Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: clean up cohere templating #1030

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 5 additions & 37 deletions instructor/client_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@

import cohere
import instructor
from functools import wraps
from typing import (
TypeVar,
overload,
)
from typing import Any
from typing_extensions import ParamSpec
from pydantic import BaseModel
from instructor.patch import handle_cohere_templating, handle_context
from instructor.process_response import handle_response_model
from instructor.retry import retry_async


T_Model = TypeVar("T_Model", bound=BaseModel)
Expand Down Expand Up @@ -58,39 +54,11 @@ def from_cohere(
**kwargs,
)

@wraps(client.chat)
async def new_create_async(
response_model: type[T_Model] | None = None,
validation_context: dict[str, Any] | None = None,
max_retries: int = 1,
context: dict[str, Any] | None = None,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
prepared_response_model, new_kwargs = handle_response_model(
response_model=response_model,
if isinstance(client, cohere.AsyncClient):
return instructor.AsyncInstructor(
client=client,
create=instructor.patch(create=client.chat, mode=mode),
provider=instructor.Provider.COHERE,
mode=mode,
**kwargs,
)

context = handle_context(context, validation_context)
new_kwargs = handle_cohere_templating(new_kwargs, context)

response = await retry_async(
func=client.chat,
response_model=prepared_response_model,
context=context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
)
return response

return instructor.AsyncInstructor(
client=client,
create=new_create_async,
provider=instructor.Provider.COHERE,
mode=mode,
**kwargs,
)
88 changes: 42 additions & 46 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,9 @@
return context


def handle_cohere_templating(
new_kwargs: dict[str, Any], context: dict[str, Any] | None = None
) -> dict[str, Any]:
if not context:
return new_kwargs

from textwrap import dedent
from jinja2 import Template

new_kwargs["message"] = (
dedent(Template(new_kwargs["message"]).render(**context))
if context
else new_kwargs["message"]
)
new_kwargs["chat_history"] = handle_templating(new_kwargs["chat_history"], context)
return new_kwargs


def handle_templating(
messages: list[dict[str, Any]], context: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
kwargs: dict[str, Any], context: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
Handle templating for messages using the provided context.

Expand All @@ -106,47 +88,70 @@
Anthropic's format with parts.

Args:
messages (list[dict[str, Any]]): A list of message dictionaries to process.
kwargs (dict[str, Any]): Keyword arguments being passed to the create method.
context (dict[str, Any] | None, optional): A dictionary of variables to use in templating.
Defaults to None.

Returns:
list[dict[str, Any]]: The processed list of messages with templated content.

Note:
- If no context is provided, the original messages are returned unchanged.
- For OpenAI format, the 'content' field is processed if it's a string.
- For Anthropic format, each 'text' part within the 'content' list is processed.
- The function uses Jinja2 for templating and applies textwrap.dedent for formatting.

TODO: Gemini, Cohere, formats are missing here.
"""
if context is None:
return messages
return kwargs

from jinja2 import Template
from textwrap import dedent

new_kwargs = kwargs.copy()

assert any(
key in new_kwargs for key in ["message", "messages", "contents"]
), "Expected 'message', 'messages' or 'contents' in kwargs"

# Handle templating for Cohere's message field
if "message" in new_kwargs:
new_kwargs["message"] = (
dedent(Template(new_kwargs["message"]).render(**context))
if context
else new_kwargs["message"]
)
new_kwargs["chat_history"] = handle_templating(
new_kwargs["chat_history"], context
)
return new_kwargs

if "messages" in new_kwargs:
messages = new_kwargs["messages"]
elif "contents" in new_kwargs:
messages = new_kwargs["contents"]
else:
raise ValueError("Expected 'message', 'messages' or 'contents' in kwargs")

# Handle templating for OpenAI and Anthropic
for message in messages:
if hasattr(message, "parts"):
# VertexAI Support
if isinstance(message.parts, list): # type: ignore

Check failure on line 135 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.9)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

Check failure on line 135 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.10)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

Check failure on line 135 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.11)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)
import vertexai.generative_models as gm # type: ignore

return gm.Content(
role=message.role, # type: ignore
parts=[
gm.Part.from_text(dedent(Template(part.text).render(**context))) # type: ignore
if hasattr(part, "text") # type: ignore
else part
(
gm.Part.from_text(dedent(Template(part.text).render(**context))) # type: ignore
if hasattr(part, "text") # type: ignore
else part
)
for part in message.parts # type: ignore
],
)
return message # type: ignore

Check failure on line 149 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.9)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

Check failure on line 149 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.10)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

Check failure on line 149 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.11)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

if isinstance(message.get("message"), str):
message["message"] = dedent(Template(message["message"]).render(**context))
continue

# Handle OpenAI format
if isinstance(message.get("content"), str):
message["content"] = dedent(Template(message["content"]).render(**context))
Expand All @@ -160,7 +165,9 @@
and part.get("type") == "text" # type:ignore
and isinstance(part.get("text"), str) # type:ignore
):
part["text"] = dedent(Template(part["text"]).render(**context)) # type:ignore
part["text"] = dedent(
Template(part["text"]).render(**context)

Check failure on line 169 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.9)

Argument type is unknown   Argument corresponds to parameter "source" in function "__new__" (reportUnknownArgumentType)

Check failure on line 169 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.10)

Argument type is unknown   Argument corresponds to parameter "source" in function "__new__" (reportUnknownArgumentType)

Check failure on line 169 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.11)

Argument type is unknown   Argument corresponds to parameter "source" in function "__new__" (reportUnknownArgumentType)
) # type:ignore

Check failure on line 170 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.9)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

Check failure on line 170 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.10)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

Check failure on line 170 in instructor/patch.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.11)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

# Gemini Support
if isinstance(message.get("parts"), list):
Expand All @@ -172,7 +179,7 @@
new_parts.append(part) # type: ignore
message["parts"] = new_parts

return messages
return new_kwargs


@overload
Expand Down Expand Up @@ -245,12 +252,7 @@
response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
)
if "messages" in new_kwargs:
new_kwargs["messages"] = handle_templating(new_kwargs["messages"], context)

elif "contents" in new_kwargs:
new_kwargs["contents"] = handle_templating(new_kwargs["contents"], context)

new_kwargs = handle_templating(new_kwargs, context)

response = await retry_async(
func=func, # type: ignore
Expand Down Expand Up @@ -280,13 +282,7 @@
response_model=response_model, mode=mode, **kwargs
)

if "messages" in new_kwargs:
new_kwargs["messages"] = handle_templating(new_kwargs["messages"], context)
elif "message" in new_kwargs and "chat_history" in new_kwargs:
new_kwargs = handle_cohere_templating(new_kwargs, context)

elif "contents" in new_kwargs:
new_kwargs["contents"] = handle_templating(new_kwargs["contents"], context)
new_kwargs = handle_templating(new_kwargs, context)

response = retry_sync(
func=func, # type: ignore
Expand Down
Loading