Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl committed Jun 11, 2024
1 parent b53afab commit f97ecde
Show file tree
Hide file tree
Showing 18 changed files with 80 additions and 92 deletions.
2 changes: 1 addition & 1 deletion instructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,4 @@
if importlib.util.find_spec("vertexai") is not None:
from .client_vertexai import from_vertexai

__all__ += ["from_vertexai"]
__all__ += ["from_vertexai"]
10 changes: 5 additions & 5 deletions instructor/cli/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def calculate_cost(

def group_and_sum_by_date_and_snapshot(usage_data: list[dict[str, Any]]) -> Table:
"""Group and sum the usage data by date and snapshot, including costs."""
summary: defaultdict[
str, defaultdict[str, dict[str, Union[int, float]]]
] = defaultdict(
lambda: defaultdict(
lambda: {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
summary: defaultdict[str, defaultdict[str, dict[str, Union[int, float]]]] = (
defaultdict(
lambda: defaultdict(
lambda: {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
)
)
)

Expand Down
27 changes: 9 additions & 18 deletions instructor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def create(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> Awaitable[T]:
...
) -> Awaitable[T]: ...

@overload
def create(
Expand All @@ -75,8 +74,7 @@ def create(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> T:
...
) -> T: ...

# TODO: we should overload a case where response_model is None
def create(
Expand Down Expand Up @@ -108,8 +106,7 @@ def create_partial(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> AsyncGenerator[T, None]:
...
) -> AsyncGenerator[T, None]: ...

@overload
def create_partial(
Expand All @@ -120,8 +117,7 @@ def create_partial(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> Generator[T, None, None]:
...
) -> Generator[T, None, None]: ...

def create_partial(
self,
Expand Down Expand Up @@ -155,8 +151,7 @@ def create_iterable(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> AsyncGenerator[T, None]:
...
) -> AsyncGenerator[T, None]: ...

@overload
def create_iterable(
Expand All @@ -167,8 +162,7 @@ def create_iterable(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> Generator[T, None, None]:
...
) -> Generator[T, None, None]: ...

def create_iterable(
self,
Expand Down Expand Up @@ -203,8 +197,7 @@ def create_with_completion(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> Awaitable[tuple[T, Any]]:
...
) -> Awaitable[tuple[T, Any]]: ...

@overload
def create_with_completion(
Expand All @@ -215,8 +208,7 @@ def create_with_completion(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> tuple[T, Any]:
...
) -> tuple[T, Any]: ...

def create_with_completion(
self,
Expand Down Expand Up @@ -432,8 +424,7 @@ def from_litellm(
completion: Callable[..., Any],
mode: instructor.Mode = instructor.Mode.TOOLS,
**kwargs: Any,
) -> Instructor:
...
) -> Instructor: ...


@overload
Expand Down
6 changes: 2 additions & 4 deletions instructor/client_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def from_anthropic(
),
mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS,
**kwargs: Any,
) -> instructor.Instructor:
...
) -> instructor.Instructor: ...


@overload
Expand All @@ -26,8 +25,7 @@ def from_anthropic(
),
mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS,
**kwargs: Any,
) -> instructor.AsyncInstructor:
...
) -> instructor.AsyncInstructor: ...


def from_anthropic(
Expand Down
6 changes: 2 additions & 4 deletions instructor/client_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,15 @@ def from_cohere(
client: cohere.Client,
mode: instructor.Mode = instructor.Mode.COHERE_TOOLS,
**kwargs: Any,
) -> instructor.Instructor:
...
) -> instructor.Instructor: ...


@overload
def from_cohere(
client: cohere.AsyncClient,
mode: instructor.Mode = instructor.Mode.COHERE_TOOLS,
**kwargs: Any,
) -> instructor.AsyncInstructor:
...
) -> instructor.AsyncInstructor: ...


def from_cohere(
Expand Down
4 changes: 3 additions & 1 deletion instructor/client_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def from_gemini(
use_async: bool = False,
**kwargs: Any,
) -> instructor.Instructor | instructor.AsyncInstructor:
assert mode == instructor.Mode.GEMINI_JSON, "Mode must be instructor.Mode.GEMINI_JSON"
assert (
mode == instructor.Mode.GEMINI_JSON
), "Mode must be instructor.Mode.GEMINI_JSON"

assert isinstance(
client,
Expand Down
6 changes: 2 additions & 4 deletions instructor/client_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@ def from_groq(
client: groq.Groq,
mode: instructor.Mode = instructor.Mode.TOOLS,
**kwargs: Any,
) -> instructor.Instructor:
...
) -> instructor.Instructor: ...


@overload
def from_groq(
client: groq.AsyncGroq,
mode: instructor.Mode = instructor.Mode.TOOLS,
**kwargs: Any,
) -> instructor.AsyncInstructor:
...
) -> instructor.AsyncInstructor: ...


def from_groq(
Expand Down
6 changes: 2 additions & 4 deletions instructor/client_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ def from_mistral(
client: mistralai.client.MistralClient,
mode: instructor.Mode = instructor.Mode.MISTRAL_TOOLS,
**kwargs: Any,
) -> instructor.Instructor:
...
) -> instructor.Instructor: ...


@overload
def from_mistral(
client: mistralaiasynccli.MistralAsyncClient,
mode: instructor.Mode = instructor.Mode.MISTRAL_TOOLS,
**kwargs: Any,
) -> instructor.AsyncInstructor:
...
) -> instructor.AsyncInstructor: ...


def from_mistral(
Expand Down
57 changes: 30 additions & 27 deletions instructor/client_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,60 @@

from typing import Any

from vertexai.preview.generative_models import ToolConfig #type: ignore[reportMissingTypeStubs]
import vertexai.generative_models as gm #type: ignore[reportMissingTypeStubs]
from vertexai.preview.generative_models import ToolConfig # type: ignore[reportMissingTypeStubs]
import vertexai.generative_models as gm # type: ignore[reportMissingTypeStubs]
from pydantic import BaseModel
import instructor
import jsonref #type: ignore[reportMissingTypeStubs]
import jsonref # type: ignore[reportMissingTypeStubs]


def _create_vertexai_tool(model: BaseModel) -> gm.Tool:
schema: dict[Any, Any] = jsonref.replace_refs(model.model_json_schema()) #type: ignore[reportMissingTypeStubs]
schema: dict[Any, Any] = jsonref.replace_refs(model.model_json_schema()) # type: ignore[reportMissingTypeStubs]

parameters: dict[Any, Any] = {
"type": schema["type"],
"properties": schema["properties"],
"required": schema["required"]
"required": schema["required"],
}

declaration = gm.FunctionDeclaration(
name=model.__name__,
description=model.__doc__,
parameters=parameters
name=model.__name__, description=model.__doc__, parameters=parameters
)

tool = gm.Tool(function_declarations=[declaration])

return tool


def _vertexai_message_parser(message: dict[str, str]) -> gm.Content:
return gm.Content(
role=message["role"],
parts=[
gm.Part.from_text(message["content"])
]
)
role=message["role"], parts=[gm.Part.from_text(message["content"])]
)


def vertexai_function_response_parser(response: gm.GenerationResponse, exception: Exception) -> gm.Content:
def vertexai_function_response_parser(
response: gm.GenerationResponse, exception: Exception
) -> gm.Content:
return gm.Content(
parts=[
gm.Part.from_function_response(
name=response.candidates[0].content.parts[0].function_call.name,
response={
"content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors"
}
)
]
parts=[
gm.Part.from_function_response(
name=response.candidates[0].content.parts[0].function_call.name,
response={
"content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors"
},
)
]
)


def vertexai_process_response(_kwargs: dict[str, Any], model: BaseModel):
messages = _kwargs.pop("messages")
contents = [
_vertexai_message_parser(message) #type: ignore[reportUnkownArgumentType]
if isinstance(message, dict) else message
_vertexai_message_parser(message) # type: ignore[reportUnkownArgumentType]
if isinstance(message, dict)
else message
for message in messages
]
]
tool = _create_vertexai_tool(model=model)
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
Expand All @@ -70,7 +71,9 @@ def from_vertexai(
_async: bool = False,
**kwargs: Any,
) -> instructor.Instructor:
assert mode == instructor.Mode.VERTEXAI_TOOLS, "Mode must be instructor.Mode.VERTEXAI_TOOLS"
assert (
mode == instructor.Mode.VERTEXAI_TOOLS
), "Mode must be instructor.Mode.VERTEXAI_TOOLS"

assert isinstance(
client, gm.GenerativeModel
Expand Down
2 changes: 2 additions & 0 deletions instructor/dsl/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def from_response(

if sys.version_info >= (3, 10):
from types import UnionType

def is_union_type(typehint: type[Iterable[T]]) -> bool:
return get_origin(get_args(typehint)[0]) in (Union, UnionType)
else:

def is_union_type(typehint: type[Iterable[T]]) -> bool:
return get_origin(get_args(typehint)[0]) is Union

Expand Down
6 changes: 3 additions & 3 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ def parse_vertexai_tools(
validation_context: Optional[dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
strict=False
tool_call= completion.candidates[0].content.parts[0].function_call.args # type: ignore
strict = False
tool_call = completion.candidates[0].content.parts[0].function_call.args # type: ignore
model = {}
for field in tool_call: # type: ignore
for field in tool_call: # type: ignore
model[field] = tool_call[field]
return cls.model_validate(model, context=validation_context, strict=strict)

Expand Down
18 changes: 6 additions & 12 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def __call__(
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
...
) -> T_Model: ...


class AsyncInstructorChatCompletionCreate(Protocol):
Expand All @@ -47,40 +46,35 @@ async def __call__(
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
...
) -> T_Model: ...


@overload
def patch(
client: OpenAI,
mode: Mode = Mode.TOOLS,
) -> OpenAI:
...
) -> OpenAI: ...


@overload
def patch(
client: AsyncOpenAI,
mode: Mode = Mode.TOOLS,
) -> AsyncOpenAI:
...
) -> AsyncOpenAI: ...


@overload
def patch(
create: Callable[T_ParamSpec, T_Retval],
mode: Mode = Mode.TOOLS,
) -> InstructorChatCompletionCreate:
...
) -> InstructorChatCompletionCreate: ...


@overload
def patch(
create: Awaitable[T_Retval],
mode: Mode = Mode.TOOLS,
) -> InstructorChatCompletionCreate:
...
) -> InstructorChatCompletionCreate: ...


def patch(
Expand Down
Loading

0 comments on commit f97ecde

Please sign in to comment.