Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: lock file while read, write and refresh token #299

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions diracx-cli/src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from diracx.client.models import DeviceFlowErrorResponse
from diracx.core.extensions import select_from_extension
from diracx.core.preferences import get_diracx_preferences
from diracx.core.utils import write_credentials
from diracx.core.utils import read_credentials, write_credentials

from .utils import AsyncTyper

Expand Down Expand Up @@ -116,11 +116,11 @@ async def logout():
async with DiracClient() as api:
credentials_path = get_diracx_preferences().credentials_path
if credentials_path.exists():
credentials = json.loads(credentials_path.read_text())
credentials = read_credentials(credentials_path)

# Revoke refresh token
try:
await api.auth.revoke_refresh_token(credentials["refresh_token"])
await api.auth.revoke_refresh_token(credentials.refresh_token)
except Exception as e:
print(f"Error revoking the refresh token {e!r}")
pass
Expand Down
45 changes: 12 additions & 33 deletions diracx-client/src/diracx/client/patches/aio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from importlib.metadata import PackageNotFoundError, distribution
from types import TracebackType
from pathlib import Path
from typing import Any, List, Optional, Self

from typing import Any, List, Optional, cast

from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest
Expand All @@ -24,8 +26,6 @@
from ..utils import (
get_openid_configuration,
get_token,
refresh_token,
is_refresh_token_valid,
)

