diff --git a/src/main/java/io/moderne/ai/EmbeddingModelClient.java b/src/main/java/io/moderne/ai/EmbeddingModelClient.java index 59221f8..4ca3ff1 100644 --- a/src/main/java/io/moderne/ai/EmbeddingModelClient.java +++ b/src/main/java/io/moderne/ai/EmbeddingModelClient.java @@ -49,7 +49,7 @@ public class EmbeddingModelClient { @Nullable private static EmbeddingModelClient INSTANCE; - private ObjectMapper mapper = JsonMapper.builder() + private final ObjectMapper mapper = JsonMapper.builder() .constructorDetector(ConstructorDetector.USE_PROPERTIES_BASED) .build() .registerModule(new ParameterNamesModule()) @@ -152,8 +152,9 @@ private Function timeEmbedding(List timings) { } public double getDistance(String t1, String t2) { - float[] e1 = getEmbedding(t1); - float[] e2 = getEmbedding(t2); + List timings = new ArrayList<>(2); + float[] e1 = embeddingCache.computeIfAbsent(t1, timeEmbedding(timings)); + float[] e2 = embeddingCache.computeIfAbsent(t2, timeEmbedding(timings)); return dist(e1, e2); } @@ -199,7 +200,7 @@ public float[] getEmbedding(String text) { @Value private static class GradioRequest { - private final String[] data; + String[] data; GradioRequest(String... data) { this.data = data; diff --git a/src/main/resources/get_is_related.py b/src/main/resources/get_is_related.py index 7d1fd70..a7c6583 100644 --- a/src/main/resources/get_is_related.py +++ b/src/main/resources/get_is_related.py @@ -67,7 +67,7 @@ def predict(self, query: str, snippet: str) -> float: """Returns a normalized score between [0, 1] reflecting the likelihood that the snippet is a positive match for the query.""" - def _scaled_sigmoid(self, a: np.number | np.ndarray) -> np.number | np.ndarray: + def _scaled_sigmoid(self, a): """a scaled sigmoid function to map values to [0, 1]""" return 1 / (1 + np.exp(-self._sigmoid_scale * (a - self._sigmoid_shift))) diff --git a/src/test/java/io/moderne/ai/research/FindCodeThatResemblesTest.java b/src/test/java/io/moderne/ai/research/FindCodeThatResemblesTest.java index 9842b57..e523b49 100644 --- a/src/test/java/io/moderne/ai/research/FindCodeThatResemblesTest.java +++ b/src/test/java/io/moderne/ai/research/FindCodeThatResemblesTest.java @@ -22,8 +22,6 @@ import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; -import java.util.List; - import static org.openrewrite.java.Assertions.java; @DisabledIfEnvironmentVariable(named = "CI", matches = "true")