diff --git a/CHANGELOG.md b/CHANGELOG.md
index 926ad74..a3ed83a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [0.8.17] - 2024-09-12
+
+### Fixed
+
+- Ollama host overriding
+
+## [0.8.16] - 2024-09-05
+
+### Fixed
+
+- Ollama auth
+
## [0.8.15] - 2024-09-05
### Added
@@ -210,7 +222,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Upgrade OpenAI chat models: **gpt-4-0125-preview**, **gpt-3.5-turbo-0125**
-[0.8.15]: https://github.com/carlrobertoh/llm-client/compare/fa0539e06d6cd8d21a4d0fa3336c747c2cb68fcc...HEAD
+[0.8.17]: https://github.com/carlrobertoh/llm-client/compare/6b7e26477b8e3454e78c8c639e97c8803fa5a301...HEAD
+[0.8.16]: https://github.com/carlrobertoh/llm-client/compare/d714854331915387da583c9a5b24877cc06286e...6b7e26477b8e3454e78c8c639e97c8803fa5a301
+[0.8.15]: https://github.com/carlrobertoh/llm-client/compare/fa0539e06d6cd8d21a4d0fa3336c747c2cb68fcc...d714854331915387da583c9a5b24877cc06286e
[0.8.14]: https://github.com/carlrobertoh/llm-client/compare/6461c8458325e7b2a33670fc09493b3357eb094c...fa0539e06d6cd8d21a4d0fa3336c747c2cb68fcc
[0.8.13]: https://github.com/carlrobertoh/llm-client/compare/a55fe7dcefbe6b911d5b99950d402dd06a66ec1e...6461c8458325e7b2a33670fc09493b3357eb094c
[0.8.12]: https://github.com/carlrobertoh/llm-client/compare/6fdf91d29194bfed92c7e23280953d47614e62a5...a55fe7dcefbe6b911d5b99950d402dd06a66ec1e
diff --git a/README.md b/README.md
index 4849f5c..5263b86 100644
--- a/README.md
+++ b/README.md
@@ -12,13 +12,13 @@ To use the package, you need to use following Maven dependency:
ee.carlrobert
llm-client
- 0.8.14
+ 0.8.17
```
Gradle dependency:
```kts
dependencies {
- implementation("ee.carlrobert:llm-client:0.8.14")
+ implementation("ee.carlrobert:llm-client:0.8.17")
}
```
diff --git a/build.gradle.kts b/build.gradle.kts
index 01984de..150f3a0 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -7,7 +7,7 @@ plugins {
}
group = "ee.carlrobert"
-version = "0.8.15"
+version = "0.8.17"
repositories {
mavenCentral()
diff --git a/src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java b/src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java
index 3918570..956e8ea 100644
--- a/src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java
+++ b/src/main/java/ee/carlrobert/llm/client/ollama/OllamaClient.java
@@ -244,10 +244,16 @@ private void processStreamRequest(
private HttpRequest buildPostHttpRequest(
Object request,
String path) throws JsonProcessingException {
- return HttpRequest.newBuilder(URI.create(BASE_URL + path))
+ var baseHost = port == null ? BASE_URL : format("http://localhost:%d", port);
+ var requestBuilder = HttpRequest.newBuilder(URI.create((host == null ? baseHost : host) + path))
.POST(HttpRequest.BodyPublishers.ofString(new ObjectMapper().writeValueAsString(request)))
- .header("Content-Type", "application/json")
- .timeout(Duration.ofSeconds(30))
+ .header("Content-Type", "application/x-ndjson");
+
+ if (apiKey != null) {
+ requestBuilder.header("Authorization", "Bearer " + apiKey);
+ }
+
+ return requestBuilder.timeout(Duration.ofSeconds(30))
.build();
}
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java b/src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java
new file mode 100644
index 0000000..58668d1
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/IBMAuthBearerToken.java
@@ -0,0 +1,29 @@
+package ee.carlrobert.llm.client.watsonx;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class IBMAuthBearerToken {
+
+ @JsonProperty("access_token")
+ String accessToken;
+ @JsonProperty("expiration")
+ int expiration;
+
+ String getAccessToken() {
+ return this.accessToken;
+ }
+
+ public void setAccessToken(String accessToken) {
+ this.accessToken = accessToken;
+ }
+
+ int getExpiration() {
+ return this.expiration;
+ }
+
+ public void setExpiration(int expiration) {
+ this.expiration = expiration;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java
new file mode 100644
index 0000000..640c913
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxAuthenticator.java
@@ -0,0 +1,107 @@
+package ee.carlrobert.llm.client.watsonx;
+
+import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import java.io.IOException;
+import java.util.Base64;
+import java.util.Date;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import okhttp3.MediaType;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.RequestBody;
+import okhttp3.Response;
+
+public class WatsonxAuthenticator {
+
+ IBMAuthBearerToken bearerToken;
+ OkHttpClient client;
+ Request request;
+ Boolean isZenApiKey = false;
+
+ // On Cloud
+ public WatsonxAuthenticator(String apiKey) {
+ this.client = new OkHttpClient().newBuilder()
+ .build();
+ MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
+ RequestBody body = RequestBody.create(mediaType,
+ "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + apiKey);
+ this.request = new Request.Builder()
+ .url("https://iam.cloud.ibm.com/identity/token")
+ .method("POST", body)
+ .addHeader("Content-Type", "application/x-www-form-urlencoded")
+ .build();
+ try {
+ Response response = client.newCall(request).execute();
+ this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
+ IBMAuthBearerToken.class);
+ } catch (IOException e) {
+ System.out.println(e);
+ }
+ }
+
+ // Zen API Key
+ public WatsonxAuthenticator(String username, String zenApiKey) {
+ IBMAuthBearerToken token = new IBMAuthBearerToken();
+ String tokenStr = Base64.getEncoder().encode((username + ":" + zenApiKey).getBytes())
+ .toString();
+ token.setAccessToken(tokenStr);
+ this.bearerToken = token;
+ this.isZenApiKey = true;
+ }
+
+ // Watsonx API Key
+ public WatsonxAuthenticator(String username, String apiKey,
+ String host) {//TODO add support for password
+ this.client = new OkHttpClient().newBuilder()
+ .build();
+ ObjectMapper mapper = new ObjectMapper();
+ Map authParams = new LinkedHashMap<>();
+ authParams.put("username", username);
+ authParams.put("api_key", apiKey);
+
+ String authParamsStr = "";
+ try {
+ authParamsStr = mapper.writeValueAsString(authParams);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+
+ MediaType mediaType = MediaType.parse("application/json");
+ RequestBody body = RequestBody.create(mediaType, authParamsStr);
+ this.request = new Request.Builder()
+ .url(host
+ + "/icp4d-api/v1/authorize") // TODO add support for IAM endpoint v1/auth/identitytoken
+ .method("POST", body)
+ .addHeader("Content-Type", "application/json")
+ .build();
+ try {
+ Response response = client.newCall(request).execute();
+ this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
+ IBMAuthBearerToken.class);
+ } catch (IOException e) {
+ System.out.println(e);
+ }
+ }
+
+ private void generateNewBearerToken() {
+ try {
+ Response response = client.newCall(request).execute();
+ this.bearerToken = OBJECT_MAPPER.readValue(response.body().string(),
+ IBMAuthBearerToken.class);
+ } catch (IOException e) {
+ System.out.println(e);
+ }
+ }
+
+ public String getBearerTokenValue() {
+ if (!isZenApiKey && (this.bearerToken == null || (this.bearerToken.getExpiration() * 1000)
+ < (new Date().getTime() + 60000))) {
+ generateNewBearerToken();
+ }
+ return this.bearerToken.getAccessToken();
+ }
+}
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java
new file mode 100644
index 0000000..d8204ab
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/WatsonxClient.java
@@ -0,0 +1,174 @@
+package ee.carlrobert.llm.client.watsonx;
+
+import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import ee.carlrobert.llm.PropertiesLoader;
+import ee.carlrobert.llm.client.DeserializationUtil;
+import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
+import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionRequest;
+import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponse;
+import ee.carlrobert.llm.client.watsonx.completion.WatsonxCompletionResponseError;
+import ee.carlrobert.llm.completion.CompletionEventListener;
+import ee.carlrobert.llm.completion.CompletionEventSourceListener;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import okhttp3.Headers;
+import okhttp3.MediaType;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.RequestBody;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSources;
+
+public class WatsonxClient {
+
+ private static final MediaType APPLICATION_JSON = MediaType.parse("application/json");
+ private final OkHttpClient httpClient;
+ private final String host;
+ private final String apiVersion;
+ private final WatsonxAuthenticator authenticator;
+
+ private WatsonxClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {
+ this.httpClient = httpClientBuilder.build();
+ this.apiVersion = builder.apiVersion;
+ this.host = builder.host;
+ if (builder.isOnPrem) {
+ if (builder.isZenApiKey) {
+ this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey);
+ } else {
+ this.authenticator = new WatsonxAuthenticator(builder.username, builder.apiKey,
+ builder.host);
+ }
+ } else {
+ this.authenticator = new WatsonxAuthenticator(builder.apiKey);
+ }
+ }
+
+ public EventSource getCompletionAsync(
+ WatsonxCompletionRequest request,
+ CompletionEventListener eventListener) {
+ return EventSources.createFactory(httpClient).newEventSource(
+ buildCompletionRequest(request),
+ getCompletionEventSourceListener(eventListener));
+ }
+
+ public WatsonxCompletionResponse getCompletion(WatsonxCompletionRequest request) {
+ try (var response = httpClient.newCall(buildCompletionRequest(request)).execute()) {
+ return DeserializationUtil.mapResponse(response, WatsonxCompletionResponse.class);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ protected Request buildCompletionRequest(WatsonxCompletionRequest request) {
+ var headers = new HashMap<>(getRequiredHeaders());
+ if (request.getStream()) {
+ headers.put("Accept", "text/event-stream");
+ }
+ try {
+ String path = (request.getDeploymentId() == null || request.getDeploymentId().isEmpty()) ? "text/"
+ : "deployments/" + request.getDeploymentId() + "/";
+ String generation = request.getStream() ? "generation_stream" : "generation";
+ return new Request.Builder()
+ .url(host + "/ml/v1/" + path + generation + "?version=" + apiVersion)
+ .headers(Headers.of(headers))
+ .post(RequestBody.create(OBJECT_MAPPER.writeValueAsString(request), APPLICATION_JSON))
+ .build();
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Unable to process request", e);
+ }
+ }
+
+ private Map getRequiredHeaders() {
+ return new HashMap<>(Map.of("Authorization",
+ (this.authenticator.isZenApiKey ? "ZenApiKey " : "Bearer ")
+ + authenticator.getBearerTokenValue()));
+ }
+
+ private CompletionEventSourceListener getCompletionEventSourceListener(
+ CompletionEventListener eventListener) {
+ return new CompletionEventSourceListener<>(eventListener) {
+ @Override
+ protected String getMessage(String data) {
+ try {
+ return OBJECT_MAPPER.readValue(data, WatsonxCompletionResponse.class)
+ .getResults().get(0).getGeneratedText();
+ } catch (Exception e) {
+ try {
+ String message = OBJECT_MAPPER.readValue(data, WatsonxCompletionResponseError.class)
+ .getError()
+ .getMessage();
+ return message == null ? "" : message;
+ } catch (Exception ex) {
+ System.out.println(ex);
+ return "";
+ }
+ }
+ }
+
+ @Override
+ protected ErrorDetails getErrorDetails(String error) {
+ try {
+ return OBJECT_MAPPER.readValue(error, WatsonxCompletionResponseError.class).getError();
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ };
+ }
+
+ public static class Builder {
+
+ private final String apiKey;
+ private String host = PropertiesLoader.getValue("watsonx.baseUrl");
+ private String apiVersion = "2024-03-14";
+ private Boolean isOnPrem;
+ private Boolean isZenApiKey;
+ private String username;
+
+ public Builder(String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public Builder setApiVersion(String apiVersion) {
+ this.apiVersion = apiVersion;
+ return this;
+ }
+
+ public Builder setHost(String host) {
+ this.host = host;
+ return this;
+ }
+
+ public Builder setIsZenApiKey(Boolean isZenApiKey) {
+ this.isZenApiKey = isZenApiKey;
+ return this;
+ }
+
+ public Builder setIsOnPrem(Boolean isOnPrem) {
+ this.isOnPrem = isOnPrem;
+ return this;
+ }
+
+ public Builder setUsername(String username) {
+ this.username = username;
+ return this;
+ }
+
+ public WatsonxClient build(OkHttpClient.Builder builder) {
+ return new WatsonxClient(this, builder);
+ }
+
+ public WatsonxClient build() {
+ return build(new OkHttpClient.Builder());
+ }
+ }
+}
+
+
+
+
+
+
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionErrorDetails.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionErrorDetails.java
new file mode 100644
index 0000000..4a45f9e
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionErrorDetails.java
@@ -0,0 +1,41 @@
+package ee.carlrobert.llm.client.watsonx.completion;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
+
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class WatsonxCompletionErrorDetails {
+
+ private static final String DEFAULT_ERROR_MSG = "Something went wrong. Please try again later.";
+ public static WatsonxCompletionErrorDetails DEFAULT_ERROR = new WatsonxCompletionErrorDetails(
+ DEFAULT_ERROR_MSG, null);
+ String code;
+ String message;
+ ErrorDetails details;
+
+
+ @JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
+ public WatsonxCompletionErrorDetails(
+ @JsonProperty("message") String message,
+ @JsonProperty("code") String code) {
+ this.message = message;
+ this.code = code;
+ this.details = new ErrorDetails(message, null, null, code);
+ }
+
+ public String getMessage() {
+ return message;
+ }
+
+ public String getCode() {
+ return code;
+ }
+
+ public ErrorDetails getDetails() {
+ return details;
+ }
+
+}
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionModel.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionModel.java
new file mode 100644
index 0000000..69b5c95
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionModel.java
@@ -0,0 +1,56 @@
+package ee.carlrobert.llm.client.watsonx.completion;
+
+import ee.carlrobert.llm.completion.CompletionModel;
+import java.util.Arrays;
+
+public enum WatsonxCompletionModel implements CompletionModel {
+
+ GRANITE_3B_CODE_INSTRUCT("ibm/granite-3b-code-instruct", "IBM Granite 3B Code Instruct", 8192),
+ GRANITE_8B_CODE_INSTRUCT("ibm/granite-8b-code-instruct", "IBM Granite 8B Code Instruct", 8192),
+ GRANITE_20B_CODE_INSTRUCT("ibm/granite-20b-code-instruct", "IBM Granite 20B Code Instruct", 8192),
+ GRANITE_34B_CODE_INSTRUCT("ibm/granite-34b-code-instruct", "IBM Granite 34B Code Instruct", 8192),
+ CODELLAMA_34_B_INSTRUCT("codellama/codellama-34b-instruct-hf", "Code Llama 34B Instruct", 8192),
+ MIXTRAL_8_7B("mistralai/mixtral-8x7b-instruct-v01", "Mixtral (8x7B)", 32768),
+ MIXTRAL_LARGE("mistralai/mistral-large", "Mistral Large", 128000),
+ LLAMA_3_1_70B("meta-llama/llama-3-1-70b-instruct", "Llama 3.1 Instruct (70B)", 128000),
+ LLAMA_3_1_8B("meta-llama/llama-3-1-8b-instruct", "Llama 3.1 Instruct (8B)", 128000),
+ LLAMA_2_7B("meta-llama/llama-2-70b-chat", "Llama 2 Chat (70B)", 4096),
+ LLAMA_2_13B("meta-llama/llama-2-13b-chat", "Llama 2 Chat (13B)", 4096),
+ GRANITE_13B_INSTRUCT_V2("ibm/granite-13b-instruct-v2", "IBM Granite 13B Instruct V2", 8192),
+ GRANITE_13B_CHAT_V2("ibm/granite-13b-chat-v2", "IBM Granite 13B Chat V2", 8192),
+ GRANITE_20B_MULTILINGUAL("ibm/granite-20b-multilingual", "IBM Granite 20B Multilingual", 8192);
+
+ private final String code;
+ private final String description;
+ private final int maxTokens;
+
+ WatsonxCompletionModel(String code, String description, int maxTokens) {
+ this.code = code;
+ this.description = description;
+ this.maxTokens = maxTokens;
+ }
+
+ public static WatsonxCompletionModel findByCode(String code) {
+ return Arrays.stream(WatsonxCompletionModel.values())
+ .filter(item -> item.getCode().equals(code))
+ .findFirst().orElseThrow();
+ }
+
+ public String getCode() {
+ return code;
+ }
+
+ public String getDescription() {
+ return description;
+ }
+
+ public int getMaxTokens() {
+ return maxTokens;
+ }
+
+ @Override
+ public String toString() {
+ return description;
+ }
+}
+
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionRequest.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionRequest.java
new file mode 100644
index 0000000..b2c6cc7
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionRequest.java
@@ -0,0 +1,242 @@
+package ee.carlrobert.llm.client.watsonx.completion;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import ee.carlrobert.llm.completion.CompletionRequest;
+
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class WatsonxCompletionRequest implements CompletionRequest {
+
+ String input;
+ @JsonProperty("project_id")
+ String projectId;
+ @JsonProperty("space_id")
+ String spaceId;
+ @JsonProperty("model_id")
+ String modelId;
+ String deploymentId;
+ Boolean stream;
+ WatsonxCompletionParameters parameters;
+
+ public WatsonxCompletionRequest(Builder builder) {
+ System.out.println("Model ID: " + builder.modelId);
+ System.out.println("Deployment ID: " + builder.deploymentId);
+ System.out.println("decodingMethod: " + builder.decodingMethod);
+ System.out.println("maxNewTokens: " + builder.maxNewTokens);
+ System.out.println("minNewTokens: " + builder.minNewTokens);
+ System.out.println("randomSeed: " + builder.randomSeed);
+ System.out.println("stopSequences: " + builder.stopSequences);
+ System.out.println("timeLimit: " + builder.timeLimit);
+ System.out.println("topK: " + builder.topK);
+ System.out.println("topP: " + builder.topP);
+ System.out.println("temperature: " + builder.temperature);
+ System.out.println("rep penalty: " + builder.repetitionPenalty);
+ System.out.println("include stop seq: " + builder.includeStopSequence);
+ this.input = builder.input;
+ this.stream = builder.stream;
+ this.projectId = builder.projectId;
+ this.spaceId = builder.spaceId;
+ this.modelId = builder.modelId;
+ this.deploymentId = builder.deploymentId;
+ this.parameters = new WatsonxCompletionParameters(
+ builder.decodingMethod,
+ builder.maxNewTokens,
+ builder.minNewTokens,
+ builder.randomSeed,
+ builder.stopSequences,
+ builder.timeLimit,
+ builder.topK,
+ builder.topP,
+ builder.temperature,
+ builder.repetitionPenalty,
+ builder.includeStopSequence);
+ }
+
+ public Boolean getStream() {
+ return this.stream;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ public String getDeploymentId() {
+ return deploymentId;
+ }
+
+ public String getSpaceId() {
+ return spaceId;
+ }
+
+ public String getProjectId() {
+ return projectId;
+ }
+
+ public String getInput() {
+ return input;
+ }
+
+ public WatsonxCompletionParameters getParameters() {
+ return parameters;
+ }
+
+ public static class Builder {
+
+ String input;
+ String projectId;
+ String spaceId;
+ String modelId;
+ String deploymentId;
+ Boolean stream;
+ String decodingMethod;
+ Integer maxNewTokens;
+ Integer minNewTokens;
+ Integer randomSeed;
+ String[] stopSequences;
+ Integer timeLimit;
+ Integer topK;
+ Double topP;
+ Double repetitionPenalty;
+ Boolean includeStopSequence;
+ Double temperature;
+
+ public Builder(String prompt) {
+ this.input = prompt;
+ }
+
+ public Builder setInput(String input) {
+ this.input = input;
+ return this;
+ }
+
+ public Builder setModelId(String modelId) {
+ this.modelId = modelId;
+ return this;
+ }
+
+ public Builder setDeploymentId(String deploymentId) {
+ this.deploymentId = deploymentId;
+ return this;
+ }
+
+ public Builder setSpaceId(String spaceId) {
+ this.spaceId = spaceId;
+ return this;
+ }
+
+ public Builder setProjectId(String projectId) {
+ this.projectId = projectId;
+ return this;
+ }
+
+ public Builder setStream(Boolean stream) {
+ this.stream = stream;
+ return this;
+ }
+
+ public Builder setMaxNewTokens(Integer maxNewTokens) {
+ this.maxNewTokens = maxNewTokens;
+ return this;
+ }
+
+ public Builder setMinNewTokens(Integer minNewTokens) {
+ this.minNewTokens = minNewTokens;
+ return this;
+ }
+
+ public Builder setTemperature(Double temperature) {
+ this.temperature = temperature;
+ return this;
+ }
+
+ public Builder setRepetitionPenalty(Double frequencyPenalty) {
+ this.repetitionPenalty = frequencyPenalty;
+ return this;
+ }
+
+ public Builder setDecodingMethod(String decodingMethod) {
+ this.decodingMethod = decodingMethod;
+ return this;
+ }
+
+ public Builder setStopSequences(String[] stopSequences) {
+ this.stopSequences = stopSequences;
+ return this;
+ }
+
+ public Builder setIncludeStopSequence(Boolean includeStopSequence) {
+ this.includeStopSequence = includeStopSequence;
+ return this;
+ }
+
+ public Builder setRandomSeed(Integer randomSeed) {
+ this.randomSeed = randomSeed;
+ return this;
+ }
+
+ public Builder setTopP(Double topP) {
+ this.topP = topP;
+ return this;
+ }
+
+ public Builder setTopK(Integer topK) {
+ this.topK = topK;
+ return this;
+ }
+
+ public WatsonxCompletionRequest build() {
+ return new WatsonxCompletionRequest(this);
+ }
+ }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ public class WatsonxCompletionParameters {
+
+ @JsonProperty("decoding_method")
+ String decodingMethod;
+ @JsonProperty("max_new_tokens")
+ Integer maxNewTokens;
+ @JsonProperty("min_new_tokens")
+ Integer minNewTokens;
+ @JsonProperty("random_seed")
+ Integer randomSeed;
+ @JsonProperty("stop_sequences")
+ String[] stopSequences;
+ @JsonProperty("time_limit")
+ Integer timeLimit;
+ @JsonProperty("top_k")
+ Integer topK;
+ @JsonProperty("top_p")
+ Double topP;
+ Double temperature;
+ @JsonProperty("repetition_penalty")
+ Double repetitionPenalty;
+ @JsonProperty("include_stop_sequence")
+ Boolean includeStopSequence;
+
+ public WatsonxCompletionParameters(
+ String decodingMethod,
+ Integer maxNewTokens,
+ Integer minNewTokens,
+ Integer randomSeed,
+ String[] stopSequences,
+ Integer timeLimit,
+ Integer topK,
+ Double topP,
+ Double temperature,
+ Double repetitionPenalty,
+ Boolean includeStopSequence) {
+ this.decodingMethod = decodingMethod;
+ this.maxNewTokens = maxNewTokens;
+ this.minNewTokens = minNewTokens;
+ this.randomSeed = randomSeed;
+ this.stopSequences = stopSequences;
+ this.timeLimit = timeLimit;
+ this.topK = topK;
+ this.topP = topP;
+ this.temperature = temperature;
+ this.repetitionPenalty = repetitionPenalty;
+ this.includeStopSequence = includeStopSequence;
+ }
+ }
+}
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponse.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponse.java
new file mode 100644
index 0000000..01a48d0
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponse.java
@@ -0,0 +1,39 @@
+package ee.carlrobert.llm.client.watsonx.completion;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import ee.carlrobert.llm.completion.CompletionResponse;
+import java.util.List;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class WatsonxCompletionResponse implements CompletionResponse {
+
+ private final String modelId;
+ private String createdAt;
+ private List results;
+
+ @JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
+ public WatsonxCompletionResponse(
+ @JsonProperty("model_id") String modelId,
+ @JsonProperty("created_at") String createdAt,
+ @JsonProperty("results") List results) {
+ this.modelId = modelId;
+ this.createdAt = createdAt;
+ this.results = results;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ public String getCreatedAt() {
+ return createdAt;
+ }
+
+ public List getResults() {
+ return results;
+ }
+}
+
+
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponseError.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponseError.java
new file mode 100644
index 0000000..b08eb9b
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResponseError.java
@@ -0,0 +1,24 @@
+package ee.carlrobert.llm.client.watsonx.completion;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
+import java.util.List;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class WatsonxCompletionResponseError {
+
+ private final List error;
+
+ @JsonIgnoreProperties(ignoreUnknown = true)
+ @JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
+ public WatsonxCompletionResponseError(
+ @JsonProperty("error") List error) {
+ this.error = error;
+ }
+
+ public ErrorDetails getError() {
+ return (error == null || error.isEmpty()) ? new ErrorDetails("") : error.get(0).getDetails();
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResult.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResult.java
new file mode 100644
index 0000000..a24235c
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionResult.java
@@ -0,0 +1,33 @@
+package ee.carlrobert.llm.client.watsonx.completion;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class WatsonxCompletionResult {
+
+ String generatedText;
+ String stopReason;
+ int generatedTokenCount;
+ int inputTokenCount;
+ int seed;
+
+ @JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
+ public WatsonxCompletionResult(
+ @JsonProperty("generated_text") String generatedText,
+ @JsonProperty("stop_reason") String stopReason,
+ @JsonProperty("generated_token_count") int generatedTokenCount,
+ @JsonProperty("input_token_count") int inputTokenCount,
+ @JsonProperty("seed") int seed) {
+ this.generatedText = generatedText;
+ this.stopReason = stopReason;
+ this.generatedTokenCount = generatedTokenCount;
+ this.inputTokenCount = inputTokenCount;
+ this.seed = seed;
+ }
+
+ public String getGeneratedText() {
+ return this.generatedText;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionStreamResponse.java b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionStreamResponse.java
new file mode 100644
index 0000000..0164d4c
--- /dev/null
+++ b/src/main/java/ee/carlrobert/llm/client/watsonx/completion/WatsonxCompletionStreamResponse.java
@@ -0,0 +1,40 @@
+package ee.carlrobert.llm.client.watsonx.completion;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.List;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class WatsonxCompletionStreamResponse {
+
+ private final String modelId;
+ private List items;
+ private String createdAt;
+
+ @JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
+ public WatsonxCompletionStreamResponse(
+ @JsonProperty("model_id") String modelId,
+ @JsonProperty("created_at") String createdAt,
+ @JsonProperty("items") List items) {
+ this.modelId = modelId;
+ this.createdAt = createdAt;
+ this.items = items;
+ }
+
+ public String getModelId() {
+ WatsonxCompletionResponse firstItem = items.get(0);
+ return firstItem.getModelId();
+ }
+
+ public String getCreatedAt() {
+ WatsonxCompletionResponse firstItem = items.get(0);
+ return firstItem.getCreatedAt();
+ }
+
+ public List getItems() {
+ return items;
+ }
+}
+
+
diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties
index 5747d22..efcfd80 100644
--- a/src/main/resources/application.properties
+++ b/src/main/resources/application.properties
@@ -5,4 +5,5 @@ anthropic.baseUrl=https://api.anthropic.com
you.baseUrl=https://you.com
llama.baseUrl=http://localhost:8080
ollama.baseUrl=http://localhost:11434
-google.baseUrl=https://generativelanguage.googleapis.com
\ No newline at end of file
+google.baseUrl=https://generativelanguage.googleapis.com
+watsonx.baseUrl=https://us-south.ml.cloud.ibm.com
\ No newline at end of file