Skip to content

Commit

Permalink
feat: lock credentials path to avoid concurrent access
Browse files Browse the repository at this point in the history
  • Loading branch information
natthan-pigoux authored and chaen committed Oct 16, 2024
1 parent 21076c4 commit e259927
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 114 deletions.
6 changes: 3 additions & 3 deletions diracx-cli/src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from diracx.client.models import DeviceFlowErrorResponse
from diracx.core.extensions import select_from_extension
from diracx.core.preferences import get_diracx_preferences
from diracx.core.utils import write_credentials
from diracx.core.utils import read_credentials, write_credentials

from .utils import AsyncTyper

Expand Down Expand Up @@ -94,11 +94,11 @@ async def logout():
async with DiracClient() as api:
credentials_path = get_diracx_preferences().credentials_path
if credentials_path.exists():
credentials = json.loads(credentials_path.read_text())
credentials = read_credentials(credentials_path)

# Revoke refresh token
try:
await api.auth.revoke_refresh_token(credentials["refresh_token"])
await api.auth.revoke_refresh_token(credentials.refresh_token)
except Exception as e:
print(f"Error revoking the refresh token {e!r}")
pass
Expand Down
45 changes: 12 additions & 33 deletions diracx-client/src/diracx/client/patches/aio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import json
from types import TracebackType
from pathlib import Path
from typing import Any, List, Optional, Self

from typing import Any, List, Optional, cast

from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest
Expand All @@ -23,8 +25,6 @@
from ..utils import (
get_openid_configuration,
get_token,
refresh_token,
is_refresh_token_valid,
)

