Skip to content

Commit

Permalink
Merge pull request #211 from Portkey-AI/feat/llamaPayload
Browse files Browse the repository at this point in the history
Llama payload
  • Loading branch information
VisargD authored Aug 28, 2024
2 parents 62cf6b2 + 2b2fba4 commit 7a1d8a1
Showing 1 changed file with 103 additions and 61 deletions.
164 changes: 103 additions & 61 deletions portkey_ai/llamaindex/portkey_llama_callback_handler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from enum import Enum
from enum import Enum, auto
import json
import time
from typing import Any, Dict, List, Optional
from portkey_ai.api_resources.apis.logger import Logger
from datetime import datetime
from llama_index.core.callbacks.schema import (
CBEventType,
)
from uuid import uuid4
from llama_index.legacy.schema import NodeRelationship

try:
from llama_index.core.callbacks.base_handler import (
Expand All @@ -27,25 +25,30 @@ class LlamaIndexCallbackHandler(LlamaIndexBaseCallbackHandler):
def __init__(
self,
api_key: str,
metadata: Optional[Dict[str, Any]] = {},
metadata: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
event_starts_to_ignore=[
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.SYNTHESIZE,
CBEventType.EXCEPTION,
CBEventType.TREE,
CBEventType.RERANKING,
],
event_ends_to_ignore=[
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.SYNTHESIZE,
CBEventType.EXCEPTION,
CBEventType.TREE,
CBEventType.RERANKING,
],
)

self.api_key = api_key
self.metadata = metadata
self.metadata: Dict[str, Any] = metadata or {}
self.metadata.update({"_source": "LlamaIndex", "_source_type": "Agent"})

self.portkey_logger = Logger(api_key=api_key)

Expand Down Expand Up @@ -83,7 +86,7 @@ def on_event_start( # type: ignore
span_id = str(event_id)
parent_span_id = parent_id
span_name = event_type
start_time = int(datetime.now().timestamp())
start_time = time.time()

if parent_id == "root":
parent_span_id = self.main_span_id
Expand All @@ -102,15 +105,18 @@ def on_event_start( # type: ignore
request_payload = self.retrieve_event_start(payload)
elif event_type == "templating":
request_payload = self.templating_event_start(payload)
elif event_type == "sub_question":
request_payload = self.sub_question_event_start(payload)
else:
request_payload = payload
return ""

start_event_information = {
"span_id": span_id,
"parent_span_id": parent_span_id,
"span_name": span_name.value,
"trace_id": self.global_trace_id,
"request": request_payload,
"event_type": event_type,
"start_time": start_time,
"metadata": self.metadata,
}
Expand All @@ -131,8 +137,8 @@ def on_event_end(
if span_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"
response_payload["response_time"] = total_time
else:
if event_type == "llm":
Expand All @@ -149,8 +155,10 @@ def on_event_end(
response_payload = self.retrieve_event_end(payload, event_id)
elif event_type == "templating":
response_payload = self.templating_event_end(payload, event_id)
elif event_type == "sub_question":
response_payload = self.sub_question_event_end(payload, event_id)
else:
response_payload = payload
return

self.event_map[span_id]["response"] = response_payload

Expand Down Expand Up @@ -202,32 +210,26 @@ def llm_event_start(self, payload: Any) -> Any:
return self.request

def llm_event_end(self, payload: Any, event_id) -> Any:
result: Dict[str, Any] = {}
result["body"] = {}

try:
data = self.serialize(payload)
except Exception:
data = payload.__dict__

if event_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]

self.response = {}

data = payload.get("response", {})
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

chunks = payload.get("messages", {})
self.completion_tokens = self._token_counter.estimate_tokens_in_messages(chunks)
self.token_llm = self.prompt_tokens + self.completion_tokens
self.response["status"] = 200
self.response["body"] = {
"choices": [
{
"index": 0,
"message": {
"role": data.message.role.value,
"content": data.message.content,
},
"logprobs": data.logprobs,
"finish_reason": "done",
}
]
}
self.response["body"].update(

result["body"] = data["response"]
result["body"].update(
{
"usage": {
"prompt_tokens": self.prompt_tokens,
Expand All @@ -236,18 +238,16 @@ def llm_event_end(self, payload: Any, event_id) -> Any:
}
}
)
self.response["body"].update({"id": event_id})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": getattr(data, "model", "")})
self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
result["body"].update({"id": event_id})
result["body"].update({"created": int(time.time())})
result["body"].update({"model": getattr(data, "model", "")})
result["streamingMode"] = self.streamingMode

self.response["response_time"] = total_time
result["status"] = 200
result["headers"] = {}
result["response_time"] = total_time

return self.response
return result

# ------------------------------------------------------ #
def embedding_event_start(self, payload: Any) -> Any:
Expand All @@ -270,9 +270,9 @@ def embedding_event_start(self, payload: Any) -> Any:
def embedding_event_end(self, payload: Any, event_id) -> Any:
if event_id in self.event_map:
event = self.event_map[event_id]
# event["request"]["body"]["input"] = payload.get("chunks", "")
event["request"]["body"]["input"] = payload.get("chunks", "")
# Setting as ...INPUT... to avoid logging the entire data input file
event["request"]["body"]["input"] = "...INPUT..."
# event["request"]["body"]["input"] = "...INPUT..."

start_time = event["start_time"]

Expand Down Expand Up @@ -302,27 +302,29 @@ def embedding_event_end(self, payload: Any, event_id) -> Any:
)
self.response["headers"] = {}

end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

self.response["response_time"] = total_time

return self.response

# ------------------------------------------------------ #
def agent_step_event_start(self, payload: Any) -> Any:
data = json.dumps(self.serialize(payload))
try:
data = json.dumps(self.serialize(payload))
except Exception:
data = json.dumps(payload.__dict__)
return data

def agent_step_event_end(self, payload: Any, event_id) -> Any:
if event_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
data = self.serialize(payload)
json.dumps(data)
result = self.transform_agent_step_end(data)
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result
Expand All @@ -337,35 +339,39 @@ def function_call_event_end(self, payload: Any, event_id) -> Any:
event = self.event_map[event_id]
start_time = event["start_time"]
data = self.serialize(payload)
json.dumps(data)
result = self.transform_function_call_end(data)
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result

# ------------------------------------------------------ #
def query_event_start(self, payload: Any) -> Any:
data = json.dumps(self.serialize(payload))
try:
data = json.dumps(self.serialize(payload))
except Exception:
data = json.dumps(payload.__dict__)
return data

def query_event_end(self, payload: Any, event_id) -> Any:
if event_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
data = self.serialize(payload)
json.dumps(data)
result = self.transform_query_end(data)
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result

# ------------------------------------------------------ #
def retrieve_event_start(self, payload: Any) -> Any:
data = json.dumps(self.serialize(payload))
try:
data = json.dumps(self.serialize(payload))
except Exception:
data = json.dumps(payload.__dict__)
return data

def retrieve_event_end(self, payload: Any, event_id) -> Any:
Expand All @@ -374,18 +380,16 @@ def retrieve_event_end(self, payload: Any, event_id) -> Any:
start_time = event["start_time"]

data = self.serialize(payload)
json.dumps(data)
result = self.transform_retrieve_end(data)
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result

# ------------------------------------------------------ #
def templating_event_start(self, payload: Any) -> Any:
data = self.serialize(payload)
json.dumps(data)
result = self.transform_templating_start(data)
return result

Expand All @@ -395,14 +399,44 @@ def templating_event_end(self, payload: Any, event_id) -> Any:
start_time = event["start_time"]
result = self.transform_templating_end(event_id)

end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["response_time"] = total_time
return result

# ------------------------------------------------------ #

def sub_question_event_start(self, payload: Any) -> Any:
try:
data = json.dumps(self.serialize(payload))
except Exception:
data = json.dumps(payload.__dict__)

return data

def sub_question_event_end(self, payload: Any, event_id) -> Any:
result: Dict[str, Any] = {}
result["body"] = {}
if event_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]

try:
data = self.serialize(payload)
except Exception:
data = payload.__dict__

end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

result["body"] = data
result["body"]["response_time"] = total_time

return result

# ------------------------------------------------------ #

# ----------------- EVENT Transformers ----------------- #
def transform_agent_step_end(self, data: Any) -> Any:
try:
Expand Down Expand Up @@ -587,3 +621,11 @@ def serialize(self, obj):
if isinstance(obj, tuple):
return tuple(self.serialize(item) for item in obj)
return obj


class NodeRelationship(str, Enum):
SOURCE = auto()
PREVIOUS = auto()
NEXT = auto()
PARENT = auto()
CHILD = auto()

0 comments on commit 7a1d8a1

Please sign in to comment.