Skip to content

Commit

Permalink
use requests with attested connection pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
HernanGatta committed Aug 27, 2023
1 parent 70e9e94 commit b7b231c
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions python-package/src/promptguard/promptguard_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
This module exposes wrappers around API calls to the PromptGuard service.
"""
import json
import threading
from dataclasses import dataclass
from http import HTTPStatus
from http.client import HTTPException
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

from atls import AttestedHTTPSConnection, AttestedTLSContext
from atls.validators import AZ_AAS_GLOBAL_JKUS, AzAasAciValidator, Validator
import requests
from atls.utils.requests import HTTPAAdapter
from atls.validators import Validator
from atls.validators.azure.aas import PUBLIC_JKUS, AciValidator
from promptguard.authentication import get_api_key
from promptguard.configuration import get_server_config

_session: Optional[requests.Session] = None
_sessionLock: threading.Lock = threading.Lock()


@dataclass
class SanitizeResponse:
Expand Down Expand Up @@ -39,7 +45,7 @@ def sanitize(input_texts: List[str]) -> SanitizeResponse:
a secret entropy value.
"""
response = _send_request_to_promptguard_service(
endpoint="/sanitize", payload={"input_texts": input_texts}
endpoint="sanitize", payload={"input_texts": input_texts}
)
return SanitizeResponse(**json.loads(response))

Expand Down Expand Up @@ -68,7 +74,7 @@ def desanitize(sanitized_text: str, secure_context: str) -> DesanitizeResponse:
The deanonymzied version of `sanitized_text` with PII added back in.
"""
response = _send_request_to_promptguard_service(
endpoint="/desanitize",
endpoint="desanitize",
payload={
"sanitized_text": sanitized_text,
"secure_context": secure_context,
Expand Down Expand Up @@ -103,37 +109,34 @@ def _send_request_to_promptguard_service(
if the request was successful
"""

global _session
global _sessionLock

with _sessionLock:
if _session is None:
_session = requests.Session()
_session.mount("httpa://", HTTPAAdapter(_get_default_validators()))

api_key = get_api_key()
hostname, port = get_server_config()

ctx = AttestedTLSContext(_get_default_validators())
conn = AttestedHTTPSConnection(hostname, ctx, port)
response = _session.request(
"POST",
f"httpa://{hostname}:{port}/{endpoint}",
headers={"Authorization": f"Bearer {api_key}"},
data=json.dumps(payload),
)

headers = {"Authorization": f"Bearer {api_key}"}
response_code = response.status_code
response_text = response.text

try:
conn.request(
"POST",
endpoint,
json.dumps(payload),
headers,
if response_code != HTTPStatus.OK:
raise HTTPException(
f"Error response from the PromptGuard server: "
f"[HTTP {response_code}] {response_text}"
)

response = conn.getresponse()

response_code = response.getcode()
response_body = response.read()
response_text = response_body.decode()

if response_code != HTTPStatus.OK:
raise HTTPException(
f"Error response from the PromptGuard server: "
f"[HTTP {response_code}] {response_text}"
)

return response_text
finally:
conn.close()
return response_text


def _get_default_validators() -> List[Validator]:
Expand All @@ -146,6 +149,6 @@ def _get_default_validators() -> List[Validator]:
list of Validator
One or more aTLS validators
"""
az_aas_aci_validator = AzAasAciValidator(jkus=AZ_AAS_GLOBAL_JKUS)
aci_validator = AciValidator(jkus=PUBLIC_JKUS)

return [az_aas_aci_validator]
return [aci_validator]

0 comments on commit b7b231c

Please sign in to comment.