__all__: List[str] = [
Expand Down Expand Up @@ -55,20 +55,12 @@ async def get_token(
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
"""
return refresh_token(
return get_token(
self.location,
kwargs.get("token"),
self.token_endpoint,
self.client_id,
kwargs["refresh_token"],
verify=self.verify,
self.verify,
)

async def close(self) -> None:
Expand Down Expand Up @@ -109,28 +101,15 @@ async def on_request(
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._token: AccessToken | None
self._credential: DiracTokenCredential
credentials: dict[str, Any]
try:
self._token = get_token(self._credential.location, self._token)
except RuntimeError:
# If we are here, it means the credentials path does not exist
self._token = await self._credential.get_token("", token=self._token)
if not self._token:
# If we are here, it means the token is not available
# we suppose it is not needed to perform the request
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
refresh_token = credentials["refresh_token"]
if not is_refresh_token_valid(refresh_token):
# If we are here, it means the refresh token is not valid anymore
# we suppose it is not needed to perform the request
return
self._token = await self._credential.get_token(
"", refresh_token=refresh_token
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
request.http_request.headers["Authorization"] = (
"Bearer " + cast(AccessToken, self._token).token
)


class DiracClientMixin(metaclass=abc.ABCMeta):
Expand Down
204 changes: 128 additions & 76 deletions diracx-client/src/diracx/client/patches/utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,120 @@
from __future__ import annotations

from datetime import datetime, timezone
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
import fcntl
import json
import os
from diracx.core.utils import EXPIRES_GRACE_SECONDS, serialize_credentials
import jwt
import requests


from pathlib import Path
from typing import Any, Dict, List, Optional, cast, Self

from typing import Any, Dict, List, Optional, TextIO
from urllib import parse
from azure.core.credentials import AccessToken
from azure.core.credentials import TokenCredential
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import BearerTokenCredentialPolicy

from ..generated.models import TokenResponse
from diracx.core.models import TokenResponse as CoreTokenResponse
from diracx.core.models import TokenResponse
from diracx.core.preferences import DiracxPreferences, get_diracx_preferences

import sys

class TokenStatus(Enum):
VALID = "valid"
REFRESH = "refresh"
INVALID = "invalid"

def refresh_token(

@dataclass
class TokenResult:
status: TokenStatus
access_token: Optional[AccessToken] = None
refresh_token: Optional[str] = None

def get_openid_configuration(
endpoint: str, *, verify: bool | str = True
) -> Dict[str, str]:
"""Get the openid configuration from the .well-known endpoint"""
response = requests.get(
url=parse.urljoin(endpoint, ".well-known/openid-configuration"),
verify=verify,
)
if not response.ok:
raise RuntimeError("Cannot fetch any information from the .well-known endpoint")
return response.json()


def get_token(
location: Path,
token: AccessToken | None,
token_endpoint: str,
client_id: str,
verify: bool,
) -> AccessToken | None:
"""Get the access token if available and still valid."""
# Immediately return the token if it is available and still valid
if token and is_token_valid(token):
return token

if not location.exists():
# If we are here, it means the credentials path does not exist
# we suppose access token is not needed to perform the request
return None

with open(location, "r+") as f:
# Acquire exclusive lock
fcntl.flock(f, fcntl.LOCK_EX)
try:
response = extract_token_from_credentials(f, token)
if response.status == TokenStatus.VALID:
# Release the lock
fcntl.flock(f, fcntl.LOCK_UN)
return response.access_token

if response.status == TokenStatus.INVALID:
# Release the lock
fcntl.flock(f, fcntl.LOCK_UN)
return None

# If we are here, it means the token needs to be refreshed
token_response = refresh_token(
token_endpoint,
client_id,
response.refresh_token,
verify=verify,
)

# Write the new credentials to the file
f.seek(0)
f.truncate()
f.write(serialize_credentials(token_response))
f.flush()
os.fsync(f.fileno())

# Get an AccessToken instance
return AccessToken(
token=token_response.access_token,
expires_on=datetime.now(tz=timezone.utc)
+ timedelta(seconds=token_response.expires_in - EXPIRES_GRACE_SECONDS),
)
finally:
# Release the lock
fcntl.flock(f, fcntl.LOCK_UN)


def refresh_token(
token_endpoint: str,
client_id: str,
refresh_token: str,
*,
verify: bool | str = True,
) -> AccessToken:
) -> TokenResponse:
"""Refresh the access token using the refresh_token flow."""
from diracx.core.utils import write_credentials

response = requests.post(
url=token_endpoint,
data={
Expand All @@ -48,56 +131,49 @@ def refresh_token(
)

res = response.json()
token_response = TokenResponse(
return TokenResponse(
access_token=res["access_token"],
expires_in=res["expires_in"],
token_type=res.get("token_type"),
refresh_token=res.get("refresh_token"),
)

write_credentials(cast(CoreTokenResponse, token_response), location=location)
credentials = json.loads(location.read_text())
return AccessToken(credentials.get("access_token"), credentials.get("expires_on"))


def get_openid_configuration(
endpoint: str, *, verify: bool | str = True
) -> Dict[str, str]:
"""Get the openid configuration from the .well-known endpoint"""
response = requests.get(
url=parse.urljoin(endpoint, ".well-known/openid-configuration"),
verify=verify,
)
if not response.ok:
raise RuntimeError("Cannot fetch any information from the .well-known endpoint")
return response.json()


def get_token(location: Path, token: AccessToken | None) -> AccessToken | None:
def extract_token_from_credentials(
token_file_descriptor: TextIO, token: AccessToken | None
) -> TokenResult:
"""Get token if available and still valid."""
# If the credentials path does not exist, raise an error
if not location.exists():
raise RuntimeError("credentials are not set")

# Load the existing credentials
if not token:
credentials = json.loads(location.read_text())
# If we are here, it means the token is not available or not valid anymore
# We try to get it from the file
try:
credentials = json.load(token_file_descriptor)
except json.JSONDecodeError:
return TokenResult(TokenStatus.INVALID)

try:
token = AccessToken(
cast(str, credentials.get("access_token")),
cast(int, credentials.get("expires_on")),
token=credentials["access_token"],
expires_on=credentials["expires_on"],
)
refresh_token = credentials["refresh_token"]
except KeyError:
return TokenResult(TokenStatus.INVALID)

# We check the validity of the token
# If not valid, then return None to inform the caller that a new token
# is needed
if not is_token_valid(token):
return None
# We check the validity of the tokens
if is_token_valid(token):
return TokenResult(TokenStatus.VALID, access_token=token)

if is_refresh_token_valid(refresh_token):
return TokenResult(TokenStatus.REFRESH, refresh_token=credentials.refresh_token)

return token
# If we are here, it means the refresh token is not valid anymore
return TokenResult(TokenStatus.INVALID)


def is_refresh_token_valid(refresh_token: str) -> bool:
def is_refresh_token_valid(refresh_token: str | None) -> bool:
"""Check if the refresh token is still valid."""
if not refresh_token:
return False
# Decode the refresh token
refresh_payload = jwt.decode(refresh_token, options={"verify_signature": False})
if not refresh_payload or "exp" not in refresh_payload:
Expand Down Expand Up @@ -138,20 +214,12 @@ def get_token(
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
"""
return refresh_token(
return get_token(
self.location,
kwargs.get("token"),
self.token_endpoint,
self.client_id,
kwargs["refresh_token"],
verify=self.verify,
self.verify,
)


Expand All @@ -167,35 +235,19 @@ def __init__(
) -> None:
super().__init__(credential, *scopes, **kwargs)

def on_request(
self, request: PipelineRequest
) -> None: # pylint:disable=invalid-overridden-method
def on_request(self, request: PipelineRequest) -> None:
"""Authorization Bearer is optional here.
:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._token: AccessToken | None
self._credential: DiracTokenCredential
credentials: dict[str, Any]

try:
self._token = get_token(self._credential.location, self._token)
except RuntimeError:
# If we are here, it means the credentials path does not exist
self._token = self._credential.get_token("", token=self._token)
if not self._token:
# If we are here, it means the token is not available
# we suppose it is not needed to perform the request
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
refresh_token = credentials["refresh_token"]
if not is_refresh_token_valid(refresh_token):
# If we are here, it means the refresh token is not valid anymore
# we suppose it is not needed to perform the request
return
self._token = self._credential.get_token("", refresh_token=refresh_token)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
self._update_headers(request.http_request.headers, self._token.token)


class DiracClientMixin:
Expand Down
Loading

0 comments on commit e259927

Please sign in to comment.