__all__: List[str] = [
Expand Down Expand Up @@ -56,20 +56,12 @@ async def get_token(
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
"""
return refresh_token(
return get_token(
self.location,
kwargs.get("token"),
self.token_endpoint,
self.client_id,
kwargs["refresh_token"],
verify=self.verify,
self.verify,
)

async def close(self) -> None:
Expand Down Expand Up @@ -110,28 +102,15 @@ async def on_request(
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._token: AccessToken | None
self._credential: DiracTokenCredential
credentials: dict[str, Any]
try:
self._token = get_token(self._credential.location, self._token)
except RuntimeError:
# If we are here, it means the credentials path does not exist
self._token = await self._credential.get_token("", token=self._token)
if not self._token:
# If we are here, it means the token is not available
# we suppose it is not needed to perform the request
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
refresh_token = credentials["refresh_token"]
if not is_refresh_token_valid(refresh_token):
# If we are here, it means the refresh token is not valid anymore
# we suppose it is not needed to perform the request
return
self._token = await self._credential.get_token(
"", refresh_token=refresh_token
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
request.http_request.headers["Authorization"] = (
"Bearer " + cast(AccessToken, self._token).token
)


class DiracClientMixin(metaclass=abc.ABCMeta):
Expand Down
218 changes: 141 additions & 77 deletions diracx-client/src/diracx/client/patches/utils.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,132 @@
from __future__ import annotations


import json
import os
from diracx.core.utils import EXPIRES_GRACE_SECONDS, serialize_credentials
import jwt
import requests

from datetime import datetime, timezone
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
import fcntl
from importlib.metadata import PackageNotFoundError, distribution

from pathlib import Path
from typing import Any, Dict, List, Optional, cast, Self

from typing import Any, Dict, List, Optional, TextIO
from urllib import parse
from azure.core.credentials import AccessToken
from azure.core.credentials import TokenCredential
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy

from ..generated.models import TokenResponse
from diracx.core.models import TokenResponse as CoreTokenResponse
from diracx.core.models import TokenResponse
from diracx.core.preferences import DiracxPreferences, get_diracx_preferences

import sys

class TokenStatus(Enum):
VALID = "valid"
REFRESH = "refresh"
INVALID = "invalid"

def refresh_token(

@dataclass
class TokenResult:
status: TokenStatus
access_token: Optional[AccessToken] = None
refresh_token: Optional[str] = None


def get_openid_configuration(
endpoint: str, *, verify: bool | str = True
) -> Dict[str, str]:
"""Get the openid configuration from the .well-known endpoint"""
response = requests.get(
url=parse.urljoin(endpoint, ".well-known/openid-configuration"),
verify=verify,
)
if not response.ok:
raise RuntimeError("Cannot fetch any information from the .well-known endpoint")
return response.json()


def get_token(
location: Path,
token: AccessToken | None,
token_endpoint: str,
client_id: str,
verify: bool,
) -> AccessToken | None:
"""Get the access token if available and still valid."""
# Immediately return the token if it is available and still valid
if token and is_token_valid(token):
return token

if not location.exists():
# If we are here, it means the credentials path does not exist
# we suppose access token is not needed to perform the request
return None

with open(location, "r+") as f:
# Acquire exclusive lock
fcntl.flock(f, fcntl.LOCK_EX)
try:
response = extract_token_from_credentials(f, token)
if response.status == TokenStatus.VALID:
# Release the lock
fcntl.flock(f, fcntl.LOCK_UN)
return response.access_token

if response.status == TokenStatus.INVALID:
# Release the lock
fcntl.flock(f, fcntl.LOCK_UN)
return None

if response.status == TokenStatus.REFRESH and response.refresh_token:
# If we are here, it means the token needs to be refreshed
token_response = refresh_token(
token_endpoint,
client_id,
response.refresh_token,
verify=verify,
)

# Write the new credentials to the file
f.seek(0)
f.truncate()
f.write(serialize_credentials(token_response))
f.flush()
os.fsync(f.fileno())

# Get an AccessToken instance
return AccessToken(
token=token_response.access_token,
expires_on=int(
(
datetime.now(tz=timezone.utc)
+ timedelta(
seconds=token_response.expires_in
- EXPIRES_GRACE_SECONDS
)
).timestamp()
),
)
else:
return None
finally:
# Release the lock
fcntl.flock(f, fcntl.LOCK_UN)


def refresh_token(
token_endpoint: str,
client_id: str,
refresh_token: str,
*,
verify: bool | str = True,
) -> AccessToken:
) -> TokenResponse:
"""Refresh the access token using the refresh_token flow."""
from diracx.core.utils import write_credentials

response = requests.post(
url=token_endpoint,
data={
Expand All @@ -50,56 +143,49 @@ def refresh_token(
)

res = response.json()
token_response = TokenResponse(
return TokenResponse(
access_token=res["access_token"],
expires_in=res["expires_in"],
token_type=res.get("token_type"),
refresh_token=res.get("refresh_token"),
)

write_credentials(cast(CoreTokenResponse, token_response), location=location)
credentials = json.loads(location.read_text())
return AccessToken(credentials.get("access_token"), credentials.get("expires_on"))


def get_openid_configuration(
endpoint: str, *, verify: bool | str = True
) -> Dict[str, str]:
"""Get the openid configuration from the .well-known endpoint"""
response = requests.get(
url=parse.urljoin(endpoint, ".well-known/openid-configuration"),
verify=verify,
)
if not response.ok:
raise RuntimeError("Cannot fetch any information from the .well-known endpoint")
return response.json()


def get_token(location: Path, token: AccessToken | None) -> AccessToken | None:
def extract_token_from_credentials(
token_file_descriptor: TextIO, token: AccessToken | None
) -> TokenResult:
"""Get token if available and still valid."""
# If the credentials path does not exist, raise an error
if not location.exists():
raise RuntimeError("credentials are not set")

# Load the existing credentials
if not token:
credentials = json.loads(location.read_text())
# If we are here, it means the token is not available or not valid anymore
# We try to get it from the file
try:
credentials = json.load(token_file_descriptor)
except json.JSONDecodeError:
return TokenResult(TokenStatus.INVALID)

try:
token = AccessToken(
cast(str, credentials.get("access_token")),
cast(int, credentials.get("expires_on")),
token=credentials["access_token"],
expires_on=credentials["expires_on"],
)
refresh_token = credentials["refresh_token"]
except KeyError:
return TokenResult(TokenStatus.INVALID)

# We check the validity of the token
# If not valid, then return None to inform the caller that a new token
# is needed
if not is_token_valid(token):
return None
# We check the validity of the tokens
if is_token_valid(token):
return TokenResult(TokenStatus.VALID, access_token=token)

if is_refresh_token_valid(refresh_token):
return TokenResult(TokenStatus.REFRESH, refresh_token=refresh_token)

return token
# If we are here, it means the refresh token is not valid anymore
return TokenResult(TokenStatus.INVALID)


def is_refresh_token_valid(refresh_token: str) -> bool:
def is_refresh_token_valid(refresh_token: str | None) -> bool:
"""Check if the refresh token is still valid."""
if not refresh_token:
return False
# Decode the refresh token
refresh_payload = jwt.decode(refresh_token, options={"verify_signature": False})
if not refresh_payload or "exp" not in refresh_payload:
Expand Down Expand Up @@ -140,20 +226,12 @@ def get_token(
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
"""
return refresh_token(
return get_token(
self.location,
kwargs.get("token"),
self.token_endpoint,
self.client_id,
kwargs["refresh_token"],
verify=self.verify,
self.verify,
)


Expand All @@ -169,35 +247,21 @@ def __init__(
) -> None:
super().__init__(credential, *scopes, **kwargs)

def on_request(
self, request: PipelineRequest
) -> None: # pylint:disable=invalid-overridden-method
def on_request(self, request: PipelineRequest) -> None:
"""Authorization Bearer is optional here.
:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._token: AccessToken | None
self._credential: DiracTokenCredential
credentials: dict[str, Any]

try:
self._token = get_token(self._credential.location, self._token)
except RuntimeError:
# If we are here, it means the credentials path does not exist
self._token: AccessToken | None = self._credential.get_token(
"", token=self._token
)
if not self._token:
# If we are here, it means the token is not available
# we suppose it is not needed to perform the request
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
refresh_token = credentials["refresh_token"]
if not is_refresh_token_valid(refresh_token):
# If we are here, it means the refresh token is not valid anymore
# we suppose it is not needed to perform the request
return
self._token = self._credential.get_token("", refresh_token=refresh_token)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
self._update_headers(request.http_request.headers, self._token.token)


class DiracClientMixin:
Expand Down
Loading
Loading