Skip to content

Commit

Permalink
ANTHROPIC_JSON: allow control characters in JSON strings if strict=Fa…
Browse files Browse the repository at this point in the history
…lse (#644)
  • Loading branch information
voberoi authored May 1, 2024
1 parent 8cd5c43 commit 6491aec
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 9 deletions.
26 changes: 20 additions & 6 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import json
import logging
from functools import wraps
from typing import Annotated, Any, Optional, TypeVar, cast

from docstring_parser import parse
from openai.types.chat import ChatCompletion
from pydantic import BaseModel, Field, TypeAdapter, ConfigDict, create_model # type: ignore - remove once Pydantic is updated
from pydantic import ( # type: ignore - remove once Pydantic is updated
BaseModel,
ConfigDict,
Field,
TypeAdapter,
create_model,
)

from instructor.exceptions import IncompleteOutputException
from instructor.mode import Mode
from instructor.utils import extract_json_from_codeblock, classproperty

from instructor.utils import classproperty, extract_json_from_codeblock

T = TypeVar("T")

Expand Down Expand Up @@ -141,9 +148,16 @@ def parse_anthropic_json(

text = completion.content[0].text
extra_text = extract_json_from_codeblock(text)
return cls.model_validate_json(
extra_text, context=validation_context, strict=strict
)

if strict:
return cls.model_validate_json(
extra_text, context=validation_context, strict=True
)
else:
# Allow control characters.
parsed = json.loads(extra_text, strict=False)
# Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/
return cls.model_validate(parsed, context=validation_context, strict=False)

@classmethod
def parse_cohere_tools(
Expand Down
63 changes: 60 additions & 3 deletions tests/test_function_calls.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import TypeVar

import pytest
from pydantic import BaseModel
from anthropic.types import Message, Usage
from openai.resources.chat.completions import ChatCompletion
from pydantic import BaseModel, ValidationError

from instructor import openai_schema, OpenAISchema
import instructor
from instructor import OpenAISchema, openai_schema
from instructor.exceptions import IncompleteOutputException


T = TypeVar("T")


Expand Down Expand Up @@ -51,6 +52,24 @@ def mock_completion(request: T) -> ChatCompletion:

return completion

@pytest.fixture # type: ignore[misc]
def mock_anthropic_message(request: T) -> Message:
data_content = '{\n"data": "Claude says hi"\n}'
if hasattr(request, "param"):
data_content = request.param.get("data_content", data_content)
return Message(
id="test_id",
content=[{ "type": "text", "text": data_content }],
model="claude-3-haiku-20240307",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=Usage(
input_tokens=100,
output_tokens=100,
)
)

def test_openai_schema() -> None:
@openai_schema
Expand Down Expand Up @@ -122,3 +141,41 @@ def test_incomplete_output_exception_raise(
) -> None:
with pytest.raises(IncompleteOutputException):
test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS)

def test_anthropic_no_exception(
test_model: type[OpenAISchema], mock_anthropic_message: Message
) -> None:
test_model_instance = test_model.from_response(mock_anthropic_message, mode=instructor.Mode.ANTHROPIC_JSON)
assert test_model_instance.data == "Claude says hi"

@pytest.mark.parametrize(
"mock_anthropic_message",
[{"data_content": '{\n"data": "Claude likes\ncontrol\ncharacters"\n}'}],
indirect=True,
) # type: ignore[misc]
def test_control_characters_not_allowed_in_anthropic_json_strict_mode(
test_model: type[OpenAISchema], mock_anthropic_message: Message
) -> None:
with pytest.raises(ValidationError) as exc_info:
test_model.from_response(
mock_anthropic_message, mode=instructor.Mode.ANTHROPIC_JSON, strict=True
)

# https://docs.pydantic.dev/latest/errors/validation_errors/#json_invalid
exc = exc_info.value
assert len(exc.errors()) == 1
assert exc.errors()[0]["type"] == "json_invalid"
assert "control character" in exc.errors()[0]["msg"]

@pytest.mark.parametrize(
"mock_anthropic_message",
[{"data_content": '{\n"data": "Claude likes\ncontrol\ncharacters"\n}'}],
indirect=True,
) # type: ignore[misc]
def test_control_characters_allowed_in_anthropic_json_non_strict_mode(
test_model: type[OpenAISchema], mock_anthropic_message: Message
) -> None:
test_model_instance = test_model.from_response(
mock_anthropic_message, mode=instructor.Mode.ANTHROPIC_JSON, strict=False
)
assert test_model_instance.data == "Claude likes\ncontrol\ncharacters"

0 comments on commit 6491aec

Please sign in to comment.