Skip to content

Commit

Permalink
feat(api): add correlationID to Track Chat requests (#3668)
Browse files Browse the repository at this point in the history
* Add CorrelationID to chat request

Signed-off-by: Siddharth More <[email protected]>

* remove get_token_metrics

Signed-off-by: Siddharth More <[email protected]>

* Add CorrelationID to proto

Signed-off-by: Siddharth More <[email protected]>

* fix correlation method name

Signed-off-by: Siddharth More <[email protected]>

* Update core/http/endpoints/openai/chat.go

Co-authored-by: Ettore Di Giacinto <[email protected]>
Signed-off-by: Siddharth More <[email protected]>

* Update core/http/endpoints/openai/chat.go

Signed-off-by: Ettore Di Giacinto <[email protected]>
Signed-off-by: Siddharth More <[email protected]>

---------

Signed-off-by: Siddharth More <[email protected]>
Signed-off-by: Ettore Di Giacinto <[email protected]>
Co-authored-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
siddimore and mudler authored Sep 28, 2024
1 parent e94a50e commit 50a3b54
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 1 deletion.
1 change: 1 addition & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ message PredictOptions {
repeated Message Messages = 44;
repeated string Videos = 45;
repeated string Audios = 46;
string CorrelationId = 47;
}

// The response message containing the result
Expand Down
14 changes: 14 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,9 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
data["ignore_eos"] = predict->ignoreeos();
data["embeddings"] = predict->embeddings();

// Add the correlationid to json data
data["correlation_id"] = predict->correlationid();

// for each image in the request, add the image data
//
for (int i = 0; i < predict->images_size(); i++) {
Expand Down Expand Up @@ -2344,6 +2347,11 @@ class BackendServiceImpl final : public backend::Backend::Service {
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);

// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});

// Send the reply
writer->Write(reply);

Expand All @@ -2367,6 +2375,12 @@ class BackendServiceImpl final : public backend::Backend::Service {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {

// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});

completion_text = result.result_json.value("content", "");
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
Expand Down
7 changes: 7 additions & 0 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID")
if len(strings.TrimSpace(correlationID)) == 0 {
correlationID = id
}
c.Set("X-Correlation-ID", correlationID)

modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
Expand Down Expand Up @@ -444,6 +450,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Set("X-Correlation-ID", id)

responses := make(chan schema.OpenAIResponse)

Expand Down
2 changes: 2 additions & 0 deletions core/http/endpoints/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}

return func(c *fiber.Ctx) error {
// Add Correlation
c.Set("X-Correlation-ID", id)
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
Expand Down
13 changes: 12 additions & 1 deletion core/http/endpoints/openai/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"

"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
Expand All @@ -15,6 +16,11 @@ import (
"github.com/rs/zerolog/log"
)

type correlationIDKeyType string

// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"

func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)

Expand All @@ -24,9 +30,14 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo
}

received, _ := json.Marshal(input)
// Extract or generate the correlation ID
correlationID := c.Get("X-Correlation-ID", uuid.New().String())

ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID)

input.Context = ctxWithCorrelationID
input.Cancel = cancel

log.Debug().Msgf("Request received: %s", string(received))
Expand Down

0 comments on commit 50a3b54

Please sign in to comment.