-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
814 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
|
||
@JsonIgnoreProperties(ignoreUnknown = true) | ||
public class IBMAuthBearerToken { | ||
|
||
@JsonProperty("access_token") | ||
String accessToken; | ||
@JsonProperty("expiration") | ||
int expiration; | ||
|
||
String getAccessToken() { | ||
return this.accessToken; | ||
} | ||
|
||
public void setAccessToken(String accessToken) { | ||
this.accessToken = accessToken; | ||
} | ||
|
||
int getExpiration() { | ||
return this.expiration; | ||
} | ||
|
||
public void setExpiration(int expiration) { | ||
this.expiration = expiration; | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import java.io.IOException; | ||
import java.util.Base64; | ||
import java.util.Date; | ||
import java.util.LinkedHashMap; | ||
import java.util.Map; | ||
import okhttp3.MediaType; | ||
import okhttp3.OkHttpClient; | ||
import okhttp3.Request; | ||
import okhttp3.RequestBody; | ||
import okhttp3.Response; | ||
|
||
public class WatsonxAuthenticator { | ||
|
||
IBMAuthBearerToken bearerToken; | ||
OkHttpClient client; | ||
Request request; | ||
Boolean isZenApiKey = false; | ||
|
||
// On Cloud | ||
public WatsonxAuthenticator(String apiKey) { | ||
this.client = new OkHttpClient().newBuilder() | ||
.build(); | ||
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded"); | ||
RequestBody body = RequestBody.create(mediaType, | ||
"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + apiKey); | ||
this.request = new Request.Builder() | ||
.url("https://iam.cloud.ibm.com/identity/token") | ||
.method("POST", body) | ||
.addHeader("Content-Type", "application/x-www-form-urlencoded") | ||
.build(); | ||
try { | ||
Response response = client.newCall(request).execute(); | ||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), | ||
IBMAuthBearerToken.class); | ||
} catch (IOException e) { | ||
System.out.println(e); | ||
} | ||
} | ||
|
||
// Zen API Key | ||
public WatsonxAuthenticator(String username, String zenApiKey) { | ||
IBMAuthBearerToken token = new IBMAuthBearerToken(); | ||
String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes()) | ||
.toString(); | ||
token.setAccessToken(tokenStr); | ||
this.bearerToken = token; | ||
this.isZenApiKey = true; | ||
} | ||
|
||
// Watsonx API Key | ||
public WatsonxAuthenticator(String username, String apiKey, | ||
String host) {//TODO add support for password | ||
this.client = new OkHttpClient().newBuilder() | ||
.build(); | ||
ObjectMapper mapper = new ObjectMapper(); | ||
Map<String, String> authParams = new LinkedHashMap<>(); | ||
authParams.put("username", username); | ||
authParams.put("api_key", apiKey); | ||
|
||
String authParamsStr = ""; | ||
try { | ||
authParamsStr = mapper.writeValueAsString(authParams); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException(e); | ||
} | ||
|
||
MediaType mediaType = MediaType.parse("application/json"); | ||
RequestBody body = RequestBody.create(mediaType, authParamsStr); | ||
this.request = new Request.Builder() | ||
.url(host | ||
+ "/icp4d-api/v1/authorize") // TODO add support for IAM endpoint v1/auth/identitytoken | ||
.method("POST", body) | ||
.addHeader("Content-Type", "application/json") | ||
.build(); | ||
try { | ||
Response response = client.newCall(request).execute(); | ||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), | ||
IBMAuthBearerToken.class); | ||
} catch (IOException e) { | ||
System.out.println(e); | ||
} | ||
} | ||
|
||
private void generateNewBearerToken() { | ||
try { | ||
Response response = client.newCall(request).execute(); | ||
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(), | ||
IBMAuthBearerToken.class); | ||
} catch (IOException e) { | ||
System.out.println(e); | ||
} | ||
} | ||
|
||
public String getBearerTokenValue() { | ||
if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000) | ||
< (new Date().getTime() + 60000))) { | ||
generateNewBearerToken(); | ||
} | ||
return this.bearerToken.getAccessToken(); | ||
} | ||
} |
174 changes: 174 additions & 0 deletions
174
src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
package ee.carlrobert.llm.client.watsonx; | ||
|
||
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import ee.carlrobert.llm.PropertiesLoader; | ||
import ee.carlrobert.llm.client.DeserializationUtil; | ||
import ee.carlrobert.llm.client.openai.completion.ErrorDetails; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse; | ||
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError; | ||
import ee.carlrobert.llm.completion.CompletionEventListener; | ||
import ee.carlrobert.llm.completion.CompletionEventSourceListener; | ||
import java.io.IOException; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import okhttp3.Headers; | ||
import okhttp3.MediaType; | ||
import okhttp3.OkHttpClient; | ||
import okhttp3.Request; | ||
import okhttp3.RequestBody; | ||
import okhttp3.sse.EventSource; | ||
import okhttp3.sse.EventSources; | ||
|
||
public class WatsonxClient { | ||
|
||
private static final MediaType APPLICATION_JSON = MediaType.parse("application/json"); | ||
private final OkHttpClient httpClient; | ||
private final String host; | ||
private final String apiVersion; | ||
private final WatsonxAuthenticator authenticator; | ||
|
||
private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) { | ||
this.httpClient = httpClientBuilder.build(); | ||
this.apiVersion = builder.apiVersion; | ||
this.host = builder.host; | ||
if (builder.isOnPrem) { | ||
if (builder.isZenApiKey) { | ||
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey); | ||
} else { | ||
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey, | ||
builder.host); | ||
} | ||
} else { | ||
this.authenticator = new WatsonxAuthenticator(builder.apiKey); | ||
} | ||
} | ||
|
||
public EventSource getCompletionAsync( | ||
WatsonxCompletionRequest request, | ||
CompletionEventListener<String> eventListener) { | ||
return EventSources.createFactory(httpClient).newEventSource( | ||
buildCompletionRequest(request), | ||
getCompletionEventSourceListener(eventListener)); | ||
} | ||
|
||
public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) { | ||
try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) { | ||
return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
protected Request buildCompletionRequest(WatsonxCompletionRequest request) { | ||
var headers = new HashMap<>(getRequiredHeaders()); | ||
if (request.getStream()) { | ||
headers.put("Accept", "text/event-stream"); | ||
} | ||
try { | ||
String path = (request.getDeploymentId() == null || request.getDeploymentId().isEmpty()) ? "text/" | ||
: "deployments/" + request.getDeploymentId() + "/"; | ||
String generation = request.getStream() ? "generation_stream" : "generation"; | ||
return new Request.Builder() | ||
.url(host + "/ml/v1/" + path + generation + "?version=" + apiVersion) | ||
.headers(Headers.of(headers)) | ||
.post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON)) | ||
.build(); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException("Unable to process request", e); | ||
} | ||
} | ||
|
||
private Map<String, String> getRequiredHeaders() { | ||
return new HashMap<>(Map.of("Authorization", | ||
(this.authenticator.isZenApiKey ? "ZenApiKey " : "Bearer ") | ||
+ authenticator.getBearerTokenValue())); | ||
} | ||
|
||
private CompletionEventSourceListener<String> getCompletionEventSourceListener( | ||
CompletionEventListener<String> eventListener) { | ||
return new CompletionEventSourceListener<>(eventListener) { | ||
@Override | ||
protected String getMessage(String data) { | ||
try { | ||
return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class) | ||
.getResults().get(0).getGeneratedText(); | ||
} catch (Exception e) { | ||
try { | ||
String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class) | ||
.getError() | ||
.getMessage(); | ||
return message == null ? "" : message; | ||
} catch (Exception ex) { | ||
System.out.println(ex); | ||
return ""; | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
protected ErrorDetails getErrorDetails(String error) { | ||
try { | ||
return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError(); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
}; | ||
} | ||
|
||
public static class Builder { | ||
|
||
private final String apiKey; | ||
private String host = PropertiesLoader.getValue("watsonx.baseUrl"); | ||
private String apiVersion = "2024-03-14"; | ||
private Boolean isOnPrem; | ||
private Boolean isZenApiKey; | ||
private String username; | ||
|
||
public Builder(String apiKey) { | ||
this.apiKey = apiKey; | ||
} | ||
|
||
public Builder setApiVersion(String apiVersion) { | ||
this.apiVersion = apiVersion; | ||
return this; | ||
} | ||
|
||
public Builder setHost(String host) { | ||
this.host = host; | ||
return this; | ||
} | ||
|
||
public Builder setIsZenApiKey(Boolean isZenApiKey) { | ||
this.isZenApiKey = isZenApiKey; | ||
return this; | ||
} | ||
|
||
public Builder setIsOnPrem(Boolean isOnPrem) { | ||
this.isOnPrem = isOnPrem; | ||
return this; | ||
} | ||
|
||
public Builder setUsername(String username) { | ||
this.username = username; | ||
return this; | ||
} | ||
|
||
public WatsonxClient build(OkHttpClient.Builder builder) { | ||
return new WatsonxClient(this, builder); | ||
} | ||
|
||
public WatsonxClient build() { | ||
return build(new OkHttpClient.Builder()); | ||
} | ||
} | ||
} | ||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.