Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: clean up cohere templating #1030

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 5 additions & 37 deletions instructor/client_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@

import cohere
import instructor
from functools import wraps
from typing import (
TypeVar,
overload,
)
from typing import Any
from typing_extensions import ParamSpec
from pydantic import BaseModel
from instructor.patch import handle_cohere_templating, handle_context
from instructor.process_response import handle_response_model
from instructor.retry import retry_async


T_Model = TypeVar("T_Model", bound=BaseModel)
Expand Down Expand Up @@ -58,39 +54,11 @@ def from_cohere(
**kwargs,
)

@wraps(client.chat)
async def new_create_async(
response_model: type[T_Model] | None = None,
validation_context: dict[str, Any] | None = None,
max_retries: int = 1,
context: dict[str, Any] | None = None,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
prepared_response_model, new_kwargs = handle_response_model(
response_model=response_model,
if isinstance(client, cohere.AsyncClient):
return instructor.AsyncInstructor(
client=client,
create=instructor.patch(create=client.chat, mode=mode),
provider=instructor.Provider.COHERE,
mode=mode,
**kwargs,
)

context = handle_context(context, validation_context)
new_kwargs = handle_cohere_templating(new_kwargs, context)

response = await retry_async(
func=client.chat,
response_model=prepared_response_model,
context=context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
)
return response

return instructor.AsyncInstructor(
client=client,
create=new_create_async,
provider=instructor.Provider.COHERE,
mode=mode,
**kwargs,
)
114 changes: 3 additions & 111 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from instructor.process_response import handle_response_model
from instructor.retry import retry_async, retry_sync
from instructor.utils import is_async
from instructor.templating import handle_templating

from instructor.mode import Mode
import logging
Expand Down Expand Up @@ -77,104 +78,6 @@ def handle_context(
return context


def handle_cohere_templating(
new_kwargs: dict[str, Any], context: dict[str, Any] | None = None
) -> dict[str, Any]:
if not context:
return new_kwargs

from textwrap import dedent
from jinja2 import Template

new_kwargs["message"] = (
dedent(Template(new_kwargs["message"]).render(**context))
if context
else new_kwargs["message"]
)
new_kwargs["chat_history"] = handle_templating(new_kwargs["chat_history"], context)
return new_kwargs


def handle_templating(
messages: list[dict[str, Any]], context: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
"""
Handle templating for messages using the provided context.

This function processes a list of messages, applying Jinja2 templating to their content
using the provided context. It supports both standard OpenAI message format and
Anthropic's format with parts.

Args:
messages (list[dict[str, Any]]): A list of message dictionaries to process.
context (dict[str, Any] | None, optional): A dictionary of variables to use in templating.
Defaults to None.

Returns:
list[dict[str, Any]]: The processed list of messages with templated content.

Note:
- If no context is provided, the original messages are returned unchanged.
- For OpenAI format, the 'content' field is processed if it's a string.
- For Anthropic format, each 'text' part within the 'content' list is processed.
- The function uses Jinja2 for templating and applies textwrap.dedent for formatting.

TODO: Gemini, Cohere, formats are missing here.
"""
if context is None:
return messages

from jinja2 import Template
from textwrap import dedent

for message in messages:
if hasattr(message, "parts"):
# VertexAI Support
if isinstance(message.parts, list): # type: ignore
import vertexai.generative_models as gm # type: ignore

return gm.Content(
role=message.role, # type: ignore
parts=[
gm.Part.from_text(dedent(Template(part.text).render(**context))) # type: ignore
if hasattr(part, "text") # type: ignore
else part
for part in message.parts # type: ignore
],
)
return message # type: ignore

if isinstance(message.get("message"), str):
message["message"] = dedent(Template(message["message"]).render(**context))
continue
# Handle OpenAI format
if isinstance(message.get("content"), str):
message["content"] = dedent(Template(message["content"]).render(**context))
continue

# Handle Anthropic format
if isinstance(message.get("content"), list):
for part in message["content"]:
if (
isinstance(part, dict)
and part.get("type") == "text" # type:ignore
and isinstance(part.get("text"), str) # type:ignore
):
part["text"] = dedent(Template(part["text"]).render(**context)) # type:ignore

# Gemini Support
if isinstance(message.get("parts"), list):
new_parts = []
for part in message["parts"]:
if isinstance(part, str):
new_parts.append(dedent(Template(part).render(**context))) # type: ignore
else:
new_parts.append(part) # type: ignore
message["parts"] = new_parts

return messages


@overload
def patch(
client: OpenAI,
Expand Down Expand Up @@ -245,12 +148,7 @@ async def new_create_async(
response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
)
if "messages" in new_kwargs:
new_kwargs["messages"] = handle_templating(new_kwargs["messages"], context)

elif "contents" in new_kwargs:
new_kwargs["contents"] = handle_templating(new_kwargs["contents"], context)

new_kwargs = handle_templating(new_kwargs, context)

response = await retry_async(
func=func, # type: ignore
Expand Down Expand Up @@ -280,13 +178,7 @@ def new_create_sync(
response_model=response_model, mode=mode, **kwargs
)

if "messages" in new_kwargs:
new_kwargs["messages"] = handle_templating(new_kwargs["messages"], context)
elif "message" in new_kwargs and "chat_history" in new_kwargs:
new_kwargs = handle_cohere_templating(new_kwargs, context)

elif "contents" in new_kwargs:
new_kwargs["contents"] = handle_templating(new_kwargs["contents"], context)
new_kwargs = handle_templating(new_kwargs, context)

response = retry_sync(
func=func, # type: ignore
Expand Down
97 changes: 97 additions & 0 deletions instructor/templating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# type: ignore[all]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The # type: ignore[all] comment should be justified with a reason or removed if not necessary, especially in library code.

from typing import Any, Dict
from jinja2 import Template
from textwrap import dedent


def apply_template(text: str, context: Dict[str, Any]) -> str:
"""Apply Jinja2 template to the given text."""
return dedent(Template(text).render(**context))


def process_message(message: Dict[str, Any], context: Dict[str, Any]) -> None:
"""Process a single message, applying templates to its content."""
# VertexAI Support
if hasattr(message, "parts") and isinstance(message.parts, list):
import vertexai.generative_models as gm

message.parts = [
(
gm.Part.from_text(apply_template(part.text, context))
if hasattr(part, "text")
else part
)
for part in message.parts
]
return

# OpenAI format
if isinstance(message.get("content"), str):
message["content"] = apply_template(message["content"], context)
return

# Anthropic format
if isinstance(message.get("content"), list):
for part in message["content"]:
if (
isinstance(part, dict)
and part.get("type") == "text"
and isinstance(part.get("text"), str)
):
part["text"] = apply_template(part["text"], context)
return

# Gemini Support
if isinstance(message.get("parts"), list):
message["parts"] = [
apply_template(part, context) if isinstance(part, str) else part
for part in message["parts"]
]
return

# Cohere format
if isinstance(message.get("message"), str):
message["message"] = apply_template(message["message"], context)


def handle_templating(
kwargs: Dict[str, Any], context: Dict[str, Any] | None = None
) -> Dict[str, Any]:
"""
Handle templating for messages using the provided context.

This function processes messages, applying Jinja2 templating to their content
using the provided context. It supports various message formats including
OpenAI, Anthropic, Cohere, VertexAI, and Gemini.

Args:
kwargs (Dict[str, Any]): Keyword arguments being passed to the create method.
context (Dict[str, Any] | None, optional): Variables to use in templating. Defaults to None.

Returns:
Dict[str, Any]: The processed kwargs with templated content.

Raises:
ValueError: If no recognized message format is found in kwargs.
"""
if not context:
return kwargs

new_kwargs = kwargs.copy()

# Handle Cohere's message field
if "message" in new_kwargs:
new_kwargs["message"] = apply_template(new_kwargs["message"], context)
new_kwargs["chat_history"] = handle_templating(
new_kwargs["chat_history"], context
)
return new_kwargs

messages = new_kwargs.get("messages") or new_kwargs.get("contents")
if not messages:
raise ValueError("Expected 'message', 'messages' or 'contents' in kwargs")

for message in messages:
process_message(message, context)

return new_kwargs
Loading
Loading