Skip to content

Commit

Permalink
Merge pull request #110 from Portkey-AI/fix/promptResponse
Browse files Browse the repository at this point in the history
Fix/prompt response
  • Loading branch information
VisargD authored Apr 5, 2024
2 parents 862cc8a + e3cd216 commit 81802fc
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 23 deletions.
47 changes: 27 additions & 20 deletions portkey_ai/api_resources/apis/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import warnings
from typing import Literal, Optional, Union, Mapping, Any, overload
from portkey_ai.api_resources.base_client import APIClient, AsyncAPIClient
from portkey_ai.api_resources.types.generation_type import (
PromptCompletion,
PromptCompletionChunk,
PromptRender,
)
from portkey_ai.api_resources.utils import (
retrieve_config,
GenericResponse,
Expand Down Expand Up @@ -88,14 +93,14 @@ def render(
self,
*,
prompt_id: str,
variables: Optional[Mapping[str, Any]] = None,
variables: Mapping[str, Any],
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> GenericResponse:
) -> PromptRender:
"""Prompt render Method"""
body = {
"variables": variables,
Expand All @@ -110,8 +115,8 @@ def render(
f"/prompts/{prompt_id}/render",
body=body,
params=None,
cast_to=GenericResponse,
stream_cls=Stream[GenericResponse],
cast_to=PromptRender,
stream_cls=Stream[PromptRender],
stream=False,
headers={},
)
Expand All @@ -128,29 +133,31 @@ async def render(
self,
*,
prompt_id: str,
variables: Optional[Mapping[str, Any]] = None,
variables: Mapping[str, Any],
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> GenericResponse:
) -> PromptRender:
"""Prompt render Method"""
body = {
"variables": variables,
"temperature": temperature,
"max_tokens": max_tokens,
"top_k": top_k,
"top_p": top_p,
"stream": stream,
**kwargs,
}
return await self._post(
f"/prompts/{prompt_id}/render",
body=body,
params=None,
cast_to=GenericResponse,
cast_to=PromptRender,
stream=False,
stream_cls=AsyncStream[GenericResponse],
stream_cls=AsyncStream[PromptRender],
headers={},
)

Expand All @@ -172,7 +179,7 @@ def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Stream[GenericResponse]:
) -> Stream[PromptCompletionChunk]:
...

@overload
Expand All @@ -188,7 +195,7 @@ def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> GenericResponse:
) -> PromptCompletion:
...

@overload
Expand All @@ -204,7 +211,7 @@ def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Union[GenericResponse, Stream[GenericResponse]]:
) -> Union[PromptCompletion, Stream[PromptCompletionChunk]]:
...

def create(
Expand All @@ -219,7 +226,7 @@ def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Union[GenericResponse, Stream[GenericResponse]]:
) -> Union[PromptCompletion, Stream[PromptCompletionChunk],]:
"""Prompt completions Method"""
if config is None:
config = retrieve_config()
Expand All @@ -236,8 +243,8 @@ def create(
f"/prompts/{prompt_id}/completions",
body=body,
params=None,
cast_to=GenericResponse,
stream_cls=Stream[GenericResponse],
cast_to=PromptCompletion,
stream_cls=Stream[PromptCompletionChunk],
stream=stream,
headers={},
)
Expand All @@ -260,7 +267,7 @@ async def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> AsyncStream[GenericResponse]:
) -> AsyncStream[PromptCompletionChunk]:
...

@overload
Expand All @@ -276,7 +283,7 @@ async def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> GenericResponse:
) -> PromptCompletion:
...

@overload
Expand All @@ -292,7 +299,7 @@ async def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Union[GenericResponse, AsyncStream[GenericResponse]]:
) -> Union[PromptCompletion, AsyncStream[PromptCompletionChunk]]:
...

async def create(
Expand All @@ -307,7 +314,7 @@ async def create(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs,
) -> Union[GenericResponse, AsyncStream[GenericResponse]]:
) -> Union[PromptCompletion, AsyncStream[PromptCompletionChunk]]:
"""Prompt completions Method"""
if config is None:
config = retrieve_config()
Expand All @@ -324,8 +331,8 @@ async def create(
f"/prompts/{prompt_id}/completions",
body=body,
params=None,
cast_to=GenericResponse,
stream_cls=AsyncStream[GenericResponse],
cast_to=PromptCompletion,
stream_cls=AsyncStream[PromptCompletionChunk],
stream=stream,
headers={},
)
23 changes: 21 additions & 2 deletions portkey_ai/api_resources/common_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import TypeVar, Union

import httpx

