Skip to content

Commit

Permalink
Merge pull request #127 from Portkey-AI/feat/langchainCallbackHandler
Browse files Browse the repository at this point in the history
langchain llamaindex callback handler
  • Loading branch information
VisargD authored Jun 22, 2024
2 parents 59dd12d + 9d973b7 commit 04483d4
Show file tree
Hide file tree
Showing 10 changed files with 530 additions and 4 deletions.
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,10 @@ upload:
rm -rf dist

dev:
pip install -e ".[dev]"
pip install -e ".[dev]"

langchain_callback:
pip install -e ".[langchain_callback]"

llama_index_callback:
pip install -e ".[llama_index_callback]"
33 changes: 33 additions & 0 deletions portkey_ai/api_resources/apis/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import os
from typing import Optional
import requests

from portkey_ai.api_resources.global_constants import PORTKEY_BASE_URL


class Logger:
def __init__(
self,
api_key: Optional[str] = None,
) -> None:
api_key = api_key or os.getenv("PORTKEY_API_KEY")
if api_key is None:
raise ValueError("API key is required to use the Logger API")

self.headers = {
"Content-Type": "application/json",
"x-portkey-api-key": api_key,
}

self.url = PORTKEY_BASE_URL + "/logs"

def log(
self,
log_object: dict,
):
response = requests.post(
url=self.url, data=json.dumps(log_object), headers=self.headers
)

return response
3 changes: 2 additions & 1 deletion portkey_ai/llms/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .chat import ChatPortkey
from .completion import PortkeyLLM
from .portkey_langchain_callback import PortkeyLangchain

__all__ = ["ChatPortkey", "PortkeyLLM"]
__all__ = ["ChatPortkey", "PortkeyLLM", "PortkeyLangchain"]
170 changes: 170 additions & 0 deletions portkey_ai/llms/langchain/portkey_langchain_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from datetime import datetime
import time
from typing import Any, Dict, List, Optional
from portkey_ai.api_resources.apis.logger import Logger

try:
from langchain_core.callbacks import BaseCallbackHandler
except ImportError:
raise ImportError("Please pip install langchain-core to use PortkeyLangchain")


class PortkeyLangchain(BaseCallbackHandler):
def __init__(
self,
api_key: str,
) -> None:
super().__init__()
self.startTimestamp: float = 0
self.endTimestamp: float = 0

self.api_key = api_key

self.portkey_logger = Logger(api_key=api_key)

self.log_object: Dict[str, Any] = {}
self.prompt_records: Any = []

self.request: Any = {}
self.response: Any = {}

# self.responseHeaders: Dict[str, Any] = {}
self.responseBody: Any = None
self.responseStatus: int = 0

self.streamingMode: bool = False

if not api_key:
raise ValueError("Please provide an API key to use PortkeyCallbackHandler")

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
for prompt in prompts:
messages = prompt.split("\n")
for message in messages:
role, content = message.split(":", 1)
self.prompt_records.append(
{"role": role.lower(), "content": content.strip()}
)

self.startTimestamp = float(datetime.now().timestamp())

self.streamingMode = kwargs.get("invocation_params", False).get("stream", False)

self.request["method"] = "POST"
self.request["url"] = serialized.get("kwargs", "").get(
"base_url", "chat/completions"
)
self.request["provider"] = serialized["id"][2]
self.request["headers"] = serialized.get("kwargs", {}).get(
"default_headers", {}
)
self.request["headers"].update({"provider": serialized["id"][2]})
self.request["body"] = {"messages": self.prompt_records}
self.request["body"].update({**kwargs.get("invocation_params", {})})

