diff --git a/python-package/sample/connect_aci.py b/python-package/sample/connect_aci.py index a24fc49..722d08a 100755 --- a/python-package/sample/connect_aci.py +++ b/python-package/sample/connect_aci.py @@ -1,3 +1,5 @@ +#!/usr/bin/python3 + # Sample usage of the atls package # # Suppose a simple HTTP server with a single GET endpoint /index is running @@ -17,10 +19,13 @@ # --url /index import argparse -from typing import List, Optional +import ast +from typing import List, Mapping, Optional -from atls import AttestedHTTPSConnection, AttestedTLSContext -from atls.validators import AzAasAciValidator +import requests +from atls import ATLSContext, HTTPAConnection +from atls.utils.requests import HTTPAAdapter +from atls.validators.azure.aas import AciValidator # Parse arguments parser = argparse.ArgumentParser() @@ -36,29 +41,51 @@ parser.add_argument( "--method", default="GET", - help="HTTP method to use in the request " "(default: GET)", + help="HTTP method to use in the request (default: GET)", ) parser.add_argument( "--url", default="/index", - help="URL to perform the HTTP request against " "(default: /index)", + help="URL to perform the HTTP request against (default: /index)", ) parser.add_argument( "--policy", nargs="*", - help="path to a CCE policy in Rego format, may be " - "specified multiple times, once for each allowed policy " - "(default: ignore)", + help="path to a CCE policy in Rego format, may be specified multiple " + "times, once for each allowed policy (default: ignore)", ) parser.add_argument( "--jku", nargs="*", - help="allowed JWKS URL to verify the JKU claim in the AAS " - "JWT token against, may be specified multiple times, one " - "for each allowed value (default: ignore)", + action="extend", + help="allowed JWKS URL to verify the JKU claim in the AAS JWT token " + "against, may be specified multiple times, one for each allowed value " + "(default: ignore)", +) + +parser.add_argument( + "--body", + type=argparse.FileType("r"), + help="path to a file containing the content to include in the request " + "(default: nothing)", +) + +parser.add_argument( + "--headers", + type=argparse.FileType("r"), + help="path to a file containing the string representation of a Python " + "dictionary containing the headers to be sent along with the request " + "(default: none)", +) + +parser.add_argument( + "--use-requests", + action="store_true", + help="use the requests library with the HTTPS/aTLS adapater (default: " + "false)", ) args = parser.parse_args() @@ -81,17 +108,69 @@ # - The JKUs array carries all allowed JWKS URLs, or none if the JKU claim in # the AAS JWT token sent by the server during the aTLS handshake should not # be checked. -validator = AzAasAciValidator(policies=policies, jkus=jkus) +validator = AciValidator(policies=policies, jkus=jkus) + +# Parse provided headers, if any. +headers: Mapping[str, str] = {} +if args.headers is not None: + raw = args.headers.read() + headers = ast.literal_eval(raw) + +# Read in the provided body, if any. +body: Optional[str] = None +if args.body is not None: + body = args.body.read() + + +def use_direct() -> None: + # Set up the aTLS context, including at least one attestation document + # validator (only one need succeed). + ctx = ATLSContext([validator]) + + # Set up the HTTP request machinery using the aTLS context. + conn = HTTPAConnection(args.server, ctx, args.port) + + # Send the HTTP request, and read and print the response in the usual way. + conn.request( + args.method, + args.url, + body, + headers, + ) + + response = conn.getresponse() + code = response.getcode() + + print(f"Status: {code}") + print(f"Response: {response.read().decode()}") + + conn.close() + + +def use_requests() -> None: + session = requests.Session() + + # Mount the HTTP/aTLS adapter such that any URL whose scheme is httpa:// + # results in an HTTPAConnection object that in turn establishes an aTLS + # connection with the server. + session.mount("httpa://", HTTPAAdapter([validator])) -# Set up the aTLS context, including at least one attestation document -# validator (only one need succeed). -ctx = AttestedTLSContext([validator]) + # The rest of the usage of the requests library is as usual. Do remember to + # use session.request from the session object that has the mounted adapter, + # not requests.request, since that's the global request function and has + # therefore no knowledge of the adapter. + response = session.request( + args.method, + f"httpa://{args.server}:{args.port}{args.url}", + data=body, + headers=headers, + ) -# Set up the HTTP request machinery using the aTLS context. -conn = AttestedHTTPSConnection(args.server, ctx, args.port) + print(f"Status: {response.status_code}") + print(f"Response: {response.text}") -# Send the HTTP request, and read and print the response in the usual way. -conn.request(args.method, args.url) -print(conn.getresponse().read().decode()) -conn.close() +if args.use_requests: + use_requests() +else: + use_direct()