Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for async httpx clients #48

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 168 additions & 23 deletions httpx_auth/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from hashlib import sha256, sha512
from urllib.parse import parse_qs, urlsplit, urlunsplit, urlencode
from typing import Optional, Generator
from typing import Optional, Generator, Union, AsyncGenerator

import httpx

Expand Down Expand Up @@ -59,11 +59,9 @@ def _get_query_parameter(url: str, param_name: str) -> Optional[str]:
return all_values[0] if all_values else None


def request_new_grant_with_post(
url: str, data, grant_name: str, client: httpx.Client
def process_new_grant_response(
response: httpx.Response, grant_name: str,
) -> (str, int):
response = client.post(url, data=data)

if response.is_error:
# As described in https://tools.ietf.org/html/rfc6749#section-5.2
raise InvalidGrantRequest(response)
Expand Down Expand Up @@ -152,6 +150,8 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs):
reaches the actual server. Set it to 0 to deactivate this feature and use the same token until actual expiry.
:param client: httpx.Client instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param async_client: httpx.AsyncClient instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param kwargs: all additional authorization parameters that should be put as body parameters in the token URL.
"""
self.token_url = token_url
Expand All @@ -175,6 +175,7 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs):
# Time is expressed in seconds
self.timeout = int(kwargs.pop("timeout", None) or 60)
self.client = kwargs.pop("client", None)
self.async_client = kwargs.pop("async_client", None)

# As described in https://tools.ietf.org/html/rfc6749#section-4.3.2
self.data = {
Expand All @@ -190,7 +191,7 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs):
all_parameters_in_url = _add_parameters(self.token_url, self.data)
self.state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest()

def auth_flow(
def sync_auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
Expand All @@ -201,13 +202,25 @@ def auth_flow(
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

async def async_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
token = await OAuth2.token_cache.get_token_async(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token_async,
)
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple:
client = self.client or httpx.Client()
self._configure_client(client)
try:
grant_response = client.post(self.token_url, data=self.data)
# As described in https://tools.ietf.org/html/rfc6749#section-4.3.3
token, expires_in = request_new_grant_with_post(
self.token_url, self.data, self.token_field_name, client
token, expires_in = process_new_grant_response(
grant_response, self.token_field_name
)
finally:
# Close client only if it was created by this module
Expand All @@ -216,7 +229,23 @@ def request_new_token(self) -> tuple:
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: httpx.Client):
async def request_new_token_async(self) -> tuple:
client = self.async_client or httpx.AsyncClient()
self._configure_client(client)
try:
grant_response = await client.post(self.token_url, data=self.data)
# As described in https://tools.ietf.org/html/rfc6749#section-4.3.3
token, expires_in = process_new_grant_response(
grant_response, self.token_field_name
)
finally:
# Close client only if it was created by this module
if self.async_client is None:
await client.aclose()
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: Union[httpx.Client, httpx.AsyncClient]):
client.auth = (self.username, self.password)
client.timeout = self.timeout

Expand Down Expand Up @@ -248,6 +277,8 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs)
reaches the actual server. Set it to 0 to deactivate this feature and use the same token until actual expiry.
:param client: httpx.Client instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param async_client: httpx.AsyncClient instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param kwargs: all additional authorization parameters that should be put as query parameter in the token URL.
"""
self.token_url = token_url
Expand All @@ -272,6 +303,7 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs)
self.timeout = int(kwargs.pop("timeout", None) or 60)

self.client = kwargs.pop("client", None)
self.async_client = kwargs.pop("async_client", None)

# As described in https://tools.ietf.org/html/rfc6749#section-4.4.2
self.data = {"grant_type": "client_credentials"}
Expand All @@ -283,7 +315,7 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs)
all_parameters_in_url = _add_parameters(self.token_url, self.data)
self.state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest()

def auth_flow(
def sync_auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
Expand All @@ -294,13 +326,25 @@ def auth_flow(
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

async def async_auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = await OAuth2.token_cache.get_token_async(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token_async,
)
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple:
client = self.client or httpx.Client()
self._configure_client(client)
try:
grant_response = client.post(self.token_url, data=self.data)
# As described in https://tools.ietf.org/html/rfc6749#section-4.4.3
token, expires_in = request_new_grant_with_post(
self.token_url, self.data, self.token_field_name, client
token, expires_in = process_new_grant_response(
grant_response, self.token_field_name
)
finally:
# Close client only if it was created by this module
Expand All @@ -309,7 +353,23 @@ def request_new_token(self) -> tuple:
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: httpx.Client):
async def request_new_token_async(self) -> tuple:
client = self.async_client or httpx.AsyncClient()
self._configure_client(client)
try:
grant_response = await client.post(self.token_url, data=self.data)
# As described in https://tools.ietf.org/html/rfc6749#section-4.4.3
token, expires_in = process_new_grant_response(
grant_response, self.token_field_name
)
finally:
# Close client only if it was created by this module
if self.async_client is None:
await client.aclose()
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: Union[httpx.Client, httpx.AsyncClient]):
client.auth = (self.client_id, self.client_secret)
client.timeout = self.timeout

Expand Down Expand Up @@ -358,6 +418,8 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
:param password: User password in case basic authentication should be used to retrieve token.
:param client: httpx.Client instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param async_client: httpx.AsyncClient instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param kwargs: all additional authorization parameters that should be put as query parameter
in the authorization URL and as body parameters in the token URL.
Usual parameters are:
Expand Down Expand Up @@ -387,6 +449,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
password = kwargs.pop("password", None)
self.auth = (username, password) if username and password else None
self.client = kwargs.pop("client", None)
self.async_client = kwargs.pop("async_client", None)

# As described in https://tools.ietf.org/html/rfc6749#section-4.1.2
code_field_name = kwargs.pop("code_field_name", "code")
Expand Down Expand Up @@ -431,7 +494,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
}
self.token_data.update(kwargs)