def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Run when chain starts running."""

def on_llm_end(self, response: Any, **kwargs: Any) -> None:
self.endTimestamp = float(datetime.now().timestamp())
responseTime = self.endTimestamp - self.startTimestamp

usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr]

self.response["status"] = (
200 if self.responseStatus == 0 else self.responseStatus
)
self.response["body"] = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response.generations[0][0].text,
},
"logprobs": response.generations[0][0].generation_info.get("logprobs", ""), # type: ignore[union-attr] # noqa: E501
"finish_reason": response.generations[0][0].generation_info.get("finish_reason", ""), # type: ignore[union-attr] # noqa: E501
}
]
}
self.response["body"].update({"usage": usage})
self.response["body"].update({"id": str(kwargs.get("run_id", ""))})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501
self.response["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501
self.response["time"] = int(responseTime * 1000)
self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

self.log_object.update(
{
"request": self.request,
"response": self.response,
}
)

self.portkey_logger.log(log_object=self.log_object)

def on_chain_end(
self,
outputs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
pass

def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self.responseBody = error
self.responseStatus = error.status_code # type: ignore[attr-defined]
"""Do nothing."""
pass

def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self.responseBody = error
self.responseStatus = error.status_code # type: ignore[attr-defined]
"""Do nothing."""
pass

def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self.responseBody = error
self.responseStatus = error.status_code # type: ignore[attr-defined]
pass

def on_text(self, text: str, **kwargs: Any) -> None:
pass

def on_agent_finish(self, finish: Any, **kwargs: Any) -> None:
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.streamingMode = True
"""Do nothing."""
pass

def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass

def on_agent_action(self, action: Any, **kwargs: Any) -> Any:
"""Do nothing."""
pass

def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
4 changes: 2 additions & 2 deletions portkey_ai/llms/llama_index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .completions import PortkeyLLM
from .portkey_llama_callback import PortkeyLlamaindex

__all__ = ["PortkeyLLM"]
__all__ = ["PortkeyLlamaindex"]
160 changes: 160 additions & 0 deletions portkey_ai/llms/llama_index/portkey_llama_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import time
from typing import Any, Dict, List, Optional
from portkey_ai.api_resources.apis.logger import Logger
from datetime import datetime

try:
from llama_index.core.callbacks.base_handler import (
BaseCallbackHandler as LlamaIndexBaseCallbackHandler,
)
from llama_index.core.utilities.token_counting import TokenCounter
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install llama-index to use Portkey Callback Handler"
)
except ImportError:
raise ImportError("Please pip install llama-index to use Portkey Callback Handler")


class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler):
startTimestamp: int = 0
endTimestamp: float = 0

def __init__(
self,
api_key: str,
) -> None:
super().__init__(
event_starts_to_ignore=[],
event_ends_to_ignore=[],
)

self.api_key = api_key

self.portkey_logger = Logger(api_key=api_key)

self._token_counter = TokenCounter()
self.completion_tokens = 0
self.prompt_tokens = 0
self.token_llm = 0

self.log_object: Dict[str, Any] = {}
self.prompt_records: Any = []

self.request: Any = {}
self.response: Any = {}

self.responseTime: int = 0
self.streamingMode: bool = False

if not api_key:
raise ValueError("Please provide an API key to use PortkeyCallbackHandler")

def on_event_start( # type: ignore[return]
self,
event_type: Any,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""

if event_type == "llm":
self.llm_event_start(payload)

def on_event_end(
self,
event_type: Any,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""

if event_type == "llm":
self.llm_event_stop(payload, event_id)

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
self.startTimestamp = int(datetime.now().timestamp())

def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""

def llm_event_start(self, payload: Any) -> None:
if "messages" in payload:
chunks = payload.get("messages", {})
self.prompt_tokens = self._token_counter.estimate_tokens_in_messages(chunks)
messages = payload.get("messages", {})
self.prompt_records = [
{"role": m.role.value, "content": m.content} for m in messages
]
self.request["method"] = "POST"
self.request["url"] = payload.get("serialized", {}).get(
"api_base", "chat/completions"
)
self.request["provider"] = payload.get("serialized", {}).get("class_name", "")
self.request["headers"] = {}
self.request["body"] = {"messages": self.prompt_records}
self.request["body"].update(
{"model": payload.get("serialized", {}).get("model", "")}
)
self.request["body"].update(
{"temperature": payload.get("serialized", {}).get("temperature", "")}
)

return None

def llm_event_stop(self, payload: Any, event_id) -> None:
self.endTimestamp = float(datetime.now().timestamp())
responseTime = self.endTimestamp - self.startTimestamp

data = payload.get("response", {})

chunks = payload.get("messages", {})
self.completion_tokens = self._token_counter.estimate_tokens_in_messages(chunks)
self.token_llm = self.prompt_tokens + self.completion_tokens
self.response["status"] = 200
self.response["body"] = {
"choices": [
{
"index": 0,
"message": {
"role": data.message.role.value,
"content": data.message.content,
},
"logprobs": data.logprobs,
"finish_reason": "done",
}
]
}
self.response["body"].update(
{
"usage": {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.token_llm,
}
}
)
self.response["body"].update({"id": event_id})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": data.raw.get("model", "")})
self.response["time"] = int(responseTime * 1000)
self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

self.log_object.update(
{
"request": self.request,
"response": self.response,
}
)
self.portkey_logger.log(log_object=self.log_object)

return None
Loading

0 comments on commit 04483d4

Please sign in to comment.