diff --git a/portkey_ai/api_resources/apis/generation.py b/portkey_ai/api_resources/apis/generation.py index cc0b217..a723fb2 100644 --- a/portkey_ai/api_resources/apis/generation.py +++ b/portkey_ai/api_resources/apis/generation.py @@ -2,6 +2,11 @@ import warnings from typing import Literal, Optional, Union, Mapping, Any, overload from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient +from portkey_ai.api_resources.types.generation_type import ( + PromptCompletion, + PromptCompletionChunk, + PromptRender, +) from portkey_ai.api_resources.utils import ( retrieve_config, GenericResponse, @@ -88,14 +93,14 @@ def render( self, *, prompt_id: str, - variables: Optional[Mapping[str, Any]] = None, + variables: Mapping[str, Any], stream: bool = False, temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> GenericResponse: + ) -> PromptRender: """Prompt render Method""" body = { "variables": variables, @@ -110,8 +115,8 @@ def render( f"/prompts/{prompt_id}/render", body=body, params=None, - cast_to=GenericResponse, - stream_cls=Stream[GenericResponse], + cast_to=PromptRender, + stream_cls=Stream[PromptRender], stream=False, headers={}, ) @@ -128,13 +133,14 @@ async def render( self, *, prompt_id: str, - variables: Optional[Mapping[str, Any]] = None, + variables: Mapping[str, Any], + stream: bool = False, temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> GenericResponse: + ) -> PromptRender: """Prompt render Method""" body = { "variables": variables, @@ -142,15 +148,16 @@ async def render( "max_tokens": max_tokens, "top_k": top_k, "top_p": top_p, + "stream": stream, **kwargs, } return await self._post( f"/prompts/{prompt_id}/render", body=body, params=None, - cast_to=GenericResponse, + cast_to=PromptRender, stream=False, - stream_cls=AsyncStream[GenericResponse], + stream_cls=AsyncStream[PromptRender], headers={}, ) @@ -172,7 +179,7 @@ def create( top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> Stream[GenericResponse]: + ) -> Stream[PromptCompletionChunk]: ... @overload @@ -188,7 +195,7 @@ def create( top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> GenericResponse: + ) -> PromptCompletion: ... @overload @@ -204,7 +211,7 @@ def create( top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> Union[GenericResponse, Stream[GenericResponse]]: + ) -> Union[PromptCompletion, Stream[PromptCompletionChunk]]: ... def create( @@ -219,7 +226,7 @@ def create( top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> Union[GenericResponse, Stream[GenericResponse]]: + ) -> Union[PromptCompletion, Stream[PromptCompletionChunk],]: """Prompt completions Method""" if config is None: config = retrieve_config() @@ -236,8 +243,8 @@ def create( f"/prompts/{prompt_id}/completions", body=body, params=None, - cast_to=GenericResponse, - stream_cls=Stream[GenericResponse], + cast_to=PromptCompletion, + stream_cls=Stream[PromptCompletionChunk], stream=stream, headers={}, ) @@ -260,7 +267,7 @@ async def create( top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> AsyncStream[GenericResponse]: + ) -> AsyncStream[PromptCompletionChunk]: ... @overload @@ -276,7 +283,7 @@ async def create( top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> GenericResponse: + ) -> PromptCompletion: ... @overload @@ -292,7 +299,7 @@ async def create( top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> Union[GenericResponse, AsyncStream[GenericResponse]]: + ) -> Union[PromptCompletion, AsyncStream[PromptCompletionChunk]]: ... async def create( @@ -307,7 +314,7 @@ async def create( top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs, - ) -> Union[GenericResponse, AsyncStream[GenericResponse]]: + ) -> Union[PromptCompletion, AsyncStream[PromptCompletionChunk]]: """Prompt completions Method""" if config is None: config = retrieve_config() @@ -324,8 +331,8 @@ async def create( f"/prompts/{prompt_id}/completions", body=body, params=None, - cast_to=GenericResponse, - stream_cls=AsyncStream[GenericResponse], + cast_to=PromptCompletion, + stream_cls=AsyncStream[PromptCompletionChunk], stream=stream, headers={}, ) diff --git a/portkey_ai/api_resources/common_types.py b/portkey_ai/api_resources/common_types.py index ddb765a..f35a842 100644 --- a/portkey_ai/api_resources/common_types.py +++ b/portkey_ai/api_resources/common_types.py @@ -1,6 +1,11 @@ from typing import TypeVar, Union import httpx + +from portkey_ai.api_resources.types.generation_type import ( + PromptCompletionChunk, + PromptRender, +) from .streaming import Stream, AsyncStream from .utils import GenericResponse from .types.chat_complete_type import ChatCompletionChunk @@ -9,13 +14,27 @@ StreamT = TypeVar( "StreamT", bound=Stream[ - Union[ChatCompletionChunk, TextCompletionChunk, GenericResponse, httpx.Response] + Union[ + ChatCompletionChunk, + TextCompletionChunk, + GenericResponse, + PromptCompletionChunk, + PromptRender, + httpx.Response, + ] ], ) AsyncStreamT = TypeVar( "AsyncStreamT", bound=AsyncStream[ - Union[ChatCompletionChunk, TextCompletionChunk, GenericResponse, httpx.Response] + Union[ + ChatCompletionChunk, + TextCompletionChunk, + GenericResponse, + PromptCompletionChunk, + PromptRender, + httpx.Response, + ] ], ) diff --git a/portkey_ai/api_resources/types/generation_type.py b/portkey_ai/api_resources/types/generation_type.py new file mode 100644 index 0000000..417e4a8 --- /dev/null +++ b/portkey_ai/api_resources/types/generation_type.py @@ -0,0 +1,116 @@ +import json +from typing import Dict, Optional, Union +import httpx + +from portkey_ai.api_resources.types.chat_complete_type import ( + ChatCompletionMessage, + Choice, + StreamChoice, + Usage, +) +from portkey_ai.api_resources.types.complete_type import Logprobs, TextChoice + +from .utils import parse_headers +from typing import List, Any +from pydantic import BaseModel + + +class PromptCompletion(BaseModel): + id: Optional[str] + choices: List[Choice] + created: Optional[int] + model: Optional[str] + object: Optional[str] + system_fingerprint: Optional[str] = None + usage: Optional[Usage] = None + index: Optional[int] = None + text: Optional[str] = None + logprobs: Optional[Logprobs] = None + finish_reason: Optional[str] = None + _headers: Optional[httpx.Headers] = None + + def __str__(self): + return json.dumps(self.dict(), indent=4) + + def __getitem__(self, key): + return getattr(self, key, None) + + def get(self, key: str, default: Optional[Any] = None): + return getattr(self, key, None) or default + + def get_headers(self) -> Optional[Dict[str, str]]: + return parse_headers(self._headers) + + +class PromptCompletionChunk(BaseModel): + id: Optional[str] = None + object: Optional[str] = None + created: Optional[int] = None + model: Optional[str] = None + provider: Optional[str] = None + choices: Optional[Union[List[TextChoice], List[StreamChoice]]] + + def __str__(self): + return json.dumps(self.dict(), indent=4) + + def __getitem__(self, key): + return getattr(self, key, None) + + def get(self, key: str, default: Optional[Any] = None): + return getattr(self, key, None) or default + + +FunctionParameters = Dict[str, object] + + +class Function(BaseModel): + name: Optional[str] + description: Optional[str] = None + parameters: Optional[FunctionParameters] = None + + +class Tool(BaseModel): + function: Function + type: Optional[str] + + +class PromptRenderData(BaseModel): + messages: Optional[List[ChatCompletionMessage]] = None + prompt: Optional[str] = None + model: Optional[str] = None + suffix: Optional[str] = None + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + n: Optional[int] = None + stop_sequences: Optional[List[str]] = None + timeout: Union[float, None] = None + functions: Optional[List[Function]] = None + function_call: Optional[Union[None, str, Function]] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None + echo: Optional[bool] = None + stop: Optional[Union[str, List[str]]] = None + presence_penalty: Optional[int] = None + frequency_penalty: Optional[int] = None + best_of: Optional[int] = None + logit_bias: Optional[Dict[str, int]] = None + user: Optional[str] = None + organization: Optional[str] = None + tool_choice: Optional[Union[None, str]] = None + tools: Optional[List[Tool]] = None + + +class PromptRender(BaseModel): + success: Optional[bool] = True + data: PromptRenderData + + def __str__(self): + return json.dumps(self.dict(), indent=4) + + def __getitem__(self, key): + return getattr(self, key, None) + + def get(self, key: str, default: Optional[Any] = None): + return getattr(self, key, None) or default diff --git a/portkey_ai/api_resources/utils.py b/portkey_ai/api_resources/utils.py index 2308e5e..860952e 100644 --- a/portkey_ai/api_resources/utils.py +++ b/portkey_ai/api_resources/utils.py @@ -15,6 +15,11 @@ TextCompletionChunk, TextCompletion, ) +from portkey_ai.api_resources.types.generation_type import ( + PromptCompletion, + PromptCompletionChunk, + PromptRender, +) from .exceptions import ( APIStatusError, BadRequestError, @@ -56,7 +61,7 @@ class CacheType(str, Enum, metaclass=MetaEnum): ResponseT = TypeVar( "ResponseT", - bound="Union[ChatCompletionChunk, ChatCompletions, TextCompletion, TextCompletionChunk, GenericResponse, httpx.Response]", # noqa: E501 + bound="Union[ChatCompletionChunk, ChatCompletions, TextCompletion, TextCompletionChunk, GenericResponse, PromptCompletion, PromptCompletionChunk, PromptRender, httpx.Response]", # noqa: E501 )