Skip to content

Commit

Permalink
Merge pull request #99 from openrewrite/opt-latency-on-embeddings
Browse files Browse the repository at this point in the history
Opt latency on embeddings
  • Loading branch information
justine-gehring authored Sep 11, 2024
2 parents a8f6e55 + c536a50 commit 288e3a6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 39 deletions.
41 changes: 25 additions & 16 deletions src/main/java/io/moderne/ai/EmbeddingModelClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.fasterxml.jackson.databind.cfg.ConstructorDetector;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
import com.fasterxml.jackson.annotation.JsonProperty;
import kong.unirest.HttpResponse;
import kong.unirest.Unirest;
import kong.unirest.UnirestException;
Expand Down Expand Up @@ -126,7 +127,7 @@ private boolean checkForUp(Process proc) {

private int checkForUpRequest() {
try {
HttpResponse<String> response = Unirest.head("http://127.0.0.1:7860").asString();
HttpResponse<String> response = Unirest.head("http://127.0.0.1:7860/embeddings").asString();
return response.getStatus();
} catch (UnirestException e) {
return 523;
Expand Down Expand Up @@ -178,8 +179,8 @@ public float[] getEmbedding(String text) {

try {
raw = http
.post("http://127.0.0.1:7860/run/predict")
.withContent("application/json", mapper.writeValueAsBytes(new EmbeddingModelClient.GradioRequest(text)))
.post("http://127.0.0.1:7860/embeddings")
.withContent("application/json", mapper.writeValueAsBytes(new EmbeddingModelClient.Request(text)))
.send();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
Expand All @@ -191,34 +192,42 @@ public float[] getEmbedding(String text) {

float[] embeddings = null;
try {
embeddings = mapper.readValue(raw.getBodyAsBytes(), EmbeddingModelClient.GradioResponse.class).getEmbedding();
embeddings = mapper.readValue(raw.getBodyAsBytes(), EmbeddingModelClient.Response.class).getEmbedding();
} catch (IOException e) {
throw new RuntimeException(e);
}
return embeddings;
}

@Value
private static class GradioRequest {
String[] data;
private static class Request {
@JsonProperty("model")
String model = "bge-small";

GradioRequest(String... data) {
this.data = data;
@JsonProperty("input")
String input;

Request(String input) {
this.input = input;
}
}

@Value
private static class GradioResponse {
List<String> data;
private static class Response {
@JsonProperty("data")
List<EmbeddingData> data;

public float[] getEmbedding() {
String d = data.get(0);
String[] emStr = d.substring(1, d.length() - 1).split(",");
float[] em = new float[emStr.length];
for (int i = 0; i < emStr.length; i++) {
em[i] = Float.parseFloat(emStr[i]);
if (data == null || data.isEmpty()) {
return new float[0];
}
return em;
return data.get(0).embedding;
}

@Value
private static class EmbeddingData {
@JsonProperty("embedding")
float[] embedding;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class FindCodeThatResembles extends ScanningRecipe<FindCodeThatResembles.
description = "Since AI based matching has a higher latency than rules based matching, " +
"we do a first pass to find the top k methods using embeddings. " +
"To narrow the scope, you can specify the top k methods as method filters.",
example = "1000")
example = "5")
int k;

transient CodeSearch codeSearchTable = new CodeSearch(this);
Expand Down
42 changes: 20 additions & 22 deletions src/main/resources/get_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,28 @@
# limitations under the License.
#

import os
os.environ["XDG_CACHE_HOME"]="/HF_CACHE"
os.environ["HF_HOME"]="/HF_CACHE/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"]="/HF_CACHE/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"]="/HF_CACHE/huggingface"
import torch #pytorch = 2.0.1
from transformers import AutoModel, AutoTokenizer, logging # 4.29.2
import gradio as gr # 3.23.0

logging.set_verbosity_error()

#initialize models
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5")
model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5")
model.eval()
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
from infinity_emb import create_server
import uvicorn
import logging
from fastapi.responses import JSONResponse

logging.getLogger("infinity_emb").setLevel(logging.ERROR)

def get_embedding(input_string):
with torch.no_grad():
engine_args = EngineArgs(
model_name_or_path="michaelfeil/bge-small-en-v1.5",
device="cpu",
engine="optimum",
served_model_name="bge-small",
compile=True,
batch_size=1
)

encoded_input = tokenizer([input_string], padding=True, truncation=True, return_tensors='pt')
model_output = model(**encoded_input)
# Perform pooling. In this case, cls pooling.
embedding = model_output[0][:, 0]
embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)[0]
return embedding.tolist()
fastapi_app = create_server(engine_args_list=[engine_args])
@fastapi_app.head("/embeddings")
def read_root_head():
return JSONResponse({"message": "Infinity embedding is running"})

gr.Interface(fn=get_embedding, inputs="text", outputs="text").launch(server_port=7860)
uvicorn.run(fastapi_app, host="127.0.0.1", port=7860)

0 comments on commit 288e3a6

Please sign in to comment.