Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add correlationID to chat completions #2

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ DETECT_LIBS?=true
# llama.cpp versions
GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp
GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be
CPPLLAMA_VERSION?=95bc82fbc0df6d48cf66c857a4dda3d044f45ca2
CPPLLAMA_VERSION?=b5de3b74a595cbfefab7eeb5a567425c6a9690cf

# go-rwkv version
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
Expand Down
1 change: 1 addition & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ message PredictOptions {
repeated Message Messages = 44;
repeated string Videos = 45;
repeated string Audios = 46;
string CorrelationId = 47;
}

// The response message containing the result
Expand Down
14 changes: 14 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,9 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
data["ignore_eos"] = predict->ignoreeos();
data["embeddings"] = predict->embeddings();

// Add the correlationid to json data
data["correlation_id"] = predict->correlationid();

// for each image in the request, add the image data
//
for (int i = 0; i < predict->images_size(); i++) {
Expand Down Expand Up @@ -2344,6 +2347,11 @@ class BackendServiceImpl final : public backend::Backend::Service {
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);

// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});

// Send the reply
writer->Write(reply);

Expand All @@ -2367,6 +2375,12 @@ class BackendServiceImpl final : public backend::Backend::Service {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {

// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});

completion_text = result.result_json.value("content", "");
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
Expand Down
7 changes: 7 additions & 0 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID")
if len(strings.TrimSpace(correlationID)) == 0 {
correlationID = id
}
c.Set("X-Correlation-ID", correlationID)

modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
Expand Down Expand Up @@ -444,6 +450,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Set("X-Correlation-ID", id)

responses := make(chan schema.OpenAIResponse)

Expand Down
2 changes: 2 additions & 0 deletions core/http/endpoints/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}

return func(c *fiber.Ctx) error {
// Add Correlation
c.Set("X-Correlation-ID", id)
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
Expand Down
13 changes: 12 additions & 1 deletion core/http/endpoints/openai/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"

"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
Expand All @@ -15,6 +16,11 @@ import (
"github.com/rs/zerolog/log"
)

type correlationIDKeyType string

// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"

func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)

Expand All @@ -24,9 +30,14 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo
}

received, _ := json.Marshal(input)
// Extract or generate the correlation ID
correlationID := c.Get("X-Correlation-ID", uuid.New().String())

ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID)

input.Context = ctxWithCorrelationID
input.Cancel = cancel

log.Debug().Msgf("Request received: %s", string(received))
Expand Down