def auth_flow(
def sync_auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
Expand All @@ -442,6 +505,17 @@ def auth_flow(
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

async def async_sync_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
token = await OAuth2.token_cache.get_token_async(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token_async,
)
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple:
# Request code
state, code = oauth2_authentication_responses_server.request_new_grant(
Expand All @@ -454,9 +528,10 @@ def request_new_token(self) -> tuple:
client = self.client or httpx.Client()
self._configure_client(client)
try:
grant_response = client.post(self.token_url, data=self.token_data)
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.4
token, expires_in = request_new_grant_with_post(
self.token_url, self.token_data, self.token_field_name, client
token, expires_in = process_new_grant_response(
grant_response, self.token_field_name
)
finally:
# Close client only if it was created by this module
Expand All @@ -465,7 +540,31 @@ def request_new_token(self) -> tuple:
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: httpx.Client):
async def request_new_token_async(self) -> tuple:
# Request code
state, code = oauth2_authentication_responses_server.request_new_grant(
self.code_grant_details
)

# As described in https://tools.ietf.org/html/rfc6749#section-4.1.3
self.token_data["code"] = code

client = self.async_client or httpx.AsyncClient()
self._configure_client(client)
try:
grant_response = await client.post(self.token_url, data=self.token_data)
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.4
token, expires_in = process_new_grant_response(
grant_response, self.token_field_name
)
finally:
# Close client only if it was created by this module
if self.async_client is None:
await client.aclose()
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: Union[httpx.Client, httpx.AsyncClient]):
client.auth = self.auth
client.timeout = self.timeout

Expand Down Expand Up @@ -512,6 +611,8 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
:param code_field_name: Field name containing the code. code by default.
:param client: httpx.Client instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param async_client: httpx.AsyncClient instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param kwargs: all additional authorization parameters that should be put as query parameter
in the authorization URL and as body parameters in the token URL.
Usual parameters are:
Expand All @@ -530,6 +631,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
BrowserAuth.__init__(self, kwargs)

self.client = kwargs.pop("client", None)
self.async_client = kwargs.pop("async_client", None)

self.header_name = kwargs.pop("header_name", None) or "Authorization"
self.header_value = kwargs.pop("header_value", None) or "Bearer {token}"
Expand Down Expand Up @@ -596,7 +698,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
}
self.token_data.update(kwargs)

def auth_flow(
def sync_auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
Expand All @@ -607,6 +709,17 @@ def auth_flow(
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

async def async_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
token = await OAuth2.token_cache.get_token_async(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token_async,
)
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple:
# Request code
state, code = oauth2_authentication_responses_server.request_new_grant(
Expand All @@ -619,9 +732,10 @@ def request_new_token(self) -> tuple:
client = self.client or httpx.Client()
self._configure_client(client)
try:
grant_response = client.post(self.token_url, data=self.token_data)
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.4
token, expires_in = request_new_grant_with_post(
self.token_url, self.token_data, self.token_field_name, client
token, expires_in = process_new_grant_response(
grant_response, self.token_field_name
)
finally:
# Close client only if it was created by this module
Expand All @@ -630,7 +744,31 @@ def request_new_token(self) -> tuple:
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: httpx.Client):
async def request_new_token_async(self) -> tuple:
# Request code
state, code = oauth2_authentication_responses_server.request_new_grant(
self.code_grant_details
)

# As described in https://tools.ietf.org/html/rfc6749#section-4.1.3
self.token_data["code"] = code

client = self.async_client or httpx.AsyncClient()
self._configure_client(client)
try:
grant_response = await client.post(self.token_url, data=self.token_data)
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.4
token, expires_in = process_new_grant_response(
grant_response, self.token_field_name
)
finally:
# Close client only if it was created by this module
if self.async_client is None:
await client.aclose()
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: Union[httpx.Client, httpx.AsyncClient]):
client.timeout = self.timeout

@staticmethod
Expand Down Expand Up @@ -1207,11 +1345,18 @@ class _MultiAuth(httpx.Auth):
def __init__(self, *authentication_modes):
self.authentication_modes = authentication_modes

def auth_flow(
def sync_auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
for authentication_mode in self.authentication_modes:
next(authentication_mode.auth_flow(request))
next(authentication_mode.sync_auth_flow(request))
yield request

async def async_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
for authentication_mode in self.authentication_modes:
await authentication_mode.async_auth_flow(request).__anext__()
yield request

def __add__(self, other) -> "_MultiAuth":
Expand Down
Loading