Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use requests with attested connection pooling #22

Merged
merged 3 commits into from
Aug 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 103 additions & 57 deletions python-package/src/promptguard/promptguard_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,26 @@
This module exposes wrappers around API calls to the PromptGuard service.
"""
import json
import threading
from dataclasses import dataclass
from http import HTTPStatus
from http.client import HTTPException
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

from atls import AttestedHTTPSConnection, AttestedTLSContext
from atls.validators import AZ_AAS_GLOBAL_JKUS, AzAasAciValidator, Validator
import requests
from atls.utils.requests import HTTPAAdapter
from atls.validators import Validator
from atls.validators.azure.aas import PUBLIC_JKUS, AciValidator
from promptguard.authentication import get_api_key
from promptguard.configuration import get_server_config

# Global requests session to leverage connection pooling to in turn avoid
# establishing a new connection for each request to the service.
_session: Optional[requests.Session] = None

# Protects the global requests session when creating it for the first time.
_session_lock: threading.Lock = threading.Lock()


@dataclass
class SanitizeResponse:
Expand All @@ -24,75 +34,94 @@ class SanitizeResponse:
The sanitized form of the input texts without PII. List has the same
dimensions as the input_texts list.
secret_entropy : str
A set of bytes encoded as a string which contains context
needed to desanitize the entities in sanitized_text. Should
be passed along to the desanitize endpoint.
A set of bytes encoded as a string that contains the context needed to
desanitize the entities in sanitized_text; it must be passed along to
the desanitize endpoint.
"""

sanitized_texts: List[str]
secure_context: str


def sanitize(input_texts: List[str]) -> SanitizeResponse:
def sanitize(
input_texts: List[str],
retries: Optional[int] = None,
timeout: Optional[int] = None,
) -> SanitizeResponse:
"""
Takes in a list of text prompts and returns a list of
sanitized texts with PII redacted from it.
Takes in a list of prompts and returns a list of sanitized prompts with PII
redacted from it.

Parameters
----------
input_text : list of str
List of prompt that you want to be sanitized together.
List of prompts to sanitize together.

Returns
-------
SanitizeResponse
The anonymzied version of input_texts without PII and
a secret entropy value.
The anonymized version of input_texts without PII and a secret entropy
value.
"""
response = _send_request_to_promptguard_service(
endpoint="/sanitize", payload={"input_texts": input_texts}
endpoint="sanitize",
payload={"input_texts": input_texts},
retries=retries,
timeout=timeout,
)
return SanitizeResponse(**json.loads(response))


@dataclass
class DesanitizeResponse:
"""
Class representing the return value of the desanitize method
Class representing the return value of the desanitize method.

Attributes
----------
desanitized_text : str
The desanitized form of the input text with PII added back in
The desanitized form of the input text with PII added back in.
"""

desanitized_text: str


def desanitize(sanitized_text: str, secure_context: str) -> DesanitizeResponse:
def desanitize(
sanitized_text: str,
secure_context: str,
retries: Optional[int] = None,
timeout: Optional[int] = None,
) -> DesanitizeResponse:
"""
Takes in a sanitized response and returns the desanitized
text with PII added back to it.
Takes in a sanitized response and returns the desanitized text with PII
added back to it.

Parameters
----------
sanitized_text : str
Sanitized response that you want to be desanitized.
secure_context : str
Secret entropy value that should have been returned by
the call to `sanitize`.
Secret entropy value that should have been returned by the call to
sanitize.
retries : int, optional
The number of retries to submit a request to the service before giving
up when errors occur.
timeout : int, optional
The number of seconds to wait until a request to the service times out.

Returns
-------
DesanitizeResponse
The deanonymzied version of `sanitized_text` with PII added back in.
The deanonymzied version of sanitized_text with PII added back in.
"""
response = _send_request_to_promptguard_service(
endpoint="/desanitize",
endpoint="desanitize",
payload={
"sanitized_text": sanitized_text,
"secure_context": secure_context,
},
retries=retries,
timeout=timeout,
)
return DesanitizeResponse(**json.loads(response))

Expand All @@ -101,59 +130,76 @@ def desanitize(sanitized_text: str, secure_context: str) -> DesanitizeResponse:


def _send_request_to_promptguard_service(
endpoint: str, payload: Dict[str, Union[str, List[str]]]
endpoint: str,
payload: Dict[str, Union[str, List[str]]],
retries: Optional[int] = None,
timeout: Optional[int] = None,
) -> str:
"""
Helper method which takes in the name of the endpoint, and a
payload dictionary, and converts it into the form needed to send
the request to the Promptguard service. Returns the response
recieved if its successful, and raises an error otherwise.
Helper method which takes in the name of the endpoint and a payload
dictionary, and converts it into the form needed to send the request to the
PromptGuard service. Returns the response received if it's successful, and
raises an error otherwise.

Parameters
----------
endpoint : str
The name of the endpoint you are trying to hit
The name of the endpoint you are trying to hit.
payload : dict
The payload of the request as a dictionary
The payload of the request as a dictionary.
retries : int, optional
The number of retries to submit a request to the service before giving
up when errors occur.
timeout : int, optional
The number of seconds to wait until a request to the service times out.

Returns
-------
str
The response body returned by the request, only returned
if the request was successful
The response body returned by the request, only returned if the request
was successful.
"""

api_key = get_api_key()
hostname, port = get_server_config()
global _session
global _session_lock

ctx = AttestedTLSContext(_get_default_validators())
conn = AttestedHTTPSConnection(hostname, ctx, port)
with _session_lock:
if _session is None:
_session = requests.Session()
_session.mount("httpa://", HTTPAAdapter(_get_default_validators()))

headers = {"Authorization": f"Bearer {api_key}"}
api_key = get_api_key()
hostname, port = get_server_config()

try:
conn.request(
"POST",
endpoint,
json.dumps(payload),
headers,
)
if retries is None:
retries = 3

conn_except: ConnectionError
while retries > 0:
try:
response = _session.request(
"POST",
f"httpa://{hostname}:{port}/{endpoint}",
headers={"Authorization": f"Bearer {api_key}"},
data=json.dumps(payload),
timeout=timeout,
)

response = conn.getresponse()
response_code = response.status_code
response_text = response.text

response_code = response.getcode()
response_body = response.read()
response_text = response_body.decode()
if response_code != HTTPStatus.OK:
raise HTTPException(
f"Error response from the PromptGuard server: "
f"[HTTP {response_code}] {response_text}"
)

if response_code != HTTPStatus.OK:
raise HTTPException(
f"Error response from the PromptGuard server: "
f"[HTTP {response_code}] {response_text}"
)
return response_text
except ConnectionError as e:
conn_except = e
retries -= 1

return response_text
finally:
conn.close()
raise conn_except


def _get_default_validators() -> List[Validator]:
Expand All @@ -166,6 +212,6 @@ def _get_default_validators() -> List[Validator]:
list of Validator
One or more aTLS validators
"""
az_aas_aci_validator = AzAasAciValidator(jkus=AZ_AAS_GLOBAL_JKUS)
aci_validator = AciValidator(jkus=PUBLIC_JKUS)

return [az_aas_aci_validator]
return [aci_validator]
Loading