diff --git a/src/main/java/io/moderne/ai/EmbeddingModelClient.java b/src/main/java/io/moderne/ai/EmbeddingModelClient.java index 22763ba..652dd8c 100644 --- a/src/main/java/io/moderne/ai/EmbeddingModelClient.java +++ b/src/main/java/io/moderne/ai/EmbeddingModelClient.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.databind.cfg.ConstructorDetector; import com.fasterxml.jackson.databind.json.JsonMapper; import com.fasterxml.jackson.module.paramnames.ParameterNamesModule; +import com.fasterxml.jackson.annotation.JsonProperty; import kong.unirest.HttpResponse; import kong.unirest.Unirest; import kong.unirest.UnirestException; @@ -126,7 +127,7 @@ private boolean checkForUp(Process proc) { private int checkForUpRequest() { try { - HttpResponse response = Unirest.head("http://127.0.0.1:7860").asString(); + HttpResponse response = Unirest.head("http://127.0.0.1:7860/embeddings").asString(); return response.getStatus(); } catch (UnirestException e) { return 523; @@ -178,8 +179,8 @@ public float[] getEmbedding(String text) { try { raw = http - .post("http://127.0.0.1:7860/run/predict") - .withContent("application/json", mapper.writeValueAsBytes(new EmbeddingModelClient.GradioRequest(text))) + .post("http://127.0.0.1:7860/embeddings") + .withContent("application/json", mapper.writeValueAsBytes(new EmbeddingModelClient.Request(text))) .send(); } catch (JsonProcessingException e) { throw new RuntimeException(e); @@ -191,7 +192,7 @@ public float[] getEmbedding(String text) { float[] embeddings = null; try { - embeddings = mapper.readValue(raw.getBodyAsBytes(), EmbeddingModelClient.GradioResponse.class).getEmbedding(); + embeddings = mapper.readValue(raw.getBodyAsBytes(), EmbeddingModelClient.Response.class).getEmbedding(); } catch (IOException e) { throw new RuntimeException(e); } @@ -199,26 +200,34 @@ public float[] getEmbedding(String text) { } @Value - private static class GradioRequest { - String[] data; + private static class Request { + @JsonProperty("model") + String model = "bge-small"; - GradioRequest(String... data) { - this.data = data; + @JsonProperty("input") + String input; + + Request(String input) { + this.input = input; } } @Value - private static class GradioResponse { - List data; + private static class Response { + @JsonProperty("data") + List data; public float[] getEmbedding() { - String d = data.get(0); - String[] emStr = d.substring(1, d.length() - 1).split(","); - float[] em = new float[emStr.length]; - for (int i = 0; i < emStr.length; i++) { - em[i] = Float.parseFloat(emStr[i]); + if (data == null || data.isEmpty()) { + return new float[0]; } - return em; + return data.get(0).embedding; + } + + @Value + private static class EmbeddingData { + @JsonProperty("embedding") + float[] embedding; } } diff --git a/src/main/java/io/moderne/ai/research/FindCodeThatResembles.java b/src/main/java/io/moderne/ai/research/FindCodeThatResembles.java index 07895d0..823439d 100644 --- a/src/main/java/io/moderne/ai/research/FindCodeThatResembles.java +++ b/src/main/java/io/moderne/ai/research/FindCodeThatResembles.java @@ -56,7 +56,7 @@ public class FindCodeThatResembles extends ScanningRecipe