Skip to content

Commit

Permalink
feat: Added support for configslug in the apis
Browse files Browse the repository at this point in the history
  • Loading branch information
noble-varghese committed Oct 10, 2023
1 parent 4bf49c3 commit 9dd9286
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 18 deletions.
2 changes: 1 addition & 1 deletion portkey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

api_key = os.environ.get(PORTKEY_API_KEY_ENV)
base_url = os.environ.get(PORTKEY_PROXY_ENV, PORTKEY_BASE_URL)
config: Optional[Config] = None
config: Optional[Union[Config, str]] = None
mode: Optional[Union[Modes, ModesLiteral]] = None
__version__ = VERSION
__all__ = [
Expand Down
63 changes: 51 additions & 12 deletions portkey/api_resources/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .utils import (
Modes,
Config,
ConfigSlug,
retrieve_config,
Params,
Message,
Expand Down Expand Up @@ -43,7 +44,7 @@ def create(
cls,
*,
prompt: Optional[str] = None,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
stream: Literal[True],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Expand All @@ -59,7 +60,7 @@ def create(
cls,
*,
prompt: Optional[str] = None,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
stream: Literal[False] = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Expand All @@ -75,7 +76,7 @@ def create(
cls,
*,
prompt: Optional[str] = None,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Expand All @@ -90,7 +91,7 @@ def create(
cls,
*,
prompt: Optional[str] = None,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Expand All @@ -100,7 +101,6 @@ def create(
) -> Union[TextCompletion, Stream[TextCompletionChunk]]:
if config is None:
config = retrieve_config()
_client = APIClient(api_key=config.api_key, base_url=config.base_url)
params = Params(
prompt=prompt,
temperature=temperature,
Expand All @@ -109,6 +109,24 @@ def create(
top_p=top_p,
**kwargs,
)
_client = (
APIClient()
if isinstance(config, str)
else APIClient(api_key=config.api_key, base_url=config.base_url)
)

if isinstance(config, str):
body = ConfigSlug(config=config)
return cls(_client)._post(
"/v1/complete",
body=body,
params=params,
cast_to=ChatCompletion,
stream_cls=Stream[TextCompletionChunk],
stream=stream,
mode="",
)

if config.mode == Modes.SINGLE.value:
return cls(_client)._post(
"/v1/complete",
Expand Down Expand Up @@ -149,7 +167,7 @@ def create(
cls,
*,
messages: Optional[List[Message]] = None,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
stream: Literal[True],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Expand All @@ -165,7 +183,7 @@ def create(
cls,
*,
messages: Optional[List[Message]] = None,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
stream: Literal[False] = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Expand All @@ -181,7 +199,7 @@ def create(
cls,
*,
messages: Optional[List[Message]] = None,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Expand All @@ -196,7 +214,7 @@ def create(
cls,
*,
messages: Optional[List[Message]] = None,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Expand All @@ -206,7 +224,6 @@ def create(
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
if config is None:
config = retrieve_config()
_client = APIClient(api_key=config.api_key, base_url=config.base_url)
params = Params(
messages=messages,
temperature=temperature,
Expand All @@ -215,6 +232,24 @@ def create(
top_p=top_p,
**kwargs,
)
_client = (
APIClient()
if isinstance(config, str)
else APIClient(api_key=config.api_key, base_url=config.base_url)
)

if isinstance(config, str):
body = ConfigSlug(config=config)
return cls(_client)._post(
"/v1/chatComplete",
body=body,
params=params,
cast_to=ChatCompletion,
stream_cls=Stream[ChatCompletionChunk],
stream=stream,
mode="",
)

if config.mode == Modes.SINGLE.value:
return cls(_client)._post(
"/v1/chatComplete",
Expand Down Expand Up @@ -254,12 +289,16 @@ def create(
cls,
*,
prompt_id: str,
config: Optional[Config] = None,
config: Optional[Union[Config, str]] = None,
variables: Optional[Mapping[str, Any]] = None,
) -> Union[GenericResponse, Stream[GenericResponse]]:
if config is None:
config = retrieve_config()
_client = APIClient(api_key=config.api_key, base_url=config.base_url)
_client = (
APIClient()
if isinstance(config, str)
else APIClient(api_key=config.api_key, base_url=config.base_url)
)
body = {"variables": variables}
return cls(_client)._post(
f"/v1/prompts/{prompt_id}/generate",
Expand Down
12 changes: 9 additions & 3 deletions portkey/api_resources/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .utils import (
remove_empty_values,
Body,
ConfigSlug,
Options,
RequestConfig,
OverrideParams,
Expand Down Expand Up @@ -127,7 +128,7 @@ def post(
self,
path: str,
*,
body: Union[List[Body], Any],
body: Union[List[Body], Any, ConfigSlug],
mode: str,
cast_to: Type[ResponseT],
stream: bool,
Expand Down Expand Up @@ -187,7 +188,7 @@ def _construct(
*,
method: str,
url: str,
body: List[Body],
body: Union[List[Body], ConfigSlug],
mode: str,
stream: bool,
params: Params,
Expand All @@ -196,8 +197,13 @@ def _construct(
opts.method = method
opts.url = url
params_dict = {} if params is None else params.dict()
config = (
body.config
if isinstance(body, ConfigSlug)
else self._config(mode, body).dict()
)
json_body = {
"config": self._config(mode, body).dict(),
"config": config,
"params": {**params_dict, "stream": stream},
}
opts.json_body = remove_empty_values(json_body)
Expand Down
7 changes: 5 additions & 2 deletions portkey/api_resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ class ModelParams(BaseModel):
n: Optional[int] = None
stop_sequences: Optional[List[str]] = None
timeout: Union[float, None] = None
retry_settings: Optional[RetrySettings] = None
functions: Optional[List[Function]] = None
function_call: Optional[Union[None, str, Function]] = None
logprobs: Optional[int] = None
Expand Down Expand Up @@ -234,6 +233,10 @@ class Body(LLMOptions):
...


class ConfigSlug(BaseModel):
config: str


class Params(ConversationInput, ModelParams, extra="forbid"):
...

Expand Down Expand Up @@ -496,7 +499,7 @@ def default_base_url() -> str:
raise ValueError(MISSING_BASE_URL)


def retrieve_config() -> Config:
def retrieve_config() -> Union[Config, str]:
if portkey.config:
return portkey.config
raise ValueError(MISSING_CONFIG_MESSAGE)
Expand Down

0 comments on commit 9dd9286

Please sign in to comment.