From 10e31fe9c2c4e1ddef838db724d627ee42486f66 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Tue, 20 Feb 2024 09:38:03 -0600 Subject: [PATCH 01/11] support top level list in rail, fail gracefully for pydantic --- guardrails/classes/output_type.py | 4 ++-- guardrails/classes/validation_outcome.py | 2 +- guardrails/cli/validate.py | 4 ++-- guardrails/guard.py | 9 ++++++++ guardrails/rail.py | 27 +++++++++++++++------- guardrails/schema/json_schema.py | 29 ++++++++++++------------ guardrails/schema/string_schema.py | 6 +---- guardrails/utils/json_utils.py | 6 +++-- guardrails/utils/reask_utils.py | 4 ++-- guardrails/validator_base.py | 12 ++++++++++ tests/unit_tests/test_validators.py | 2 +- 11 files changed, 68 insertions(+), 37 deletions(-) diff --git a/guardrails/classes/output_type.py b/guardrails/classes/output_type.py index 8c1c26135..f40714047 100644 --- a/guardrails/classes/output_type.py +++ b/guardrails/classes/output_type.py @@ -1,3 +1,3 @@ -from typing import Dict, TypeVar +from typing import Dict, List, TypeVar -OT = TypeVar("OT", str, Dict) +OT = TypeVar("OT", str, Dict, List) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 3e1ab45eb..4a7586e8b 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -9,7 +9,7 @@ from guardrails.utils.reask_utils import ReAsk -class ValidationOutcome(Generic[OT], ArbitraryModel): +class ValidationOutcome(ArbitraryModel, Generic[OT]): raw_llm_output: Optional[str] = Field( description="The raw, unchanged output from the LLM call.", default=None ) diff --git a/guardrails/cli/validate.py b/guardrails/cli/validate.py index df38a5125..5281639a9 100644 --- a/guardrails/cli/validate.py +++ b/guardrails/cli/validate.py @@ -1,5 +1,5 @@ import json -from typing import Dict, Union +from typing import Dict, List, Union import typer @@ -7,7 +7,7 @@ from guardrails.cli.guardrails import guardrails -def validate_llm_output(rail: str, llm_output: str) -> Union[str, Dict, None]: +def validate_llm_output(rail: str, llm_output: str) -> Union[str, Dict, List, None]: """Validate guardrails.yml file.""" guard = Guard.from_rail(rail) result = guard.parse(llm_output) diff --git a/guardrails/guard.py b/guardrails/guard.py index f896ec47b..f9dff781f 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -220,6 +220,10 @@ def from_rail( return cast( Guard[str], cls(rail=rail, num_reasks=num_reasks, tracer=tracer) ) + elif rail.output_type == "list": + return cast( + Guard[List], cls(rail=rail, num_reasks=num_reasks, tracer=tracer) + ) return cast(Guard[Dict], cls(rail=rail, num_reasks=num_reasks, tracer=tracer)) @classmethod @@ -247,6 +251,10 @@ def from_rail_string( return cast( Guard[str], cls(rail=rail, num_reasks=num_reasks, tracer=tracer) ) + elif rail.output_type == "list": + return cast( + Guard[List], cls(rail=rail, num_reasks=num_reasks, tracer=tracer) + ) return cast(Guard[Dict], cls(rail=rail, num_reasks=num_reasks, tracer=tracer)) @classmethod @@ -272,6 +280,7 @@ def from_pydantic( reask_prompt=reask_prompt, reask_instructions=reask_instructions, ) + # TODO: Add List[BaseModel] support return cast( Guard[Dict], cls(rail, num_reasks=num_reasks, base_model=output_class, tracer=tracer), diff --git a/guardrails/rail.py b/guardrails/rail.py index ec8e9d34a..2ed68ae7f 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -6,6 +6,7 @@ from lxml import etree as ET from pydantic import BaseModel +from guardrails.datatypes import List from guardrails.prompt import Instructions, Prompt from guardrails.schema import JsonSchema, Schema, StringSchema from guardrails.utils.xml_utils import cast_xml_to_string @@ -65,8 +66,11 @@ def output_type(self): if isinstance(self.output_schema, StringSchema): return "str" - else: - return "dict" + elif isinstance(self.output_schema, JsonSchema) and isinstance( + self.output_schema.root_datatype, List + ): + return "list" + return "dict" @classmethod def from_pydantic( @@ -226,18 +230,25 @@ def load_output_schema_from_xml( Returns: A Schema object. """ + schema_type = root.attrib["type"] if "type" in root.attrib else "object" # If root contains a `type="string"` attribute, then it's a StringSchema - if "type" in root.attrib and root.attrib["type"] == "string": + if schema_type == "string": return StringSchema.from_xml( root, reask_prompt_template=reask_prompt, reask_instructions_template=reask_instructions, ) - return JsonSchema.from_xml( - root, - reask_prompt_template=reask_prompt, - reask_instructions_template=reask_instructions, - ) + elif schema_type in ["object", "list"]: + return JsonSchema.from_xml( + root, + reask_prompt_template=reask_prompt, + reask_instructions_template=reask_instructions, + ) + else: + raise ValueError( + "The type attribute of the tag must be one of:" + ' "string", "object", or "list"' + ) @staticmethod def load_string_schema_from_string( diff --git a/guardrails/schema/json_schema.py b/guardrails/schema/json_schema.py index fb5e119cf..a71596186 100644 --- a/guardrails/schema/json_schema.py +++ b/guardrails/schema/json_schema.py @@ -8,7 +8,9 @@ from guardrails import validator_service from guardrails.classes.history import Iteration -from guardrails.datatypes import Choice, DataType, Object +from guardrails.datatypes import Choice, DataType +from guardrails.datatypes import List as ListDataType +from guardrails.datatypes import Object from guardrails.llm_providers import ( AsyncOpenAICallable, AsyncOpenAIChatCallable, @@ -35,7 +37,7 @@ prune_obj_for_reasking, ) from guardrails.utils.telemetry_utils import trace_validation_result -from guardrails.validator_base import FailResult, check_refrain_in_dict, filter_in_dict +from guardrails.validator_base import FailResult, check_refrain, filter_in_schema class JsonSchema(Schema): @@ -43,7 +45,7 @@ class JsonSchema(Schema): def __init__( self, - schema: Object, + schema: Union[Object, ListDataType], reask_prompt_template: Optional[str] = None, reask_instructions_template: Optional[str] = None, ) -> None: @@ -163,7 +165,12 @@ def from_xml( if "strict" in root.attrib and root.attrib["strict"] == "true": strict = True - schema = Object.from_xml(root, strict=strict) + schema_type = root.attrib["type"] if "type" in root.attrib else "object" + + if schema_type == "list": + schema = ListDataType.from_xml(root, strict=strict) + else: + schema = Object.from_xml(root, strict=strict) return cls( schema, @@ -302,9 +309,6 @@ def validate( if data is None: return None - if not isinstance(data, dict): - raise TypeError(f"Argument `data` must be a dictionary, not {type(data)}.") - validated_response = deepcopy(data) if not verify_schema_against_json( @@ -337,14 +341,14 @@ def validate( iteration=iteration, ) - if check_refrain_in_dict(validated_response): + if check_refrain(validated_response): # If the data contains a `Refrain` value, we return an empty # dictionary. logger.debug("Refrain detected.") validated_response = {} # Remove all keys that have `Filter` values. - validated_response = filter_in_dict(validated_response) + validated_response = filter_in_schema(validated_response) # TODO: Capture error messages once Top Level error handling is merged in trace_validation_result( @@ -371,9 +375,6 @@ async def async_validate( if data is None: return None - if not isinstance(data, dict): - raise TypeError(f"Argument `data` must be a dictionary, not {type(data)}.") - validated_response = deepcopy(data) if not verify_schema_against_json( @@ -406,14 +407,14 @@ async def async_validate( iteration=iteration, ) - if check_refrain_in_dict(validated_response): + if check_refrain(validated_response): # If the data contains a `Refain` value, we return an empty # dictionary. logger.debug("Refrain detected.") validated_response = {} # Remove all keys that have `Filter` values. - validated_response = filter_in_dict(validated_response) + validated_response = filter_in_schema(validated_response) # TODO: Capture error messages once Top Level error handling is merged in trace_validation_result( diff --git a/guardrails/schema/string_schema.py b/guardrails/schema/string_schema.py index 0c9f1d8b3..f3eaf1d52 100644 --- a/guardrails/schema/string_schema.py +++ b/guardrails/schema/string_schema.py @@ -19,11 +19,7 @@ from guardrails.utils.constants import constants from guardrails.utils.reask_utils import FieldReAsk, ReAsk from guardrails.utils.telemetry_utils import trace_validation_result -from guardrails.validator_base import ( - ValidatorSpec, - check_refrain_in_dict, - filter_in_dict, -) +from guardrails.validator_base import ValidatorSpec, check_refrain_in_dict, filter_in_dict class StringSchema(Schema): diff --git a/guardrails/utils/json_utils.py b/guardrails/utils/json_utils.py index cccdce8a5..fd272d4ae 100644 --- a/guardrails/utils/json_utils.py +++ b/guardrails/utils/json_utils.py @@ -267,7 +267,9 @@ def verify( ) -def generate_type_skeleton_from_schema(schema: Object) -> Placeholder: +def generate_type_skeleton_from_schema( + schema: Union[Object, ListDataType] +) -> Placeholder: """Generate a JSON skeleton from an XML schema.""" def _recurse_schema(schema: DataType): @@ -327,7 +329,7 @@ def _recurse_schema(schema: DataType): def verify_schema_against_json( - schema: Object, + schema: Union[Object, ListDataType], generated_json: Dict[str, Any], prune_extra_keys: bool = False, coerce_types: bool = False, diff --git a/guardrails/utils/reask_utils.py b/guardrails/utils/reask_utils.py index 4e45f617b..c42aba16c 100644 --- a/guardrails/utils/reask_utils.py +++ b/guardrails/utils/reask_utils.py @@ -86,9 +86,9 @@ def _gather_reasks_in_list( def get_pruned_tree( - root: ObjectType, + root: Union[ObjectType, ListType], reasks: Optional[List[FieldReAsk]] = None, -) -> ObjectType: +) -> Union[ObjectType, ListType]: """Prune tree of any elements that are not in `reasks`. Return the tree with only the elements that are keys of `reasks` and diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index e31c40ac0..225a04eae 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -76,6 +76,12 @@ def check_refrain_in_dict(schema: Dict) -> bool: return False +def check_refrain(schema: Union[List, Dict]) -> bool: + if isinstance(schema, List): + return check_refrain_in_list(schema) + return check_refrain_in_dict(schema) + + def filter_in_list(schema: List) -> List: """Remove out all Filter objects from a list. @@ -130,6 +136,12 @@ def filter_in_dict(schema: Dict) -> Dict: return filtered_dict +def filter_in_schema(schema: Union[Dict, List]) -> Union[Dict, List]: + if isinstance(schema, List): + return filter_in_list(schema) + return filter_in_dict(schema) + + validators_registry = {} types_to_validators = defaultdict(list) diff --git a/tests/unit_tests/test_validators.py b/tests/unit_tests/test_validators.py index 28ddf37e3..8aea93410 100644 --- a/tests/unit_tests/test_validators.py +++ b/tests/unit_tests/test_validators.py @@ -69,7 +69,7 @@ ({"a": 1}, False), ], ) -def test_check_refrain(input_dict, expected): +def test_check_refrain_in_dict(input_dict, expected): assert check_refrain_in_dict(input_dict) == expected From fadf99c61193f4703bfb29f2e1c5cf48fcd7e227 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Tue, 20 Feb 2024 15:44:26 -0600 Subject: [PATCH 02/11] list of pydantic model support --- guardrails/guard.py | 11 ++++++--- guardrails/rail.py | 10 ++++----- guardrails/run.py | 8 +++++-- guardrails/schema/json_schema.py | 36 +++++++++++++++++++++++++++--- guardrails/schema/string_schema.py | 6 ++++- 5 files changed, 57 insertions(+), 14 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index f9dff781f..77f044518 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -80,7 +80,9 @@ def __init__( self, rail: Optional[Rail] = None, num_reasks: Optional[int] = None, - base_model: Optional[Type[BaseModel]] = None, + base_model: Optional[ + Union[Type[BaseModel], Type[List[Type[BaseModel]]]] + ] = None, tracer: Optional[Tracer] = None, ): """Initialize the Guard with optional Rail instance, num_reasks, and @@ -260,7 +262,7 @@ def from_rail_string( @classmethod def from_pydantic( cls, - output_class: Type[BaseModel], + output_class: Union[Type[BaseModel], Type[List[Type[BaseModel]]]], prompt: Optional[str] = None, instructions: Optional[str] = None, num_reasks: Optional[int] = None, @@ -280,7 +282,10 @@ def from_pydantic( reask_prompt=reask_prompt, reask_instructions=reask_instructions, ) - # TODO: Add List[BaseModel] support + if rail.output_type == "list": + return cast( + Guard[List], cls(rail, num_reasks=num_reasks, base_model=output_class) + ) return cast( Guard[Dict], cls(rail, num_reasks=num_reasks, base_model=output_class, tracer=tracer), diff --git a/guardrails/rail.py b/guardrails/rail.py index 2ed68ae7f..7cde6ea8b 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -1,12 +1,12 @@ """Rail class.""" import warnings from dataclasses import dataclass -from typing import Optional, Sequence, Type +from typing import List, Optional, Sequence, Type, Union from lxml import etree as ET from pydantic import BaseModel -from guardrails.datatypes import List +from guardrails.datatypes import List as ListDataType from guardrails.prompt import Instructions, Prompt from guardrails.schema import JsonSchema, Schema, StringSchema from guardrails.utils.xml_utils import cast_xml_to_string @@ -67,7 +67,7 @@ def output_type(self): if isinstance(self.output_schema, StringSchema): return "str" elif isinstance(self.output_schema, JsonSchema) and isinstance( - self.output_schema.root_datatype, List + self.output_schema.root_datatype, ListDataType ): return "list" return "dict" @@ -75,7 +75,7 @@ def output_type(self): @classmethod def from_pydantic( cls, - output_class: Type[BaseModel], + output_class: Union[Type[BaseModel], Type[List[Type[BaseModel]]]], prompt: Optional[str] = None, instructions: Optional[str] = None, reask_prompt: Optional[str] = None, @@ -267,7 +267,7 @@ def load_string_schema_from_string( @staticmethod def load_json_schema_from_pydantic( - output_class: Type[BaseModel], + output_class: Union[Type[BaseModel], Type[List[Type[BaseModel]]]], reask_prompt_template: Optional[str] = None, reask_instructions_template: Optional[str] = None, ): diff --git a/guardrails/run.py b/guardrails/run.py index 4af15f76b..0d133e4c4 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -64,7 +64,9 @@ def __init__( msg_history_schema: Optional[StringSchema] = None, metadata: Optional[Dict[str, Any]] = None, output: Optional[str] = None, - base_model: Optional[Type[BaseModel]] = None, + base_model: Optional[ + Union[Type[BaseModel], Type[List[Type[BaseModel]]]] + ] = None, full_schema_reask: bool = False, ): if prompt: @@ -676,7 +678,9 @@ def __init__( msg_history_schema: Optional[StringSchema] = None, metadata: Optional[Dict[str, Any]] = None, output: Optional[str] = None, - base_model: Optional[Type[BaseModel]] = None, + base_model: Optional[ + Union[Type[BaseModel], Type[List[Type[BaseModel]]]] + ] = None, full_schema_reask: bool = False, ): super().__init__( diff --git a/guardrails/schema/json_schema.py b/guardrails/schema/json_schema.py index a71596186..f555925f2 100644 --- a/guardrails/schema/json_schema.py +++ b/guardrails/schema/json_schema.py @@ -1,6 +1,17 @@ import json from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + Type, + Union, + cast, + get_args, + get_origin, +) from lxml import etree as ET from pydantic import BaseModel @@ -36,8 +47,10 @@ get_pruned_tree, prune_obj_for_reasking, ) +from guardrails.utils.safe_get import safe_get from guardrails.utils.telemetry_utils import trace_validation_result from guardrails.validator_base import FailResult, check_refrain, filter_in_schema +from guardrails.validatorsattr import ValidatorsAttr class JsonSchema(Schema): @@ -181,13 +194,30 @@ def from_xml( @classmethod def from_pydantic( cls, - model: Type[BaseModel], + model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]], reask_prompt_template: Optional[str] = None, reask_instructions_template: Optional[str] = None, ) -> Self: strict = False - schema = convert_pydantic_model_to_datatype(model, strict=strict) + type_origin = get_origin(model) + + if type_origin == list: + item_types = get_args(model) + if len(item_types) > 1: + raise ValueError("List data type must have exactly one child.") + item_type = safe_get(item_types, 0) + if not item_type or not issubclass(item_type, BaseModel): + raise ValueError("List item type must be a Pydantic model.") + item_schema = convert_pydantic_model_to_datatype(item_type, strict=strict) + children = {"item": item_schema} + validators_attr = ValidatorsAttr.from_validators( + [], ListDataType.tag, strict + ) + schema = ListDataType(children, validators_attr, False, None, None) + else: + pydantic_model = cast(Type[BaseModel], model) + schema = convert_pydantic_model_to_datatype(pydantic_model, strict=strict) return cls( schema, diff --git a/guardrails/schema/string_schema.py b/guardrails/schema/string_schema.py index f3eaf1d52..0c9f1d8b3 100644 --- a/guardrails/schema/string_schema.py +++ b/guardrails/schema/string_schema.py @@ -19,7 +19,11 @@ from guardrails.utils.constants import constants from guardrails.utils.reask_utils import FieldReAsk, ReAsk from guardrails.utils.telemetry_utils import trace_validation_result -from guardrails.validator_base import ValidatorSpec, check_refrain_in_dict, filter_in_dict +from guardrails.validator_base import ( + ValidatorSpec, + check_refrain_in_dict, + filter_in_dict, +) class StringSchema(Schema): From 9687cc8672c0e9bc8294f35f729741bf543e7437 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 21 Feb 2024 11:50:29 -0600 Subject: [PATCH 03/11] handle list is not a class exception --- guardrails/utils/pydantic_utils/v2.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/guardrails/utils/pydantic_utils/v2.py b/guardrails/utils/pydantic_utils/v2.py index fd5aa9d3a..e8e75d598 100644 --- a/guardrails/utils/pydantic_utils/v2.py +++ b/guardrails/utils/pydantic_utils/v2.py @@ -63,12 +63,15 @@ def add_validator( def is_pydantic_base_model(type_annotation: Any) -> Union[Type[BaseModel], None]: """Check if a type_annotation is a Pydantic BaseModel.""" - if ( - type_annotation is not None - and isinstance(type_annotation, type) - and issubclass(type_annotation, BaseModel) - ): - return type_annotation + try: + if ( + type_annotation is not None + and isinstance(type_annotation, type) + and issubclass(type_annotation, BaseModel) + ): + return type_annotation + except TypeError: + pass return None From 8949b69b69a4051f7b003d0668fe19d219297ae4 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Fri, 23 Feb 2024 01:28:07 -0800 Subject: [PATCH 04/11] add integration tests --- tests/integration_tests/mock_llm_outputs.py | 3 +- .../test_assets/lists_object.py | 33 +++++++++++++++ tests/integration_tests/test_guard.py | 42 +++++++++++++++++-- 3 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 tests/integration_tests/test_assets/lists_object.py diff --git a/tests/integration_tests/mock_llm_outputs.py b/tests/integration_tests/mock_llm_outputs.py index e4288bc44..285194ebd 100644 --- a/tests/integration_tests/mock_llm_outputs.py +++ b/tests/integration_tests/mock_llm_outputs.py @@ -7,7 +7,7 @@ ) from guardrails.utils.llm_response import LLMResponse -from .test_assets import entity_extraction, pydantic, python_rail, string +from .test_assets import entity_extraction, pydantic, python_rail, string, lists_object class MockOpenAICallable(OpenAICallable): @@ -36,6 +36,7 @@ def _invoke_llm(self, prompt, *args, **kwargs): python_rail.VALIDATOR_PARALLELISM_PROMPT_1: python_rail.VALIDATOR_PARALLELISM_RESPONSE_1, # noqa: E501 python_rail.VALIDATOR_PARALLELISM_PROMPT_2: python_rail.VALIDATOR_PARALLELISM_RESPONSE_2, # noqa: E501 python_rail.VALIDATOR_PARALLELISM_PROMPT_3: python_rail.VALIDATOR_PARALLELISM_RESPONSE_3, # noqa: E501 + lists_object.LIST_PROMPT: lists_object.LIST_OUTPUT, } try: diff --git a/tests/integration_tests/test_assets/lists_object.py b/tests/integration_tests/test_assets/lists_object.py new file mode 100644 index 000000000..1d0ca34d6 --- /dev/null +++ b/tests/integration_tests/test_assets/lists_object.py @@ -0,0 +1,33 @@ +from typing import List +from pydantic import BaseModel + + +LIST_PROMPT = """Create a list of items that may be found in a grocery store. + +Json Output: + +""" + + +LIST_OUTPUT = """[{"name": "apple", "price": 1.0}, {"name": "banana", "price": 0.5}, {"name": "orange", "price": 1.5}]""" # noqa: E501 + + +class Item(BaseModel): + name: str + price: float + + +PYDANTIC_RAIL_WITH_LIST = List[Item] + + +RAIL_SPEC_WITH_LIST = """ + + + + + + + + Create a list of items that may be found in a grocery store. + +""" \ No newline at end of file diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 6e756c42e..2ab48a587 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -1,7 +1,7 @@ import enum import json import os -from typing import Optional, Union +from typing import List, Optional, Union import pytest from pydantic import BaseModel @@ -19,6 +19,7 @@ MockOpenAICallable, MockOpenAIChatCallable, entity_extraction, + lists_object, ) from .test_assets import pydantic, string @@ -56,6 +57,7 @@ def rail_spec(): @pytest.fixture(scope="module") def llm_output(): + """Mock LLM output for the rail_spec.""" return """ { "dummy_string": "Some string", @@ -77,6 +79,7 @@ def llm_output(): @pytest.fixture(scope="module") def validated_output(): + """Mock validated output for the rail_spec.""" return { "dummy_string": "Some string", "dummy_integer": 42, @@ -94,8 +97,7 @@ def validated_output(): def guard_initializer( rail: Union[str, BaseModel], prompt: str, instructions: Optional[str] = None ) -> Guard: - """Helper function to initialize a Guard object using the correct - method.""" + """Helper function to initialize a Guard using the correct method.""" if isinstance(rail, str): return Guard.from_rail_string(rail) @@ -130,7 +132,11 @@ def guard_initializer( def test_entity_extraction_with_reask( mocker, rail, prompt, test_full_schema_reask, multiprocessing_validators ): - """Test that the entity extraction works with re-asking.""" + """Test that the entity extraction works with re-asking. + + This test creates a Guard for the entity extraction use case. It performs + a single call to the LLM and then re-asks the LLM for a second time. + """ mocker.patch("guardrails.llm_providers.OpenAICallable", new=MockOpenAICallable) mocker.patch( "guardrails.validators.Validator.run_in_separate_process", @@ -851,3 +857,31 @@ def invoke( result = chain.invoke({"topic": topic}) assert result == output + + +@pytest.mark.parametrize( + "rail,prompt", + [ + ( + lists_object.PYDANTIC_RAIL_WITH_LIST, + "Create a list of items that may be found in a grocery store." + ), + (lists_object.RAIL_SPEC_WITH_LIST, None) + ], +) +def test_guard_with_top_level_list_return_type(mocker, rail, prompt): + # Create a Guard with a top level list return type + + # Mock the LLM + mocker.patch("guardrails.llm_providers.OpenAICallable", new=MockOpenAICallable) + + guard = guard_initializer(rail, prompt=prompt) + + output = guard(llm_api=get_static_openai_create_func()) + + # Validate the output + assert output.validated_output == [ + {"name": "apple", "price": 1.0}, + {"name": "banana", "price": 0.5}, + {"name": "orange", "price": 1.5}, + ] From 7e220dd074c97285f4f73c567503f5730acce9d6 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Mon, 26 Feb 2024 20:45:57 -0800 Subject: [PATCH 05/11] more tests --- tests/unit_tests/test_json_schema.py | 51 +++++++++++++++++++++++++ tests/unit_tests/test_rail.py | 19 +++++++++ tests/unit_tests/test_validator_base.py | 43 +++++++++++++++++++++ 3 files changed, 113 insertions(+) create mode 100644 tests/unit_tests/test_json_schema.py create mode 100644 tests/unit_tests/test_validator_base.py diff --git a/tests/unit_tests/test_json_schema.py b/tests/unit_tests/test_json_schema.py new file mode 100644 index 000000000..d9ca5be03 --- /dev/null +++ b/tests/unit_tests/test_json_schema.py @@ -0,0 +1,51 @@ +import sys +from typing import List +from pydantic import BaseModel +import pytest +from lxml import etree as ET + +from guardrails.schema.json_schema import JsonSchema + +# Set up XML parser +XMLPARSER = ET.XMLParser(encoding="utf-8") + + +def test_json_schema_from_xml_outermost_list(): + rail_spec = """ + + + + + +""" + try: + xml = ET.fromstring(rail_spec, parser=XMLPARSER) + JsonSchema.from_xml(xml) + except Exception as e: + pytest.fail(f"JsonSchema.from_xml() raised an exception: {e}") + + +def test_json_schema_from_pydantic_outermost_list_typing(): + class Foo(BaseModel): + field: str + + # Test 1: typing.List with BaseModel + try: + JsonSchema.from_pydantic(model=List[Foo]) + except Exception as e: + pytest.fail(f"JsonSchema.from_pydantic() raised an exception: {e}") + + +@pytest.mark.skipif( + sys.version_info.major <= 3 and sys.version_info.minor <= 8, + reason="requires Python > 3.8" +) +def test_json_schema_from_pydantic_outermost_list(): + class Foo(BaseModel): + field: str + + # Test 1: typing.List with BaseModel + try: + JsonSchema.from_pydantic(model=list[Foo]) + except Exception as e: + pytest.fail(f"JsonSchema.from_pydantic() raised an exception: {e}") diff --git a/tests/unit_tests/test_rail.py b/tests/unit_tests/test_rail.py index 7803bbe47..3e019d8d8 100644 --- a/tests/unit_tests/test_rail.py +++ b/tests/unit_tests/test_rail.py @@ -151,6 +151,25 @@ def test_rail_list_with_object(): Rail.from_string(rail_spec) +def test_rail_outermost_list(): + rail_spec = """ + + + + + + + + + +Hello world + + + +""" + Rail.from_string(rail_spec) + + def test_format_deprecated(): rail_spec = """ diff --git a/tests/unit_tests/test_validator_base.py b/tests/unit_tests/test_validator_base.py new file mode 100644 index 000000000..081cf5260 --- /dev/null +++ b/tests/unit_tests/test_validator_base.py @@ -0,0 +1,43 @@ +# Write tests for check_refrain and filter_in_schema in guardrails/validator_base.py +import pytest + +from guardrails.validator_base import check_refrain, filter_in_schema, Refrain, Filter + + +@pytest.mark.parametrize( + "schema,expected", + [ + (["a", Refrain(), "b"], True), + (["a", "b"], False), + (["a", ["b", Refrain(), "c"], "d"], True), + (["a", ["b", "c", "d"], "e"], False), + (["a", {"b": Refrain(), "c": "d"}, "e"], True), + (["a", {"b": "c", "d": "e"}, "f"], False), + ({"a": "b"}, False), + ({"a": Refrain()}, True), + ({"a": "b", "c": {"d": Refrain()}}, True), + ({"a": "b", "c": {"d": "e"}}, False), + ({"a": "b", "c": ["d", Refrain()]}, True), + ({"a": "b", "c": ["d", "e"]}, False), + ] +) +def test_check_refrain(schema, expected): + assert check_refrain(schema) == expected + + +@pytest.mark.parametrize( + "schema,expected", + [ + (["a", Filter(), "b"], ["a", "b"]), + (["a", ["b", Filter(), "c"], "d"], ["a", ["b", "c"], "d"]), + (["a", ["b", "c", "d"], "e"], ["a", ["b", "c", "d"], "e"]), + (["a", {"b": Filter(), "c": "d"}, "e"], ["a", {"c": "d"}, "e"]), + ({"a": "b"}, {"a": "b"}), + ({"a": Filter()}, {}), + ({"a": "b", "c": {"d": Filter()}}, {"a": "b", "c": {}}), + ({"a": "b", "c": {"d": "e"}}, {"a": "b", "c": {"d": "e"}}), + ({"a": "b", "c": ["d", Filter()]}, {"a": "b", "c": ["d"]}), + ] +) +def test_filter_in_schema(schema, expected): + assert filter_in_schema(schema) == expected From 73251d45f463253c43a487eb98136c5d93462d30 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Mon, 26 Feb 2024 20:47:37 -0800 Subject: [PATCH 06/11] fix lint issues --- tests/integration_tests/mock_llm_outputs.py | 2 +- tests/integration_tests/test_assets/lists_object.py | 4 ++-- tests/integration_tests/test_guard.py | 9 +++++---- tests/unit_tests/test_json_schema.py | 11 ++++++----- tests/unit_tests/test_validator_base.py | 6 +++--- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/integration_tests/mock_llm_outputs.py b/tests/integration_tests/mock_llm_outputs.py index 285194ebd..310c695ef 100644 --- a/tests/integration_tests/mock_llm_outputs.py +++ b/tests/integration_tests/mock_llm_outputs.py @@ -7,7 +7,7 @@ ) from guardrails.utils.llm_response import LLMResponse -from .test_assets import entity_extraction, pydantic, python_rail, string, lists_object +from .test_assets import entity_extraction, lists_object, pydantic, python_rail, string class MockOpenAICallable(OpenAICallable): diff --git a/tests/integration_tests/test_assets/lists_object.py b/tests/integration_tests/test_assets/lists_object.py index 1d0ca34d6..4310700cc 100644 --- a/tests/integration_tests/test_assets/lists_object.py +++ b/tests/integration_tests/test_assets/lists_object.py @@ -1,6 +1,6 @@ from typing import List -from pydantic import BaseModel +from pydantic import BaseModel LIST_PROMPT = """Create a list of items that may be found in a grocery store. @@ -30,4 +30,4 @@ class Item(BaseModel): Create a list of items that may be found in a grocery store. -""" \ No newline at end of file +""" diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 2ab48a587..97e5d9012 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -134,8 +134,9 @@ def test_entity_extraction_with_reask( ): """Test that the entity extraction works with re-asking. - This test creates a Guard for the entity extraction use case. It performs - a single call to the LLM and then re-asks the LLM for a second time. + This test creates a Guard for the entity extraction use case. It + performs a single call to the LLM and then re-asks the LLM for a + second time. """ mocker.patch("guardrails.llm_providers.OpenAICallable", new=MockOpenAICallable) mocker.patch( @@ -864,9 +865,9 @@ def invoke( [ ( lists_object.PYDANTIC_RAIL_WITH_LIST, - "Create a list of items that may be found in a grocery store." + "Create a list of items that may be found in a grocery store.", ), - (lists_object.RAIL_SPEC_WITH_LIST, None) + (lists_object.RAIL_SPEC_WITH_LIST, None), ], ) def test_guard_with_top_level_list_return_type(mocker, rail, prompt): diff --git a/tests/unit_tests/test_json_schema.py b/tests/unit_tests/test_json_schema.py index d9ca5be03..accd9086a 100644 --- a/tests/unit_tests/test_json_schema.py +++ b/tests/unit_tests/test_json_schema.py @@ -1,8 +1,9 @@ import sys from typing import List -from pydantic import BaseModel + import pytest from lxml import etree as ET +from pydantic import BaseModel from guardrails.schema.json_schema import JsonSchema @@ -28,7 +29,7 @@ def test_json_schema_from_xml_outermost_list(): def test_json_schema_from_pydantic_outermost_list_typing(): class Foo(BaseModel): field: str - + # Test 1: typing.List with BaseModel try: JsonSchema.from_pydantic(model=List[Foo]) @@ -37,13 +38,13 @@ class Foo(BaseModel): @pytest.mark.skipif( - sys.version_info.major <= 3 and sys.version_info.minor <= 8, - reason="requires Python > 3.8" + sys.version_info.major <= 3 and sys.version_info.minor <= 8, + reason="requires Python > 3.8", ) def test_json_schema_from_pydantic_outermost_list(): class Foo(BaseModel): field: str - + # Test 1: typing.List with BaseModel try: JsonSchema.from_pydantic(model=list[Foo]) diff --git a/tests/unit_tests/test_validator_base.py b/tests/unit_tests/test_validator_base.py index 081cf5260..962ee55f5 100644 --- a/tests/unit_tests/test_validator_base.py +++ b/tests/unit_tests/test_validator_base.py @@ -1,7 +1,7 @@ # Write tests for check_refrain and filter_in_schema in guardrails/validator_base.py import pytest -from guardrails.validator_base import check_refrain, filter_in_schema, Refrain, Filter +from guardrails.validator_base import Filter, Refrain, check_refrain, filter_in_schema @pytest.mark.parametrize( @@ -19,7 +19,7 @@ ({"a": "b", "c": {"d": "e"}}, False), ({"a": "b", "c": ["d", Refrain()]}, True), ({"a": "b", "c": ["d", "e"]}, False), - ] + ], ) def test_check_refrain(schema, expected): assert check_refrain(schema) == expected @@ -37,7 +37,7 @@ def test_check_refrain(schema, expected): ({"a": "b", "c": {"d": Filter()}}, {"a": "b", "c": {}}), ({"a": "b", "c": {"d": "e"}}, {"a": "b", "c": {"d": "e"}}), ({"a": "b", "c": ["d", Filter()]}, {"a": "b", "c": ["d"]}), - ] + ], ) def test_filter_in_schema(schema, expected): assert filter_in_schema(schema) == expected From 74c3bd59eb8704ce462efdfe35ebfc4f49f30baf Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Mon, 26 Feb 2024 20:51:33 -0800 Subject: [PATCH 07/11] fix lint --- tests/integration_tests/test_guard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 97e5d9012..d3a7e7b8f 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -1,7 +1,7 @@ import enum import json import os -from typing import List, Optional, Union +from typing import Optional, Union import pytest from pydantic import BaseModel From 0a4ef0f71dbe7c164a3754862a6be6c8008bb519 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Mon, 26 Feb 2024 22:39:57 -0800 Subject: [PATCH 08/11] add docs --- .../structured_data_with_guardrails.mdx | 219 ++++++++++++++++++ docusaurus/sidebars.js | 2 +- 2 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 docs/how_to_guides/structured_data_with_guardrails.mdx diff --git a/docs/how_to_guides/structured_data_with_guardrails.mdx b/docs/how_to_guides/structured_data_with_guardrails.mdx new file mode 100644 index 000000000..644392929 --- /dev/null +++ b/docs/how_to_guides/structured_data_with_guardrails.mdx @@ -0,0 +1,219 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Generate structured data with Guardrails AI + +Guardrails AI is effective for generating structured data across from a variety of LLMs. This guide contains +the following: +1. General instructions on generating structured data from Guardrails using `Pydantic` or Markup (i.e. `RAIL`), and +2. Examples to generate structured data using `Pydantic` or Markup. + +## Syntax for generating structured data + +There are two ways to generate structured data with Guardrails AI: using `Pydantic` or Markup (i.e. `RAIL`). + +1. **Pydantic**: In order to generate structured data with Pydantic models, create a Pydantic model with the desired fields and types, then create a `Guard` object that uses the Pydantic model to generate structured data, and finally call the LLM of your choice with the `guard` object to generate structured data. +2. **RAIL**: In order to generate structured data with RAIL specs, create a RAIL spec with the desired fields and types, then create a `Guard` object that uses the RAIL spec to generate structured data, and finally call the LLM of your choice with the `guard` object to generate structured data. + +Below is the syntax for generating structured data with Guardrails AI using `Pydantic` or Markup (i.e. `RAIL`). + + + + In order to generate structured data, first create a Pydantic model with the desired fields and types. + ```python + from pydantic import BaseModel + + class Person(BaseModel): + name: str + age: int + is_employed: bool + ``` + + Then, create a `Guard` object that uses the Pydantic model to generate structured data. + ```python + from guardrails import Guard + + guard = Guard.from_pydantic(Person) + ``` + + Finally, call the LLM of your choice with the `guard` object to generate structured data. + ```python + import openai + + res = guard( + openai.chat.completion.create, + model="gpt-3.5-turbo", + ) + ``` + + + In order to generate structured data, first create a RAIL spec with the desired fields and types. + ```xml + + + + + + + ``` + + Then, create a `Guard` object that uses the RAIL spec to generate structured data. + ```python + from guardrails import Guard + + guard = Guard.from_s(""" + + + + + + + + """) + ``` + + Finally, call the LLM of your choice with the `guard` object to generate structured data. + ```python + import openai + + res = guard( + openai.chat.completion.create, + model="gpt-3.5-turbo", + ) + ``` + + + +## Generate a JSON object with simple types + + + + ```json + { + "name": "John Doe", + "age": 30, + "is_employed": true + } + ``` + + + ```python + from pydantic import BaseModel + + class Person(BaseModel): + name: str + age: int + is_employed: bool + ``` + + + ```xml + + + + + + + + ``` + + + + +## Generate a dictionary of nested types + + + + ```json + { + "name": "John Doe", + "age": 30, + "is_employed": true, + "address": { + "street": "123 Main St", + "city": "Anytown", + "zip": "12345" + } + } + ``` + + + ```python + from pydantic import BaseModel + + class Address(BaseModel): + street: str + city: str + zip: str + + class Person(BaseModel): + name: str + age: int + is_employed: bool + address: Address + ``` + + + ```xml + + + + + + + + + + + + + ``` + + + + +## Generate a list of types + + + + ```json + [ + { + "name": "John Doe", + "age": 30, + "is_employed": true + }, + { + "name": "Jane Smith", + "age": 25, + "is_employed": false + } + ] + ``` + + + ```python + from pydantic import BaseModel + + class Person(BaseModel): + name: str + age: int + is_employed: bool + + people = list[Person] + ``` + + + ```xml + + + + + + + + + + ``` + + diff --git a/docusaurus/sidebars.js b/docusaurus/sidebars.js index 4d6874170..e043b2903 100644 --- a/docusaurus/sidebars.js +++ b/docusaurus/sidebars.js @@ -44,7 +44,7 @@ const sidebars = { type: "category", label: "How-to Guides", collapsed: true, - items: ["how_to_guides/logs", "how_to_guides/streaming", "how_to_guides/llm_api_wrappers", "how_to_guides/rail", "how_to_guides/envvars" ], + items: ["how_to_guides/logs", "how_to_guides/streaming", "how_to_guides/llm_api_wrappers", "how_to_guides/rail", "how_to_guides/envvars", "how_to_guides/structured_data_with_guardrails" ], }, "the_guard", { From 2f23a1cc5640a47c0171488290b7bac815eada52 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Mon, 4 Mar 2024 23:08:56 -0800 Subject: [PATCH 09/11] add __str__ methods for call and iteration --- guardrails/classes/history/call.py | 3 +++ guardrails/classes/history/iteration.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/guardrails/classes/history/call.py b/guardrails/classes/history/call.py index 9cf20f405..0623085fd 100644 --- a/guardrails/classes/history/call.py +++ b/guardrails/classes/history/call.py @@ -366,3 +366,6 @@ def tree(self) -> Tree: ) return tree + + def __str__(self) -> str: + return pretty_repr(self) diff --git a/guardrails/classes/history/iteration.py b/guardrails/classes/history/iteration.py index dac3fc9b6..0fcb82ef3 100644 --- a/guardrails/classes/history/iteration.py +++ b/guardrails/classes/history/iteration.py @@ -189,3 +189,6 @@ def create_msg_history_table( style="on #F0FFF0", ), ) + + def __str__(self) -> str: + return pretty_repr(self) From 201ea565f329c04c021f3fb63d3e0f0363e8ad22 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Mon, 4 Mar 2024 23:46:11 -0800 Subject: [PATCH 10/11] fix bug in pydantic type casting for union of lists and dicts --- guardrails/classes/output_type.py | 2 +- guardrails/classes/validation_outcome.py | 4 ++++ guardrails/utils/reask_utils.py | 4 ++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/guardrails/classes/output_type.py b/guardrails/classes/output_type.py index f40714047..2aff9fc82 100644 --- a/guardrails/classes/output_type.py +++ b/guardrails/classes/output_type.py @@ -1,3 +1,3 @@ from typing import Dict, List, TypeVar -OT = TypeVar("OT", str, Dict, List) +OT = TypeVar("OT", str, List, Dict) diff --git a/guardrails/classes/validation_outcome.py b/guardrails/classes/validation_outcome.py index 4a7586e8b..4f9df16ae 100644 --- a/guardrails/classes/validation_outcome.py +++ b/guardrails/classes/validation_outcome.py @@ -1,6 +1,7 @@ from typing import Generic, Iterator, Optional, Tuple, Union, cast from pydantic import Field +from rich.pretty import pretty_repr from guardrails.classes.history import Call, Iteration from guardrails.classes.output_type import OT @@ -83,3 +84,6 @@ def __iter__( def __getitem__(self, keys): """Get a subset of the ValidationOutcome's fields.""" return iter(getattr(self, k) for k in keys) + + def __str__(self) -> str: + return pretty_repr(self) diff --git a/guardrails/utils/reask_utils.py b/guardrails/utils/reask_utils.py index c42aba16c..d6a1a4faf 100644 --- a/guardrails/utils/reask_utils.py +++ b/guardrails/utils/reask_utils.py @@ -82,6 +82,10 @@ def _gather_reasks_in_list( valid_output = deepcopy(validated_output) _gather_reasks_in_dict(validated_output, valid_output) return reasks, valid_output + elif isinstance(validated_output, List): + valid_output = deepcopy(validated_output) + _gather_reasks_in_list(validated_output, valid_output) + return reasks, valid_output return reasks, None From e05040a34089ecadcfe1892278ed8294a9da7c84 Mon Sep 17 00:00:00 2001 From: Shreya Rajpal Date: Tue, 5 Mar 2024 00:16:31 -0800 Subject: [PATCH 11/11] fix type hints --- guardrails/schema/json_schema.py | 2 +- guardrails/utils/reask_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/guardrails/schema/json_schema.py b/guardrails/schema/json_schema.py index f555925f2..07c1c1a33 100644 --- a/guardrails/schema/json_schema.py +++ b/guardrails/schema/json_schema.py @@ -453,7 +453,7 @@ async def async_validate( return validated_response - def introspect(self, data: Any) -> Tuple[List[ReAsk], Optional[Dict]]: + def introspect(self, data: Any) -> Tuple[List[ReAsk], Union[Dict, List, None]]: if isinstance(data, SkeletonReAsk): return [data], None elif isinstance(data, NonParseableReAsk): diff --git a/guardrails/utils/reask_utils.py b/guardrails/utils/reask_utils.py index d6a1a4faf..926eb7fd4 100644 --- a/guardrails/utils/reask_utils.py +++ b/guardrails/utils/reask_utils.py @@ -26,8 +26,8 @@ class NonParseableReAsk(ReAsk): def gather_reasks( - validated_output: Optional[Union[str, Dict, ReAsk]] -) -> Tuple[List[ReAsk], Optional[Dict]]: + validated_output: Optional[Union[str, Dict, List, ReAsk]] +) -> Tuple[List[ReAsk], Union[Dict, List, None]]: """Traverse output and gather all ReAsk objects. Args: