Skip to content

Commit

Permalink
code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mq200 committed Sep 14, 2024
1 parent 2722590 commit 35d79ae
Show file tree
Hide file tree
Showing 15 changed files with 814 additions and 8 deletions.
16 changes: 15 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.8.17] - 2024-09-12

### Fixed

- Ollama host overriding

## [0.8.16] - 2024-09-05

### Fixed

- Ollama auth

## [0.8.15] - 2024-09-05

### Added
Expand Down Expand Up @@ -210,7 +222,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Upgrade OpenAI chat models: **gpt-4-0125-preview**, **gpt-3.5-turbo-0125**

[0.8.15]: https://github.com/carlrobertoh/llm-client/compare/fa0539e06d6cd8d21a4d0fa3336c747c2cb68fcc...HEAD
[0.8.17]: https://github.com/carlrobertoh/llm-client/compare/6b7e26477b8e3454e78c8c639e97c8803fa5a301...HEAD
[0.8.16]: https://github.com/carlrobertoh/llm-client/compare/d714854331915387da583c9a5b24877cc06286e...6b7e26477b8e3454e78c8c639e97c8803fa5a301
[0.8.15]: https://github.com/carlrobertoh/llm-client/compare/fa0539e06d6cd8d21a4d0fa3336c747c2cb68fcc...d714854331915387da583c9a5b24877cc06286e
[0.8.14]: https://github.com/carlrobertoh/llm-client/compare/6461c8458325e7b2a33670fc09493b3357eb094c...fa0539e06d6cd8d21a4d0fa3336c747c2cb68fcc
[0.8.13]: https://github.com/carlrobertoh/llm-client/compare/a55fe7dcefbe6b911d5b99950d402dd06a66ec1e...6461c8458325e7b2a33670fc09493b3357eb094c
[0.8.12]: https://github.com/carlrobertoh/llm-client/compare/6fdf91d29194bfed92c7e23280953d47614e62a5...a55fe7dcefbe6b911d5b99950d402dd06a66ec1e
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ To use the package, you need to use following Maven dependency:
<dependency>
<groupId>ee.carlrobert</groupId>
<artifactId>llm-client</artifactId>
<version>0.8.14</version>
<version>0.8.17</version>
</dependency>
```
Gradle dependency:
```kts
dependencies {
implementation("ee.carlrobert:llm-client:0.8.14")
implementation("ee.carlrobert:llm-client:0.8.17")
}
```

Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {
}

group = "ee.carlrobert"
version = "0.8.15"
version = "0.8.17"

repositories {
mavenCentral()
Expand Down
12 changes: 9 additions & 3 deletions src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,16 @@ private void processStreamRequest(
private HttpRequest buildPostHttpRequest(
Object request,
String path) throws JsonProcessingException {
return HttpRequest.newBuilder(URI.create(BASE_URL + path))
var baseHost = port == null ? BASE_URL : format("http://localhost:%d", port);
var requestBuilder = HttpRequest.newBuilder(URI.create((host == null ? baseHost : host) + path))
.POST(HttpRequest.BodyPublishers.ofString(new ObjectMapper().writeValueAsString(request)))
.header("Content-Type", "application/json")
.timeout(Duration.ofSeconds(30))
.header("Content-Type", "application/x-ndjson");

if (apiKey != null) {
requestBuilder.header("Authorization", "Bearer " + apiKey);
}

return requestBuilder.timeout(Duration.ofSeconds(30))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package ee.carlrobert.llm.client.watsonx;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(ignoreUnknown = true)
public class IBMAuthBearerToken {

@JsonProperty("access_token")
String accessToken;
@JsonProperty("expiration")
int expiration;

String getAccessToken() {
return this.accessToken;
}

public void setAccessToken(String accessToken) {
this.accessToken = accessToken;
}

int getExpiration() {
return this.expiration;
}

public void setExpiration(int expiration) {
this.expiration = expiration;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package ee.carlrobert.llm.client.watsonx;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.Base64;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.Map;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;

public class WatsonxAuthenticator {

IBMAuthBearerToken bearerToken;
OkHttpClient client;
Request request;
Boolean isZenApiKey = false;

// On Cloud
public WatsonxAuthenticator(String apiKey) {
this.client = new OkHttpClient().newBuilder()
.build();
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType,
"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + apiKey);
this.request = new Request.Builder()
.url("https://iam.cloud.ibm.com/identity/token")
.method("POST", body)
.addHeader("Content-Type", "application/x-www-form-urlencoded")
.build();
try {
Response response = client.newCall(request).execute();
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
IBMAuthBearerToken.class);
} catch (IOException e) {
System.out.println(e);
}
}

// Zen API Key
public WatsonxAuthenticator(String username, String zenApiKey) {
IBMAuthBearerToken token = new IBMAuthBearerToken();
String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes())
.toString();
token.setAccessToken(tokenStr);
this.bearerToken = token;
this.isZenApiKey = true;
}

// Watsonx API Key
public WatsonxAuthenticator(String username, String apiKey,
String host) {//TODO add support for password
this.client = new OkHttpClient().newBuilder()
.build();
ObjectMapper mapper = new ObjectMapper();
Map<String, String> authParams = new LinkedHashMap<>();
authParams.put("username", username);
authParams.put("api_key", apiKey);

String authParamsStr = "";
try {
authParamsStr = mapper.writeValueAsString(authParams);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, authParamsStr);
this.request = new Request.Builder()
.url(host
+ "/icp4d-api/v1/authorize") // TODO add support for IAM endpoint v1/auth/identitytoken
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
try {
Response response = client.newCall(request).execute();
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
IBMAuthBearerToken.class);
} catch (IOException e) {
System.out.println(e);
}
}

private void generateNewBearerToken() {
try {
Response response = client.newCall(request).execute();
this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
IBMAuthBearerToken.class);
} catch (IOException e) {
System.out.println(e);
}
}

public String getBearerTokenValue() {
if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000)
< (new Date().getTime() + 60000))) {
generateNewBearerToken();
}
return this.bearerToken.getAccessToken();
}
}
174 changes: 174 additions & 0 deletions src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package ee.carlrobert.llm.client.watsonx;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.PropertiesLoader;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse;
import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError;
import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import okhttp3.Headers;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;

public class WatsonxClient {

private static final MediaType APPLICATION_JSON = MediaType.parse("application/json");
private final OkHttpClient httpClient;
private final String host;
private final String apiVersion;
private final WatsonxAuthenticator authenticator;

private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
this.httpClient = httpClientBuilder.build();
this.apiVersion = builder.apiVersion;
this.host = builder.host;
if (builder.isOnPrem) {
if (builder.isZenApiKey) {
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey);
} else {
this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey,
builder.host);
}
} else {
this.authenticator = new WatsonxAuthenticator(builder.apiKey);
}
}

public EventSource getCompletionAsync(
WatsonxCompletionRequest request,
CompletionEventListener<String> eventListener) {
return EventSources.createFactory(httpClient).newEventSource(
buildCompletionRequest(request),
getCompletionEventSourceListener(eventListener));
}

public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) {
try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) {
return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

protected Request buildCompletionRequest(WatsonxCompletionRequest request) {
var headers = new HashMap<>(getRequiredHeaders());
if (request.getStream()) {
headers.put("Accept", "text/event-stream");
}
try {
String path = (request.getDeploymentId() == null || request.getDeploymentId().isEmpty()) ? "text/"
: "deployments/" + request.getDeploymentId() + "/";
String generation = request.getStream() ? "generation_stream" : "generation";
return new Request.Builder()
.url(host + "/ml/v1/" + path + generation + "?version=" + apiVersion)
.headers(Headers.of(headers))
.post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException("Unable to process request", e);
}
}

private Map<String, String> getRequiredHeaders() {
return new HashMap<>(Map.of("Authorization",
(this.authenticator.isZenApiKey ? "ZenApiKey " : "Bearer ")
+ authenticator.getBearerTokenValue()));
}

private CompletionEventSourceListener<String> getCompletionEventSourceListener(
CompletionEventListener<String> eventListener) {
return new CompletionEventSourceListener<>(eventListener) {
@Override
protected String getMessage(String data) {
try {
return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class)
.getResults().get(0).getGeneratedText();
} catch (Exception e) {
try {
String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class)
.getError()
.getMessage();
return message == null ? "" : message;
} catch (Exception ex) {
System.out.println(ex);
return "";
}
}
}

@Override
protected ErrorDetails getErrorDetails(String error) {
try {
return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
};
}

public static class Builder {

private final String apiKey;
private String host = PropertiesLoader.getValue("watsonx.baseUrl");
private String apiVersion = "2024-03-14";
private Boolean isOnPrem;
private Boolean isZenApiKey;
private String username;

public Builder(String apiKey) {
this.apiKey = apiKey;
}

public Builder setApiVersion(String apiVersion) {
this.apiVersion = apiVersion;
return this;
}

public Builder setHost(String host) {
this.host = host;
return this;
}

public Builder setIsZenApiKey(Boolean isZenApiKey) {
this.isZenApiKey = isZenApiKey;
return this;
}

public Builder setIsOnPrem(Boolean isOnPrem) {
this.isOnPrem = isOnPrem;
return this;
}

public Builder setUsername(String username) {
this.username = username;
return this;
}

public WatsonxClient build(OkHttpClient.Builder builder) {
return new WatsonxClient(this, builder);
}

public WatsonxClient build() {
return build(new OkHttpClient.Builder());
}
}
}






Loading

0 comments on commit 35d79ae

Please sign in to comment.