Skip to content

Commit

Permalink
cleanup + use cache for embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
justine-gehring committed May 6, 2024
1 parent a21dfcd commit 0abd1c3
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
9 changes: 5 additions & 4 deletions src/main/java/io/moderne/ai/EmbeddingModelClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class EmbeddingModelClient {
@Nullable
private static EmbeddingModelClient INSTANCE;

private ObjectMapper mapper = JsonMapper.builder()
private final ObjectMapper mapper = JsonMapper.builder()
.constructorDetector(ConstructorDetector.USE_PROPERTIES_BASED)
.build()
.registerModule(new ParameterNamesModule())
Expand Down Expand Up @@ -152,8 +152,9 @@ private Function<String, float[]> timeEmbedding(List<Duration> timings) {
}

public double getDistance(String t1, String t2) {
float[] e1 = getEmbedding(t1);
float[] e2 = getEmbedding(t2);
List<Duration> timings = new ArrayList<>(2);
float[] e1 = embeddingCache.computeIfAbsent(t1, timeEmbedding(timings));
float[] e2 = embeddingCache.computeIfAbsent(t2, timeEmbedding(timings));
return dist(e1, e2);

}
Expand Down Expand Up @@ -199,7 +200,7 @@ public float[] getEmbedding(String text) {

@Value
private static class GradioRequest {
private final String[] data;
String[] data;

GradioRequest(String... data) {
this.data = data;
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/get_is_related.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def predict(self, query: str, snippet: str) -> float:
"""Returns a normalized score between [0, 1] reflecting the likelihood that the snippet is a
positive match for the query."""

def _scaled_sigmoid(self, a: np.number | np.ndarray) -> np.number | np.ndarray:
def _scaled_sigmoid(self, a):
"""a scaled sigmoid function to map values to [0, 1]"""
return 1 / (1 + np.exp(-self._sigmoid_scale * (a - self._sigmoid_shift)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import org.openrewrite.test.RecipeSpec;
import org.openrewrite.test.RewriteTest;

import java.util.List;

import static org.openrewrite.java.Assertions.java;

@DisabledIfEnvironmentVariable(named = "CI", matches = "true")
Expand Down

0 comments on commit 0abd1c3

Please sign in to comment.