from portkey_ai.api_resources.types.generation_type import (
PromptCompletionChunk,
PromptRender,
)
from .streaming import Stream, AsyncStream
from .utils import GenericResponse
from .types.chat_complete_type import ChatCompletionChunk
Expand All @@ -9,13 +14,27 @@
StreamT = TypeVar(
"StreamT",
bound=Stream[
Union[ChatCompletionChunk, TextCompletionChunk, GenericResponse, httpx.Response]
Union[
ChatCompletionChunk,
TextCompletionChunk,
GenericResponse,
PromptCompletionChunk,
PromptRender,
httpx.Response,
]
],
)

AsyncStreamT = TypeVar(
"AsyncStreamT",
bound=AsyncStream[
Union[ChatCompletionChunk, TextCompletionChunk, GenericResponse, httpx.Response]
Union[
ChatCompletionChunk,
TextCompletionChunk,
GenericResponse,
PromptCompletionChunk,
PromptRender,
httpx.Response,
]
],
)
116 changes: 116 additions & 0 deletions portkey_ai/api_resources/types/generation_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
from typing import Dict, Optional, Union
import httpx

from portkey_ai.api_resources.types.chat_complete_type import (
ChatCompletionMessage,
Choice,
StreamChoice,
Usage,
)
from portkey_ai.api_resources.types.complete_type import Logprobs, TextChoice

from .utils import parse_headers
from typing import List, Any
from pydantic import BaseModel


class PromptCompletion(BaseModel):
id: Optional[str]
choices: List[Choice]
created: Optional[int]
model: Optional[str]
object: Optional[str]
system_fingerprint: Optional[str] = None
usage: Optional[Usage] = None
index: Optional[int] = None
text: Optional[str] = None
logprobs: Optional[Logprobs] = None
finish_reason: Optional[str] = None
_headers: Optional[httpx.Headers] = None

def __str__(self):
return json.dumps(self.dict(), indent=4)

def __getitem__(self, key):
return getattr(self, key, None)

def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default

def get_headers(self) -> Optional[Dict[str, str]]:
return parse_headers(self._headers)


class PromptCompletionChunk(BaseModel):
id: Optional[str] = None
object: Optional[str] = None
created: Optional[int] = None
model: Optional[str] = None
provider: Optional[str] = None
choices: Optional[Union[List[TextChoice], List[StreamChoice]]]

def __str__(self):
return json.dumps(self.dict(), indent=4)

def __getitem__(self, key):
return getattr(self, key, None)

def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default


FunctionParameters = Dict[str, object]


class Function(BaseModel):
name: Optional[str]
description: Optional[str] = None
parameters: Optional[FunctionParameters] = None


class Tool(BaseModel):
function: Function
type: Optional[str]


class PromptRenderData(BaseModel):
messages: Optional[List[ChatCompletionMessage]] = None
prompt: Optional[str] = None
model: Optional[str] = None
suffix: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop_sequences: Optional[List[str]] = None
timeout: Union[float, None] = None
functions: Optional[List[Function]] = None
function_call: Optional[Union[None, str, Function]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
echo: Optional[bool] = None
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[int] = None
frequency_penalty: Optional[int] = None
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, int]] = None
user: Optional[str] = None
organization: Optional[str] = None
tool_choice: Optional[Union[None, str]] = None
tools: Optional[List[Tool]] = None


class PromptRender(BaseModel):
success: Optional[bool] = True
data: PromptRenderData

def __str__(self):
return json.dumps(self.dict(), indent=4)

def __getitem__(self, key):
return getattr(self, key, None)

def get(self, key: str, default: Optional[Any] = None):
return getattr(self, key, None) or default
7 changes: 6 additions & 1 deletion portkey_ai/api_resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
TextCompletionChunk,
TextCompletion,
)
from portkey_ai.api_resources.types.generation_type import (
PromptCompletion,
PromptCompletionChunk,
PromptRender,
)
from .exceptions import (
APIStatusError,
BadRequestError,
Expand Down Expand Up @@ -56,7 +61,7 @@ class CacheType(str, Enum, metaclass=MetaEnum):

ResponseT = TypeVar(
"ResponseT",
bound="Union[ChatCompletionChunk, ChatCompletions, TextCompletion, TextCompletionChunk, GenericResponse, httpx.Response]", # noqa: E501
bound="Union[ChatCompletionChunk, ChatCompletions, TextCompletion, TextCompletionChunk, GenericResponse, PromptCompletion, PromptCompletionChunk, PromptRender, httpx.Response]", # noqa: E501
)


Expand Down

0 comments on commit 81802fc

Please sign in to comment.