Skip to content

Commit

Permalink
Bump pyjwt to 2.8 (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw authored Apr 18, 2024
1 parent e34e724 commit 9a376a3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

- Bumped fastapi to 0.110.*.

- Bumped pyjwt to 2.8.*.

## 0.12.3 (2024-03-20)
-------------------
Expand Down
19 changes: 7 additions & 12 deletions clean_python/oauth2/token_verifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# (c) Nelen & Schuurmans

import logging
import socket
from typing import Any

import jwt
Expand Down Expand Up @@ -80,7 +79,10 @@ def __init__(
self, settings: TokenVerifierSettings, logger: logging.Logger | None = None
):
self.settings = settings
self.jwk_client = PyJWKClient(f"{settings.issuer}/.well-known/jwks.json")
self.jwk_client = PyJWKClient(
f"{settings.issuer}/.well-known/jwks.json",
timeout=self.settings.jwks_timeout,
)

def __call__(self, authorization: str | None) -> Token:
# Step 0: retrieve the token from the Authorization header
Expand All @@ -95,7 +97,7 @@ def __call__(self, authorization: str | None) -> Token:
# Step 1: Confirm the structure of the JWT. This check is part of get_kid since
# jwt.get_unverified_header will raise a JWTError if the structure is wrong.
try:
key = self.get_key(jwt_str, self.settings.jwks_timeout) # JSON Web Key
key = self.get_key(jwt_str) # JSON Web Key
except PyJWTError as e:
raise Unauthorized(f"Token is invalid: {e}")
# Step 2: Validate the JWT signature and standard claims
Expand Down Expand Up @@ -124,16 +126,9 @@ def __call__(self, authorization: str | None) -> Token:
self.authorize_user(token.user)
return token

def get_key(self, token: str, timeout: float = 1.0) -> jwt.PyJWK:
def get_key(self, token: str) -> jwt.PyJWK:
"""Return the JSON Web KEY (JWK) corresponding to kid."""
# NB: pyjwt does not allow timeouts, but we can set it using the
# global value
old_timeout = socket.getdefaulttimeout()
try:
socket.setdefaulttimeout(timeout)
return self.jwk_client.get_signing_key_from_jwt(token)
finally:
socket.setdefaulttimeout(old_timeout)
return self.jwk_client.get_signing_key_from_jwt(token)

def verify_token_use(self, claims: dict[str, Any]) -> None:
"""Check the token_use claim."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ test = [
]
dramatiq = ["dramatiq==1.15.*"]
fastapi = ["fastapi==0.110.*"]
auth = ["pyjwt==2.6.*", "cryptography==42.0.*"] # pyjwt[crypto]
auth = ["pyjwt==2.8.*", "cryptography==42.0.*"] # pyjwt[crypto]
amqp = ["pika==1.3.*"]
celery = ["celery==5.3.*"]
fluentbit = ["fluent-logger"]
Expand Down
32 changes: 11 additions & 21 deletions tests/oauth2/test_verifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# (c) Nelen & Schuurmans

import json
import socket
import time
import urllib.request
from io import BytesIO
Expand All @@ -21,7 +20,7 @@ def patched_verifier(settings, jwk_patched):
return TokenVerifier(settings)


def test_verifier_ok(patched_verifier, token_generator, jwk_patched):
def test_verifier_ok(patched_verifier, token_generator):
token = token_generator()
verified_token = patched_verifier("Bearer " + token)

Expand All @@ -30,9 +29,16 @@ def test_verifier_ok(patched_verifier, token_generator, jwk_patched):
assert verified_token.tenant is None
assert verified_token.scope == {"user"}

jwk_patched.assert_called_once_with(
"https://some/auth/server/.well-known/jwks.json"
)

def test_jwks_call(token_generator, jwk_patched, settings):
token = token_generator()
TokenVerifier(settings).get_key(token)

assert jwk_patched.call_count == 1
((request,), kwargs) = jwk_patched.call_args
assert request.get_full_url() == "https://some/auth/server/.well-known/jwks.json"
assert request.get_method() == "GET"
assert kwargs["timeout"] == settings.jwks_timeout


def test_verifier_exp_leeway(patched_verifier, token_generator):
Expand Down Expand Up @@ -86,22 +92,6 @@ def test_verifier_no_header(patched_verifier, header):
patched_verifier(header)


@mock.patch.object(urllib.request, "urlopen")
def test_get_key_timeout(urlopen, patched_verifier, token_generator, public_key):
def side_effect():
assert socket.getdefaulttimeout() == 0.1
return BytesIO(json.dumps({"keys": [public_key]}).encode())

urlopen.return_value.__enter__.side_effect = side_effect

assert socket.getdefaulttimeout() is None
key = patched_verifier.get_key(token_generator(), timeout=0.1)
assert socket.getdefaulttimeout() is None

assert isinstance(key, jwt.PyJWK)
assert key.key_id == public_key["kid"]


@mock.patch.object(urllib.request, "urlopen")
def test_get_key_invalid_kid(urlopen, settings, token_generator, public_key):
urlopen.return_value.__enter__.return_value = BytesIO(
Expand Down

0 comments on commit 9a376a3

Please sign in to comment.