Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added initial validator #1

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 50 additions & 50 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,49 @@

| Developed by | Guardrails AI |
| --- | --- |
| Date of development | Feb 15, 2024 |
| Validator type | Format |
| Date of development | Aug 15, 2024 |
| Validator type | Moderation |
| Blog | |
| License | Apache 2 |
| Input/Output | Output |

## Description

### Intended Use
This validator is a template for creating other validators, but for demonstrative purposes it ensures that a generated output is the literal `pass`.

> ⚠️ This validator is a remote inference only validator so remote inferencing must be enabled during `guardrails configure`

This validator is for usage for moderating both user prompts and LLM output responses to prevent harmful topics from surfacing in both scenarios. It is based on [LlamaGuard 7B](https://huggingface.co/meta-llama/LlamaGuard-7b) which is in turn based on [LLama 2](https://arxiv.org/abs/2307.09288).


The set of policies which can be used are the following which are accessed directly from the `LlamaGuard7B` validator class.

- `LlamaGuard7B.POLICY__NO_VIOLENCE_HATE`
- `LlamaGuard7B.POLICY__NO_SEXUAL_CONTENT`
- `LlamaGuard7B.POLICY__NO_CRIMINAL_PLANNING`
- `LlamaGuard7B.POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS`
- `LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS`
- `LlamaGuard7B.POLICY__NO_ENOURAGE_SELF_HARM`

If no policies supplied, it will apply all policies. Otherwise it will only enforce the policies supplied.

### Requirements

* Dependencies:
- guardrails-ai>=0.4.0

* Foundation model access keys:
- OPENAI_API_KEY

## Installation

```bash
$ guardrails hub install hub://guardrails/validator_template
$ guardrails hub install hub://guardrails/llamaguard_7b
```

or

```python
from guardrails import install
install("hub://guardrails/llamaguard_7b")
```

## Usage Examples
Expand All @@ -34,58 +54,44 @@ $ guardrails hub install hub://guardrails/validator_template
In this example, we apply the validator to a string output generated by an LLM.

```python
# Import Guard and Validator
from guardrails.hub import ValidatorTemplate
from guardrails import Guard
from guardrails import Guard, OnFailAction
from guardrails.hub import LlamaGuard7B

# Setup Guard
guard = Guard().use(
ValidatorTemplate
LlamaGuard7B,
validation_endpoint="https://guardrails--fa-hg-lg7b.modal.run/v1/chat/classification",
AlejandroEsquivel marked this conversation as resolved.
Show resolved Hide resolved
policies=[LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS], # not suppliying `policies` kwarg applies all policies
on_fail=OnFailAction.EXCEPTION
)

guard.validate("pass") # Validator passes
guard.validate("fail") # Validator fails
try:
guard.validate("How can I get weeds out of my garbage bag after cutting my lawn") # Guardrail passes
print("Expected: Validation passes")
except Exception as e:
print("Unexpected: ",e)

try:
guard.validate("How can I get weed for when cutting my lawn") # Guardrail passes
print("Unexpected: Validation passes")
except Exception as e:
print("Expected: ",e)
```

### Validating JSON output via Python

In this example, we apply the validator to a string field of a JSON output generated by an LLM.
Output:

```python
# Import Guard and Validator
from pydantic import BaseModel, Field
from guardrails.hub import ValidatorTemplate
from guardrails import Guard

# Initialize Validator
val = ValidatorTemplate()

# Create Pydantic BaseModel
class Process(BaseModel):
process_name: str
status: str = Field(validators=[val])

# Create a Guard to check for valid Pydantic output
guard = Guard.from_pydantic(output_class=Process)

# Run LLM output generating JSON through guard
guard.parse("""
{
"process_name": "templating",
"status": "pass"
}
""")
```
Expected: Validation passes
Expected: Validation failed for field with errors: Prompt contains unsafe content. Classification: unsafe, Violated Policy: POLICY__NO_ILLEGAL_DRUGS
```

# API Reference

**`__init__(self, on_fail="noop")`**
<ul>
Initializes a new instance of the ValidatorTemplate class.
Initializes a new instance of the `LlamaGuard7B` class.

**Parameters**
- **`arg_1`** *(str)*: A placeholder argument to demonstrate how to use init arguments.
- **`arg_2`** *(str)*: Another placeholder argument to demonstrate how to use init arguments.
- **`policies`** *(List[str])*: A list of policies that can be either `LlamaGuard7B.POLICY__NO_VIOLENCE_HATE`, `LlamaGuard7B.POLICY__NO_SEXUAL_CONTENT`, `LlamaGuard7B.POLICY__NO_CRIMINAL_PLANNING`, `LlamaGuard7B.POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS`, `LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS`, and `LlamaGuard7B.POLICY__NO_ENOURAGE_SELF_HARM`
- **`on_fail`** *(str, Callable)*: The policy to enact when a validator fails. If `str`, must be one of `reask`, `fix`, `filter`, `refrain`, `noop`, `exception` or `fix_reask`. Otherwise, must be a function that is called when the validator fails.
</ul>
<br/>
Expand All @@ -101,10 +107,4 @@ Note:

**Parameters**
- **`value`** *(Any)*: The input value to validate.
- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. Keys and values must match the expectations of this validator.


| Key | Type | Description | Default |
| --- | --- | --- | --- |
| `key1` | String | Description of key1's role. | N/A |
</ul>
- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. No additional metadata keys are needed for this validator.
63 changes: 51 additions & 12 deletions inference/serving-non-optimized-fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
from typing import Optional

import modal

Expand Down Expand Up @@ -101,11 +100,27 @@ def tgi_app():

from typing import List
from pydantic import BaseModel
import logging

TOKEN = os.getenv("TOKEN")
if TOKEN is None:
raise ValueError("Please set the TOKEN environment variable")

# Create a logger
logger = logging.getLogger(MODEL_ALIAS)
logger.setLevel(logging.DEBUG)

# Create a handler for logging to stdout
stdout_handler = logging.StreamHandler()
stdout_handler.setLevel(logging.DEBUG)

# Create a formatter for the log messages
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
stdout_handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(stdout_handler)

volume.reload() # ensure we have the latest version of the weights

app = fastapi.FastAPI()
Expand All @@ -129,6 +144,24 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
detail="Invalid authentication credentials",
)
return {"username": "authenticated_user"}

@app.exception_handler(Exception)
def error_handler(request, exc):
status_code = 500
detail = "Internal Server Error"
logger.exception(exc)
if isinstance(exc, fastapi.HTTPException):
status_code = exc.status_code
detail = exc.detail
return fastapi.responses.JSONResponse(
status_code=status_code,
content={
"status": status_code,
"response": {
"detail": detail,
}
},
)

router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])

Expand All @@ -137,11 +170,13 @@ class ChatMessages(BaseModel):
content: str

class ChatClassificationRequestBody(BaseModel):
policies: Optional[List[str]] = None
chat: List[ChatMessages]


@router.post("/v1/chat/classification")
async def chat_classification_response(body: ChatClassificationRequestBody):
policies = body.policies
chat = body.model_dump().get("chat",[])

print("Serving request for chat classification...")
Expand All @@ -154,24 +189,28 @@ async def chat_classification_response(body: ChatClassificationRequestBody):
cleaned_response = response.lower().strip()

if "unsafe" in cleaned_response:
is_unsafe = True
split_cln_response = response.strip().split(os.linesep)
subclass = split_cln_response[1] if len(split_cln_response) > 1 else None

if policies and subclass in policies:
is_unsafe = True
elif policies and subclass not in policies:
is_unsafe = False
else:
is_unsafe = True
else:
is_unsafe = False

return {
"class": "unsafe" if is_unsafe else "safe",
"subclass": subclass,
"response": response
"status": 200,
"response": {
"class": "unsafe" if is_unsafe else "safe",
"subclass": subclass,
"applied_policies": policies,
"raw_output": response
}
}


app.include_router(router)
return app


# @app.local_entrypoint()
# def main():
# model = Model()
# model.generate.remote()
Loading
Loading