Skip to content

Commit

Permalink
Merge pull request #1 from guardrails-ai/feat/initial-validator
Browse files Browse the repository at this point in the history
Added initial validator
  • Loading branch information
dtam authored Aug 29, 2024
2 parents 9422670 + 141d6d9 commit 99aac90
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 84 deletions.
99 changes: 49 additions & 50 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,49 @@

| Developed by | Guardrails AI |
| --- | --- |
| Date of development | Feb 15, 2024 |
| Validator type | Format |
| Date of development | Aug 15, 2024 |
| Validator type | Moderation |
| Blog | |
| License | Apache 2 |
| Input/Output | Output |

## Description

### Intended Use
This validator is a template for creating other validators, but for demonstrative purposes it ensures that a generated output is the literal `pass`.

> ⚠️ This validator is a remote inference only validator so remote inferencing must be enabled during `guardrails configure`
This validator is for usage for moderating both user prompts and LLM output responses to prevent harmful topics from surfacing in both scenarios. It is based on [LlamaGuard 7B](https://huggingface.co/meta-llama/LlamaGuard-7b) which is in turn based on [LLama 2](https://arxiv.org/abs/2307.09288).


The set of policies which can be used are the following which are accessed directly from the `LlamaGuard7B` validator class.

- `LlamaGuard7B.POLICY__NO_VIOLENCE_HATE`
- `LlamaGuard7B.POLICY__NO_SEXUAL_CONTENT`
- `LlamaGuard7B.POLICY__NO_CRIMINAL_PLANNING`
- `LlamaGuard7B.POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS`
- `LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS`
- `LlamaGuard7B.POLICY__NO_ENOURAGE_SELF_HARM`

If no policies supplied, it will apply all policies. Otherwise it will only enforce the policies supplied.

### Requirements

* Dependencies:
- guardrails-ai>=0.4.0

* Foundation model access keys:
- OPENAI_API_KEY

## Installation

```bash
$ guardrails hub install hub://guardrails/validator_template
$ guardrails hub install hub://guardrails/llamaguard_7b
```

or

```python
from guardrails import install
install("hub://guardrails/llamaguard_7b")
```

## Usage Examples
Expand All @@ -34,58 +54,43 @@ $ guardrails hub install hub://guardrails/validator_template
In this example, we apply the validator to a string output generated by an LLM.

```python
# Import Guard and Validator
from guardrails.hub import ValidatorTemplate
from guardrails import Guard
from guardrails import Guard, OnFailAction
from guardrails.hub import LlamaGuard7B

# Setup Guard
guard = Guard().use(
ValidatorTemplate
LlamaGuard7B,
policies=[LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS], # not suppliying `policies` kwarg applies all policies
on_fail=OnFailAction.EXCEPTION
)

guard.validate("pass") # Validator passes
guard.validate("fail") # Validator fails
try:
guard.validate("How can I get weeds out of my garbage bag after cutting my lawn") # Guardrail passes
print("Expected: Validation passes")
except Exception as e:
print("Unexpected: ",e)

try:
guard.validate("How can I get weed for when cutting my lawn") # Guardrail passes
print("Unexpected: Validation passes")
except Exception as e:
print("Expected: ",e)
```

### Validating JSON output via Python

In this example, we apply the validator to a string field of a JSON output generated by an LLM.
Output:

```python
# Import Guard and Validator
from pydantic import BaseModel, Field
from guardrails.hub import ValidatorTemplate
from guardrails import Guard

# Initialize Validator
val = ValidatorTemplate()

# Create Pydantic BaseModel
class Process(BaseModel):
process_name: str
status: str = Field(validators=[val])

# Create a Guard to check for valid Pydantic output
guard = Guard.from_pydantic(output_class=Process)

# Run LLM output generating JSON through guard
guard.parse("""
{
"process_name": "templating",
"status": "pass"
}
""")
```
Expected: Validation passes
Expected: Validation failed for field with errors: Prompt contains unsafe content. Classification: unsafe, Violated Policy: POLICY__NO_ILLEGAL_DRUGS
```

# API Reference

**`__init__(self, on_fail="noop")`**
<ul>
Initializes a new instance of the ValidatorTemplate class.
Initializes a new instance of the `LlamaGuard7B` class.

**Parameters**
- **`arg_1`** *(str)*: A placeholder argument to demonstrate how to use init arguments.
- **`arg_2`** *(str)*: Another placeholder argument to demonstrate how to use init arguments.
- **`policies`** *(List[str])*: A list of policies that can be either `LlamaGuard7B.POLICY__NO_VIOLENCE_HATE`, `LlamaGuard7B.POLICY__NO_SEXUAL_CONTENT`, `LlamaGuard7B.POLICY__NO_CRIMINAL_PLANNING`, `LlamaGuard7B.POLICY__NO_GUNS_AND_ILLEGAL_WEAPONS`, `LlamaGuard7B.POLICY__NO_ILLEGAL_DRUGS`, and `LlamaGuard7B.POLICY__NO_ENOURAGE_SELF_HARM`
- **`on_fail`** *(str, Callable)*: The policy to enact when a validator fails. If `str`, must be one of `reask`, `fix`, `filter`, `refrain`, `noop`, `exception` or `fix_reask`. Otherwise, must be a function that is called when the validator fails.
</ul>
<br/>
Expand All @@ -101,10 +106,4 @@ Note:

**Parameters**
- **`value`** *(Any)*: The input value to validate.
- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. Keys and values must match the expectations of this validator.


| Key | Type | Description | Default |
| --- | --- | --- | --- |
| `key1` | String | Description of key1's role. | N/A |
</ul>
- **`metadata`** *(dict)*: A dictionary containing metadata required for validation. No additional metadata keys are needed for this validator.
63 changes: 51 additions & 12 deletions inference/serving-non-optimized-fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
from typing import Optional

import modal

Expand Down Expand Up @@ -101,11 +100,27 @@ def tgi_app():

from typing import List
from pydantic import BaseModel
import logging

TOKEN = os.getenv("TOKEN")
if TOKEN is None:
raise ValueError("Please set the TOKEN environment variable")

# Create a logger
logger = logging.getLogger(MODEL_ALIAS)
logger.setLevel(logging.DEBUG)

# Create a handler for logging to stdout
stdout_handler = logging.StreamHandler()
stdout_handler.setLevel(logging.DEBUG)

# Create a formatter for the log messages
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
stdout_handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(stdout_handler)

volume.reload() # ensure we have the latest version of the weights

app = fastapi.FastAPI()
Expand All @@ -129,6 +144,24 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
detail="Invalid authentication credentials",
)
return {"username": "authenticated_user"}

@app.exception_handler(Exception)
def error_handler(request, exc):
status_code = 500
detail = "Internal Server Error"
logger.exception(exc)
if isinstance(exc, fastapi.HTTPException):
status_code = exc.status_code
detail = exc.detail
return fastapi.responses.JSONResponse(
status_code=status_code,
content={
"status": status_code,
"response": {
"detail": detail,
}
},
)

router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])

Expand All @@ -137,11 +170,13 @@ class ChatMessages(BaseModel):
content: str

class ChatClassificationRequestBody(BaseModel):
policies: Optional[List[str]] = None
chat: List[ChatMessages]


@router.post("/v1/chat/classification")
async def chat_classification_response(body: ChatClassificationRequestBody):
policies = body.policies
chat = body.model_dump().get("chat",[])

print("Serving request for chat classification...")
Expand All @@ -154,24 +189,28 @@ async def chat_classification_response(body: ChatClassificationRequestBody):
cleaned_response = response.lower().strip()

if "unsafe" in cleaned_response:
is_unsafe = True
split_cln_response = response.strip().split(os.linesep)
subclass = split_cln_response[1] if len(split_cln_response) > 1 else None

if policies and subclass in policies:
is_unsafe = True
elif policies and subclass not in policies:
is_unsafe = False
else:
is_unsafe = True
else:
is_unsafe = False

return {
"class": "unsafe" if is_unsafe else "safe",
"subclass": subclass,
"response": response
"status": 200,
"response": {
"class": "unsafe" if is_unsafe else "safe",
"subclass": subclass,
"applied_policies": policies,
"raw_output": response
}
}


app.include_router(router)
return app


# @app.local_entrypoint()
# def main():
# model = Model()
# model.generate.remote()
Loading

0 comments on commit 99aac90

Please sign in to comment.