From e259927ff7283f2f5ade54b21891d52323a5b4b5 Mon Sep 17 00:00:00 2001 From: natthan-pigoux Date: Tue, 17 Sep 2024 09:32:11 +0200 Subject: [PATCH] feat: lock credentials path to avoid concurrent access --- diracx-cli/src/diracx/cli/__init__.py | 6 +- .../src/diracx/client/patches/aio/utils.py | 45 +--- .../src/diracx/client/patches/utils.py | 204 ++++++++++------- diracx-core/src/diracx/core/utils.py | 41 +++- diracx-core/tests/test_utils.py | 207 +++++++++++++++++- 5 files changed, 389 insertions(+), 114 deletions(-) diff --git a/diracx-cli/src/diracx/cli/__init__.py b/diracx-cli/src/diracx/cli/__init__.py index 9f2a04f9..a5137248 100644 --- a/diracx-cli/src/diracx/cli/__init__.py +++ b/diracx-cli/src/diracx/cli/__init__.py @@ -10,7 +10,7 @@ from diracx.client.models import DeviceFlowErrorResponse from diracx.core.extensions import select_from_extension from diracx.core.preferences import get_diracx_preferences -from diracx.core.utils import write_credentials +from diracx.core.utils import read_credentials, write_credentials from .utils import AsyncTyper @@ -94,11 +94,11 @@ async def logout(): async with DiracClient() as api: credentials_path = get_diracx_preferences().credentials_path if credentials_path.exists(): - credentials = json.loads(credentials_path.read_text()) + credentials = read_credentials(credentials_path) # Revoke refresh token try: - await api.auth.revoke_refresh_token(credentials["refresh_token"]) + await api.auth.revoke_refresh_token(credentials.refresh_token) except Exception as e: print(f"Error revoking the refresh token {e!r}") pass diff --git a/diracx-client/src/diracx/client/patches/aio/utils.py b/diracx-client/src/diracx/client/patches/aio/utils.py index b010521f..61a1d470 100644 --- a/diracx-client/src/diracx/client/patches/aio/utils.py +++ b/diracx-client/src/diracx/client/patches/aio/utils.py @@ -12,7 +12,9 @@ import json from types import TracebackType from pathlib import Path -from typing import Any, List, Optional, Self + +from typing import Any, List, Optional, cast + from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest @@ -23,8 +25,6 @@ from ..utils import ( get_openid_configuration, get_token, - refresh_token, - is_refresh_token_valid, ) __all__: List[str] = [ @@ -55,20 +55,12 @@ async def get_token( tenant_id: Optional[str] = None, **kwargs: Any, ) -> AccessToken: - """Refresh the access token using the refresh_token flow. - :param str scopes: The type of access needed. - :keyword str claims: Additional claims required in the token, such as those returned in a resource - provider's claims challenge following an authorization failure. - :keyword str tenant_id: Optional tenant to include in the token request. - :rtype: AccessToken - :return: An AccessToken instance containing the token string and its expiration time in Unix time. - """ - return refresh_token( + return get_token( self.location, + kwargs.get("token"), self.token_endpoint, self.client_id, - kwargs["refresh_token"], - verify=self.verify, + self.verify, ) async def close(self) -> None: @@ -109,28 +101,15 @@ async def on_request( :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ - self._token: AccessToken | None - self._credential: DiracTokenCredential - credentials: dict[str, Any] - try: - self._token = get_token(self._credential.location, self._token) - except RuntimeError: - # If we are here, it means the credentials path does not exist + self._token = await self._credential.get_token("", token=self._token) + if not self._token: + # If we are here, it means the token is not available # we suppose it is not needed to perform the request return - if not self._token: - credentials = json.loads(self._credential.location.read_text()) - refresh_token = credentials["refresh_token"] - if not is_refresh_token_valid(refresh_token): - # If we are here, it means the refresh token is not valid anymore - # we suppose it is not needed to perform the request - return - self._token = await self._credential.get_token( - "", refresh_token=refresh_token - ) - - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" + request.http_request.headers["Authorization"] = ( + "Bearer " + cast(AccessToken, self._token).token + ) class DiracClientMixin(metaclass=abc.ABCMeta): diff --git a/diracx-client/src/diracx/client/patches/utils.py b/diracx-client/src/diracx/client/patches/utils.py index 834703c9..01abf22d 100644 --- a/diracx-client/src/diracx/client/patches/utils.py +++ b/diracx-client/src/diracx/client/patches/utils.py @@ -1,37 +1,120 @@ from __future__ import annotations -from datetime import datetime, timezone +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +import fcntl import json +import os +from diracx.core.utils import EXPIRES_GRACE_SECONDS, serialize_credentials import jwt import requests from pathlib import Path -from typing import Any, Dict, List, Optional, cast, Self + +from typing import Any, Dict, List, Optional, TextIO from urllib import parse from azure.core.credentials import AccessToken from azure.core.credentials import TokenCredential from azure.core.pipeline import PipelineRequest from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from ..generated.models import TokenResponse -from diracx.core.models import TokenResponse as CoreTokenResponse +from diracx.core.models import TokenResponse from diracx.core.preferences import DiracxPreferences, get_diracx_preferences -import sys +class TokenStatus(Enum): + VALID = "valid" + REFRESH = "refresh" + INVALID = "invalid" -def refresh_token( + +@dataclass +class TokenResult: + status: TokenStatus + access_token: Optional[AccessToken] = None + refresh_token: Optional[str] = None + +def get_openid_configuration( + endpoint: str, *, verify: bool | str = True +) -> Dict[str, str]: + """Get the openid configuration from the .well-known endpoint""" + response = requests.get( + url=parse.urljoin(endpoint, ".well-known/openid-configuration"), + verify=verify, + ) + if not response.ok: + raise RuntimeError("Cannot fetch any information from the .well-known endpoint") + return response.json() + + +def get_token( location: Path, + token: AccessToken | None, + token_endpoint: str, + client_id: str, + verify: bool, +) -> AccessToken | None: + """Get the access token if available and still valid.""" + # Immediately return the token if it is available and still valid + if token and is_token_valid(token): + return token + + if not location.exists(): + # If we are here, it means the credentials path does not exist + # we suppose access token is not needed to perform the request + return None + + with open(location, "r+") as f: + # Acquire exclusive lock + fcntl.flock(f, fcntl.LOCK_EX) + try: + response = extract_token_from_credentials(f, token) + if response.status == TokenStatus.VALID: + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) + return response.access_token + + if response.status == TokenStatus.INVALID: + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) + return None + + # If we are here, it means the token needs to be refreshed + token_response = refresh_token( + token_endpoint, + client_id, + response.refresh_token, + verify=verify, + ) + + # Write the new credentials to the file + f.seek(0) + f.truncate() + f.write(serialize_credentials(token_response)) + f.flush() + os.fsync(f.fileno()) + + # Get an AccessToken instance + return AccessToken( + token=token_response.access_token, + expires_on=datetime.now(tz=timezone.utc) + + timedelta(seconds=token_response.expires_in - EXPIRES_GRACE_SECONDS), + ) + finally: + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) + + +def refresh_token( token_endpoint: str, client_id: str, refresh_token: str, *, verify: bool | str = True, -) -> AccessToken: +) -> TokenResponse: """Refresh the access token using the refresh_token flow.""" - from diracx.core.utils import write_credentials - response = requests.post( url=token_endpoint, data={ @@ -48,56 +131,49 @@ def refresh_token( ) res = response.json() - token_response = TokenResponse( + return TokenResponse( access_token=res["access_token"], expires_in=res["expires_in"], token_type=res.get("token_type"), refresh_token=res.get("refresh_token"), ) - write_credentials(cast(CoreTokenResponse, token_response), location=location) - credentials = json.loads(location.read_text()) - return AccessToken(credentials.get("access_token"), credentials.get("expires_on")) - - -def get_openid_configuration( - endpoint: str, *, verify: bool | str = True -) -> Dict[str, str]: - """Get the openid configuration from the .well-known endpoint""" - response = requests.get( - url=parse.urljoin(endpoint, ".well-known/openid-configuration"), - verify=verify, - ) - if not response.ok: - raise RuntimeError("Cannot fetch any information from the .well-known endpoint") - return response.json() - -def get_token(location: Path, token: AccessToken | None) -> AccessToken | None: +def extract_token_from_credentials( + token_file_descriptor: TextIO, token: AccessToken | None +) -> TokenResult: """Get token if available and still valid.""" - # If the credentials path does not exist, raise an error - if not location.exists(): - raise RuntimeError("credentials are not set") - - # Load the existing credentials - if not token: - credentials = json.loads(location.read_text()) + # If we are here, it means the token is not available or not valid anymore + # We try to get it from the file + try: + credentials = json.load(token_file_descriptor) + except json.JSONDecodeError: + return TokenResult(TokenStatus.INVALID) + + try: token = AccessToken( - cast(str, credentials.get("access_token")), - cast(int, credentials.get("expires_on")), + token=credentials["access_token"], + expires_on=credentials["expires_on"], ) + refresh_token = credentials["refresh_token"] + except KeyError: + return TokenResult(TokenStatus.INVALID) - # We check the validity of the token - # If not valid, then return None to inform the caller that a new token - # is needed - if not is_token_valid(token): - return None + # We check the validity of the tokens + if is_token_valid(token): + return TokenResult(TokenStatus.VALID, access_token=token) + + if is_refresh_token_valid(refresh_token): + return TokenResult(TokenStatus.REFRESH, refresh_token=credentials.refresh_token) - return token + # If we are here, it means the refresh token is not valid anymore + return TokenResult(TokenStatus.INVALID) -def is_refresh_token_valid(refresh_token: str) -> bool: +def is_refresh_token_valid(refresh_token: str | None) -> bool: """Check if the refresh token is still valid.""" + if not refresh_token: + return False # Decode the refresh token refresh_payload = jwt.decode(refresh_token, options={"verify_signature": False}) if not refresh_payload or "exp" not in refresh_payload: @@ -138,20 +214,12 @@ def get_token( tenant_id: Optional[str] = None, **kwargs: Any, ) -> AccessToken: - """Refresh the access token using the refresh_token flow. - :param str scopes: The type of access needed. - :keyword str claims: Additional claims required in the token, such as those returned in a resource - provider's claims challenge following an authorization failure. - :keyword str tenant_id: Optional tenant to include in the token request. - :rtype: AccessToken - :return: An AccessToken instance containing the token string and its expiration time in Unix time. - """ - return refresh_token( + return get_token( self.location, + kwargs.get("token"), self.token_endpoint, self.client_id, - kwargs["refresh_token"], - verify=self.verify, + self.verify, ) @@ -167,35 +235,19 @@ def __init__( ) -> None: super().__init__(credential, *scopes, **kwargs) - def on_request( - self, request: PipelineRequest - ) -> None: # pylint:disable=invalid-overridden-method + def on_request(self, request: PipelineRequest) -> None: """Authorization Bearer is optional here. :param request: The pipeline request object to be modified. :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ - self._token: AccessToken | None - self._credential: DiracTokenCredential - credentials: dict[str, Any] - - try: - self._token = get_token(self._credential.location, self._token) - except RuntimeError: - # If we are here, it means the credentials path does not exist + self._token = self._credential.get_token("", token=self._token) + if not self._token: + # If we are here, it means the token is not available # we suppose it is not needed to perform the request return - if not self._token: - credentials = json.loads(self._credential.location.read_text()) - refresh_token = credentials["refresh_token"] - if not is_refresh_token_valid(refresh_token): - # If we are here, it means the refresh token is not valid anymore - # we suppose it is not needed to perform the request - return - self._token = self._credential.get_token("", refresh_token=refresh_token) - - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" + self._update_headers(request.http_request.headers, self._token.token) class DiracClientMixin: diff --git a/diracx-core/src/diracx/core/utils.py b/diracx-core/src/diracx/core/utils.py index 63e0a310..95c1319a 100644 --- a/diracx-core/src/diracx/core/utils.py +++ b/diracx-core/src/diracx/core/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import fcntl import json import os import re @@ -37,10 +38,48 @@ def serialize_credentials(token_response: TokenResponse) -> str: return json.dumps(credential_data) +def read_credentials(location: Path) -> TokenResponse: + """Read credentials from a file.""" + from diracx.core.preferences import get_diracx_preferences + + credentials_path = location or get_diracx_preferences().credentials_path + try: + with open(credentials_path, "r") as f: + # Lock the file to prevent other processes from writing to it at the same time + fcntl.flock(f, fcntl.LOCK_SH | fcntl.LOCK_NB) + # Read the credentials from the file + try: + credentials = json.load(f) + finally: + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) + except (BlockingIOError, FileNotFoundError, json.JSONDecodeError) as e: + raise RuntimeError(f"Error reading credentials: {e}") from e + + return TokenResponse( + access_token=credentials["access_token"], + expires_in=credentials["expires_on"] + - int(datetime.now(tz=timezone.utc).timestamp()), + token_type="Bearer", # noqa: S106 + refresh_token=credentials.get("refresh_token"), + ) + + def write_credentials(token_response: TokenResponse, *, location: Path | None = None): """Write credentials received in dirax_preferences.credentials_path.""" from diracx.core.preferences import get_diracx_preferences credentials_path = location or get_diracx_preferences().credentials_path credentials_path.parent.mkdir(parents=True, exist_ok=True) - credentials_path.write_text(serialize_credentials(token_response)) + + with open(credentials_path, "w") as f: + # Lock the file to prevent other processes from writing to it at the same time + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + try: + # Write the credentials to the file + f.write(serialize_credentials(token_response)) + f.flush() + os.fsync(f.fileno()) + finally: + # Release the lock + fcntl.flock(f, fcntl.LOCK_UN) diff --git a/diracx-core/tests/test_utils.py b/diracx-core/tests/test_utils.py index 7a6f2c8b..0da0cc3d 100644 --- a/diracx-core/tests/test_utils.py +++ b/diracx-core/tests/test_utils.py @@ -1,6 +1,21 @@ from __future__ import annotations -from diracx.core.utils import dotenv_files_from_environment +import fcntl +import time +from datetime import datetime, timedelta, timezone +from multiprocessing import Pool +from pathlib import Path +from tempfile import NamedTemporaryFile + +import pytest + +from diracx.core.models import TokenResponse +from diracx.core.utils import ( + dotenv_files_from_environment, + read_credentials, + serialize_credentials, + write_credentials, +) def test_dotenv_files_from_environment(monkeypatch): @@ -24,3 +39,193 @@ def test_dotenv_files_from_environment(monkeypatch): {"TEST_PREFIX_2a": "/c", "TEST_PREFIX": "/a", "TEST_PREFIX_1": "/b"}, ) assert dotenv_files_from_environment("TEST_PREFIX") == ["/a", "/b"] + + +TOKEN_RESPONSE_DICT = { + "access_token": "test_token", + "expires_in": int((datetime.now(tz=timezone.utc) + timedelta(days=1)).timestamp()), + "token_type": "Bearer", + "refresh_token": "test_refresh", +} +CREDENTIALS_CONTENT = serialize_credentials(TokenResponse(**TOKEN_RESPONSE_DICT)) + + +def lock_and_read_file(file_path): + """Lock and read file.""" + with open(file_path, "r") as f: + fcntl.flock(f, fcntl.LOCK_SH | fcntl.LOCK_NB) + f.read() + time.sleep(2) + fcntl.flock(f, fcntl.LOCK_UN) + + +def lock_and_write_file(file_path: Path): + """Lock and write file.""" + with open(file_path, "a") as f: + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + f.write(CREDENTIALS_CONTENT) + time.sleep(2) + fcntl.flock(f, fcntl.LOCK_UN) + + +@pytest.fixture +def token_setup() -> tuple[TokenResponse, Path]: + """Setup token response and location.""" + token_location = Path(NamedTemporaryFile().name) + token_response = TokenResponse(**TOKEN_RESPONSE_DICT) + return token_response, token_location + + +@pytest.fixture +def concurrent_access_to_lock_file(): + + def run_processes(proc_to_test, *, read=True): + """Run the process to be tested and attempt to read or write concurrently.""" + location = proc_to_test[1]["location"] + error_dict = dict() + with Pool(2) as pool: + if read: + # Creating the file before reading it + with open(location, "w") as f: + f.write(CREDENTIALS_CONTENT) + pool.apply_async( + lock_and_read_file, + args=(location,), + error_callback=lambda e: error_callback( + e, error_dict, "lock_and_read_file" + ), + ) + else: + pool.apply_async( + lock_and_write_file, + args=(location,), + error_callback=lambda e: error_callback( + e, error_dict, "lock_and_write_file" + ), + ) + time.sleep(1) + pool.apply_async( + proc_to_test[0], + kwds=proc_to_test[1], + error_callback=lambda e: error_callback( + e, error_dict, f"{proc_to_test[0].__name__}" + ), + ) + pool.close() + pool.join() + return error_dict + + return run_processes + + +def error_callback(error, error_dict, process_name): + """Called if the process fails.""" + error_dict[process_name] = error + + +def assert_read_credentials_error_message(exc_info): + assert "Error reading credentials:" in exc_info.value.args[0] + + +def test_read_credentials_reading_locked_file( + token_setup, concurrent_access_to_lock_file +): + """Test that read_credentials reading a locked file end in error.""" + _, token_location = token_setup + process_to_test = (read_credentials, {"location": token_location}) + error_dict = concurrent_access_to_lock_file(process_to_test, read=False) + process_name = process_to_test[0].__name__ + if process_name in error_dict.keys(): + assert isinstance(error_dict[process_name], RuntimeError) + else: + raise AssertionError( + "Expected a RuntimeError while reading locked credentials." + ) + + +def test_write_credentials_writing_locked_file( + token_setup, concurrent_access_to_lock_file +): + """Test that write_credentials writing a locked file end in error.""" + token_response, token_location = token_setup + process_to_test = ( + write_credentials, + {"token_response": token_response, "location": token_location}, + ) + error_dict = concurrent_access_to_lock_file(process_to_test) + process_name = process_to_test[0].__name__ + if process_name in error_dict.keys(): + assert isinstance(error_dict[process_name], BlockingIOError) + else: + raise AssertionError( + "Expected a BlockingIOError while writing locked credentials." + ) + + +def create_temp_file(content=None) -> Path: + """Helper function to create a temporary file with optional content.""" + temp_file = NamedTemporaryFile(delete=False) + temp_path = Path(temp_file.name) + temp_file.close() + if content is not None: + temp_path.write_text(content) + return temp_path + + +def test_read_credentials_empty_file(): + """Test that read_credentials raises an appropriate error for an empty file.""" + temp_file = create_temp_file("") + + with pytest.raises(RuntimeError) as exc_info: + read_credentials(location=temp_file) + + temp_file.unlink() + assert_read_credentials_error_message(exc_info) + + +def test_write_credentials_empty_file(token_setup): + """Test that write_credentials raises an appropriate error for an empty file.""" + temp_file = create_temp_file("") + token_response, _ = token_setup + write_credentials(token_response, location=temp_file) + temp_file.unlink() + + +def test_read_credentials_missing_file(): + """Test that read_credentials raises an appropriate error for a missing file.""" + missing_file = Path("/path/to/nonexistent/file.txt") + with pytest.raises(RuntimeError) as exc_info: + read_credentials(location=missing_file) + assert_read_credentials_error_message(exc_info) + + +def test_write_credentials_unavailable_path(token_setup): + """Test that write_credentials raises error when it can't create path.""" + wrong_path = Path("/wrong/path/file.txt") + token_response, _ = token_setup + with pytest.raises(PermissionError): + write_credentials(token_response, location=wrong_path) + + +def test_read_credentials_invalid_content(): + """Test that read_credentials raises an appropriate error for a file with invalid content.""" + temp_file = create_temp_file("invalid content") + + with pytest.raises(RuntimeError) as exc_info: + read_credentials(location=temp_file) + + temp_file.unlink() + assert_read_credentials_error_message(exc_info) + + +def test_read_credentials_valid_file(token_setup): + """Test that read_credentials works correctly with a valid file.""" + token_response, _ = token_setup + temp_file = create_temp_file(content=CREDENTIALS_CONTENT) + + credentials = read_credentials(location=temp_file) + temp_file.unlink() + assert credentials.access_token == token_response.access_token + assert credentials.expires_in < token_response.expires_in + assert credentials.token_type == token_response.token_type + assert credentials.refresh_token == token_response.refresh_token