Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added initial validator #1

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 51 additions & 12 deletions inference/serving-non-optimized-fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import time
from typing import Optional

import modal

Expand Down Expand Up @@ -101,11 +100,27 @@ def tgi_app():

from typing import List
from pydantic import BaseModel
import logging

TOKEN = os.getenv("TOKEN")
if TOKEN is None:
raise ValueError("Please set the TOKEN environment variable")

# Create a logger
logger = logging.getLogger(MODEL_ALIAS)
logger.setLevel(logging.DEBUG)

# Create a handler for logging to stdout
stdout_handler = logging.StreamHandler()
stdout_handler.setLevel(logging.DEBUG)

# Create a formatter for the log messages
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
stdout_handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(stdout_handler)

volume.reload() # ensure we have the latest version of the weights

app = fastapi.FastAPI()
Expand All @@ -129,6 +144,24 @@ async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
detail="Invalid authentication credentials",
)
return {"username": "authenticated_user"}

@app.exception_handler(Exception)
def error_handler(request, exc):
status_code = 500
detail = "Internal Server Error"
logger.exception(exc)
if isinstance(exc, fastapi.HTTPException):
status_code = exc.status_code
detail = exc.detail
return fastapi.responses.JSONResponse(
status_code=status_code,
content={
"status": status_code,
"response": {
"detail": detail,
}
},
)

router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])

Expand All @@ -137,11 +170,13 @@ class ChatMessages(BaseModel):
content: str

class ChatClassificationRequestBody(BaseModel):
policies: Optional[List[str]] = None
chat: List[ChatMessages]


@router.post("/v1/chat/classification")
async def chat_classification_response(body: ChatClassificationRequestBody):
policies = body.policies
chat = body.model_dump().get("chat",[])

print("Serving request for chat classification...")
Expand All @@ -154,24 +189,28 @@ async def chat_classification_response(body: ChatClassificationRequestBody):
cleaned_response = response.lower().strip()

if "unsafe" in cleaned_response:
is_unsafe = True
split_cln_response = response.strip().split(os.linesep)
subclass = split_cln_response[1] if len(split_cln_response) > 1 else None

if policies and subclass in policies:
is_unsafe = True
elif policies and subclass not in policies:
is_unsafe = False
else:
is_unsafe = True
else:
is_unsafe = False

return {
"class": "unsafe" if is_unsafe else "safe",
"subclass": subclass,
"response": response
"status": 200,
"response": {
"class": "unsafe" if is_unsafe else "safe",
"subclass": subclass,
"applied_policies": policies,
"raw_output": response
}
}


app.include_router(router)
return app


# @app.local_entrypoint()
# def main():
# model = Model()
# model.generate.remote()
112 changes: 91 additions & 21 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Callable, Dict, Optional
import json
from typing import Any, Callable, Dict, List, Optional
from enum import Enum
from guardrails.validator_base import ErrorSpan

from guardrails.validator_base import (
FailResult,
Expand All @@ -7,42 +10,109 @@
Validator,
register_validator,
)
from guardrails.logger import logger

class Policies(str, Enum):
NO_VIOLENCE_HATE = "O1"
AlejandroEsquivel marked this conversation as resolved.
Show resolved Hide resolved
NO_SEXUAL_CONTENT = "O2"
NO_CRIMINAL_PLANNING = "O3"
NO_GUNS_AND_ILLEGAL_WEAPONS = "O4"
NO_ILLEGAL_DRUGS = "O5"
NO_ENOURAGE_SELF_HARM = "O6"

@register_validator(name="guardrails/validator_template", data_type="string")
class ValidatorTemplate(Validator):
"""Validates that {fill in how you validator interacts with the passed value}.

@register_validator(name="guardrails/llamaguard_7b", data_type="string")
class LlamaGuard7B(Validator):
"""
Classifies model inputs or outputs as "safe" or "unsafe" based on certain policies defined by the LlamaGuard-7B model.

**Key Properties**

| Property | Description |
| ----------------------------- | --------------------------------- |
| Name for `format` attribute | `guardrails/validator_template` |
| Name for `format` attribute | `guardrails/llamaguard_7b` |
| Supported data types | `string` |
| Programmatic fix | {If you support programmatic fixes, explain it here. Otherwise `None`} |
| Programmatic fix | None |

Args:
arg_1 (string): {Description of the argument here}
arg_2 (string): {Description of the argument here}
policies (List[Policies]): List of LlamaGuard7B.Policies enum values to enforce.
score_threshold (float): Threshold score for the classification. If the score is above this threshold, the input is considered unsafe.
""" # noqa

# If you don't have any init args, you can omit the __init__ method.
Policies = Policies

def __init__(
self,
arg_1: str,
arg_2: str,
policies: Optional[List[Policies]] = None,
validation_method: Optional[str] = "full",
on_fail: Optional[Callable] = None,
):
super().__init__(on_fail=on_fail, arg_1=arg_1, arg_2=arg_2)
self._arg_1 = arg_1
self._arg_2 = arg_2

super().__init__(
on_fail=on_fail,
validation_method=validation_method,
)

try:
self._policies = [policy.value for policy in policies] if policies else []
except AttributeError as e:
raise ValueError("Invalid policies provided. Please provide a list of LlamaGuard7B.Policies enum values.") from e


def validate(self, value: Any, metadata: Dict = {}) -> ValidationResult:
"""Validates that {fill in how you validator interacts with the passed value}."""
# Add your custom validator logic here and return a PassResult or FailResult accordingly.
if value != "pass": # FIXME

if not value:
raise ValueError("Value cannot be empty.")

(classification, subclass) = self._inference(value)

is_unsafe = classification == "unsafe"

if is_unsafe:
error_span = ErrorSpan(
start=0,
end=len(value),
reason=f"Unsafe content: {value}",
)

find_policy_violated = next(
(policy for policy in self.Policies if policy.value == subclass),
None
)
return FailResult(
error_message="{A descriptive but concise error message about why validation failed}",
fix_value="{The programmtic fix if applicable, otherwise remove this kwarg.}",
error_message=(
f"Prompt contains unsafe content. Classification: {classification}, Violated Policy: {find_policy_violated}"
),
error_spans=[error_span],
)
return PassResult()
else:
return PassResult()


def _inference_local(self, value: str):
raise NotImplementedError("Local inference is not supported for LlamaGuard7B validator.")

def _inference_remote(self, value: str) -> ValidationResult:
"""Remote inference method for this validator."""
request_body = {
"policies": self._policies,
"chat": [
{
"role": "user",
"content": value
}
]
}

response = self._hub_inference_request(json.dumps(request_body), self.validation_endpoint)

status = response.get("status")
if status != 200:
detail = response.get("response",{}).get("detail", "Unknown error")
raise ValueError(f"Failed to get valid response from Llamaguard-7B model. Status: {status}. Detail: {detail}")

response_data = response.get("response")

classification = response_data.get("class")
subclass = response_data.get("subclass")

return (classification, subclass)
Loading