Skip to content

Commit

Permalink
Refactor(oauth): resolve RuntimeWarning (#2152)
Browse files Browse the repository at this point in the history
  • Loading branch information
angela-tran authored Jun 12, 2024
2 parents 0a246e1 + 06186fa commit 9556fe5
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 83 deletions.
12 changes: 0 additions & 12 deletions benefits/oauth/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,3 @@ class OAuthAppConfig(AppConfig):
name = "benefits.oauth"
label = "oauth"
verbose_name = "Benefits OAuth"

def ready(self):
# delay import until the ready() function is called, signaling that
# Django has loaded all the apps and models
from .client import oauth, register_providers

# wrap registration in try/catch
# even though we are in a ready() function, sometimes it's called early?
try:
register_providers(oauth)
except Exception:
pass
38 changes: 22 additions & 16 deletions benefits/oauth/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

from authlib.integrations.django_client import OAuth

from benefits.core.models import AuthProvider


logger = logging.getLogger(__name__)

oauth = OAuth()
Expand Down Expand Up @@ -42,23 +39,32 @@ def _authorize_params(scheme):
return params


def register_providers(oauth_registry):
def _register_provider(oauth_registry, provider):
"""
Register OAuth clients into the given registry, using configuration from AuthProvider models.
Register OAuth clients into the given registry, using configuration from AuthProvider model.
Adapted from https://stackoverflow.com/a/64174413.
"""
logger.info("Registering OAuth clients")
logger.debug(f"Registering OAuth client: {provider.client_name}")

client = oauth_registry.register(
provider.client_name,
client_id=provider.client_id,
server_metadata_url=_server_metadata_url(provider.authority),
client_kwargs=_client_kwargs(provider.scope),
authorize_params=_authorize_params(provider.scheme),
)

return client

providers = AuthProvider.objects.all()

for provider in providers:
logger.debug(f"Registering OAuth client: {provider.client_name}")
def create_client(oauth_registry, provider):
"""
Returns an OAuth client, registering it if needed.
"""
client = oauth_registry.create_client(provider.client_name)

if client is None:
client = _register_provider(oauth_registry, provider)

oauth_registry.register(
provider.client_name,
client_id=provider.client_id,
server_metadata_url=_server_metadata_url(provider.authority),
client_kwargs=_client_kwargs(provider.scope),
authorize_params=_authorize_params(provider.scheme),
)
return client
7 changes: 4 additions & 3 deletions benefits/oauth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from benefits.core import session
from . import analytics, redirects
from .client import oauth
from .client import oauth, create_client
from .middleware import VerifierUsesAuthVerificationSessionRequired


Expand All @@ -24,7 +24,8 @@
def login(request):
"""View implementing OIDC authorize_redirect."""
verifier = session.verifier(request)
oauth_client = oauth.create_client(verifier.auth_provider.client_name)

oauth_client = create_client(oauth, verifier.auth_provider)

if not oauth_client:
raise Exception(f"oauth_client not registered: {verifier.auth_provider.client_name}")
Expand All @@ -43,7 +44,7 @@ def login(request):
def authorize(request):
"""View implementing OIDC token authorization."""
verifier = session.verifier(request)
oauth_client = oauth.create_client(verifier.auth_provider.client_name)
oauth_client = create_client(oauth, verifier.auth_provider)

if not oauth_client:
raise Exception(f"oauth_client not registered: {verifier.auth_provider.client_name}")
Expand Down
9 changes: 2 additions & 7 deletions docs/configuration/oauth.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,9 @@ The [data migration file](./data.md) contains sample values for an `AuthProvider
The [`benefits.oauth.client`][oauth-client] module defines helpers for registering OAuth clients, and creating instances for
use in e.g. views.

- `register_providers(oauth_registry)` uses data from `AuthProvider` instances to register clients into the given registry
- `oauth` is an `authlib.integrations.django_client.OAuth` instance

Providers are registered into this instance once in the [`OAuthAppConfig.ready()`][oauth-app-ready] function at application
startup.
Consumers call `benefits.oauth.client.create_client(oauth, provider)` with the name of a client to obtain an Authlib client
instance. If that client name has not been registered yet, `_register_provider(oauth_registry, provider)` uses data from the given `AuthProvider` instance to register the client into this instance and returns the client object.

Consumers call `oauth.create_client(client_name)` with the name of a previously registered client to obtain an Authlib client
instance.

[oauth-app-ready]: https://github.com/cal-itp/benefits/blob/dev/benefits/oauth/__init__.py
[oauth-client]: https://github.com/cal-itp/benefits/blob/dev/benefits/oauth/client.py
23 changes: 0 additions & 23 deletions tests/pytest/oauth/test_app.py

This file was deleted.

68 changes: 46 additions & 22 deletions tests/pytest/oauth/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from benefits.core.models import AuthProvider
from benefits.oauth.client import _client_kwargs, _server_metadata_url, _authorize_params, register_providers
from benefits.oauth.client import _client_kwargs, _server_metadata_url, _authorize_params, _register_provider, create_client


def test_client_kwargs():
Expand Down Expand Up @@ -39,33 +39,57 @@ def test_authorize_params_no_scheme():


@pytest.mark.django_db
def test_register_providers(mocker, mocked_oauth_registry):
mock_providers = []
def test_register_provider(mocker, mocked_oauth_registry):
mocked_client_provider = mocker.Mock(spec=AuthProvider)
mocked_client_provider.client_name = "client_name_1"
mocked_client_provider.client_id = "client_id_1"

for i in range(3):
p = mocker.Mock(spec=AuthProvider)
p.client_name = f"client_name_{i}"
p.client_id = f"client_id_{i}"
mock_providers.append(p)
mocker.patch("benefits.oauth.client._client_kwargs", return_value={"client": "kwargs"})
mocker.patch("benefits.oauth.client._server_metadata_url", return_value="https://metadata.url")
mocker.patch("benefits.oauth.client._authorize_params", return_value={"scheme": "test_scheme"})

_register_provider(mocked_oauth_registry, mocked_client_provider)

mocked_client_provider = mocker.patch("benefits.oauth.client.AuthProvider")
mocked_client_provider.objects.all.return_value = mock_providers
mocked_oauth_registry.register.assert_any_call(
"client_name_1",
client_id="client_id_1",
server_metadata_url="https://metadata.url",
client_kwargs={"client": "kwargs"},
authorize_params={"scheme": "test_scheme"},
)


@pytest.mark.django_db
def test_create_client_already_registered(mocker, mocked_oauth_registry):
mocked_client_provider = mocker.Mock(spec=AuthProvider)
mocked_client_provider.client_name = "client_name_1"
mocked_client_provider.client_id = "client_id_1"

create_client(mocked_oauth_registry, mocked_client_provider)

mocked_oauth_registry.create_client.assert_any_call("client_name_1")
mocked_oauth_registry.register.assert_not_called()


@pytest.mark.django_db
def test_create_client_already_not_registered_yet(mocker, mocked_oauth_registry):
mocked_client_provider = mocker.Mock(spec=AuthProvider)
mocked_client_provider.client_name = "client_name_1"
mocked_client_provider.client_id = "client_id_1"

mocker.patch("benefits.oauth.client._client_kwargs", return_value={"client": "kwargs"})
mocker.patch("benefits.oauth.client._server_metadata_url", return_value="https://metadata.url")
mocker.patch("benefits.oauth.client._authorize_params", return_value={"scheme": "test_scheme"})

register_providers(mocked_oauth_registry)

mocked_client_provider.objects.all.assert_called_once()
mocked_oauth_registry.create_client.return_value = None

for provider in mock_providers:
i = mock_providers.index(provider)
create_client(mocked_oauth_registry, mocked_client_provider)

mocked_oauth_registry.register.assert_any_call(
f"client_name_{i}",
client_id=f"client_id_{i}",
server_metadata_url="https://metadata.url",
client_kwargs={"client": "kwargs"},
authorize_params={"scheme": "test_scheme"},
)
mocked_oauth_registry.create_client.assert_any_call("client_name_1")
mocked_oauth_registry.register.assert_any_call(
"client_name_1",
client_id="client_id_1",
server_metadata_url="https://metadata.url",
client_kwargs={"client": "kwargs"},
authorize_params={"scheme": "test_scheme"},
)

0 comments on commit 9556fe5

Please sign in to comment.