Skip to content

Commit

Permalink
thread-safe token refresh (#532)
Browse files Browse the repository at this point in the history
* thread-safe token refresh

* add some debug logs

* fix a benign bug

* fix failing unittest
  • Loading branch information
hasan7n authored Feb 18, 2024
1 parent cf9f41b commit d1581dc
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
4 changes: 3 additions & 1 deletion cli/medperf/account_management/token_storage/filesystem.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import base64
import logging
from medperf.utils import remove_path
from medperf import config

Expand Down Expand Up @@ -27,7 +28,7 @@ def __get_paths(self, account_id):

def set_tokens(self, account_id, access_token, refresh_token):
access_token_file, refresh_token_file = self.__get_paths(account_id)

logging.debug("Writing tokens to disk.")
fd = os.open(access_token_file, os.O_CREAT | os.O_WRONLY, 0o600)
os.write(fd, access_token.encode("utf-8"))
os.close(fd)
Expand All @@ -38,6 +39,7 @@ def set_tokens(self, account_id, access_token, refresh_token):

def read_tokens(self, account_id):
access_token_file, refresh_token_file = self.__get_paths(account_id)
logging.debug("Reading tokens to disk.")
with open(access_token_file) as f:
access_token = f.read()
with open(refresh_token_file) as f:
Expand Down
18 changes: 17 additions & 1 deletion cli/medperf/comms/auth/auth0.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import time
import logging
import threading
from medperf.comms.auth.interface import Auth
from medperf.comms.auth.token_verifier import verify_token
from medperf.exceptions import CommunicationError
Expand All @@ -17,6 +19,7 @@ def __init__(self):
self.domain = config.auth_domain
self.client_id = config.auth_client_id
self.audience = config.auth_audience
self._lock = threading.Lock()

def login(self, email):
"""Retrieves and stores an access token/refresh token pair from the auth0
Expand Down Expand Up @@ -149,6 +152,18 @@ def logout(self):

@property
def access_token(self):
"""Thread-safe access token retrieval"""
# TODO: lock the credentials file to have this process-safe
# If someone is preparing their dataset, and configured
# the preparation to send reports async, there might be a
# risk if they tried to run other commands separately (e.g., dataset ls)
# (i.e., two processes may try to refresh an expired access token, which
# may trigger refresh token reuse since we use refresh token rotation.)
with self._lock:
return self._access_token

@property
def _access_token(self):
"""Reads and returns an access token of the currently logged
in user to be used for authorizing requests to the MedPerf server.
Refresh the token if necessary.
Expand Down Expand Up @@ -187,6 +202,7 @@ def __refresh_access_token(self, refresh_token):
"refresh_token": refresh_token,
}
token_issued_at = time.time()
logging.debug("Refreshing access token.")
res = requests.post(url=url, headers=headers, data=body)

if res.status_code != 200:
Expand All @@ -204,8 +220,8 @@ def __refresh_access_token(self, refresh_token):
access_token,
refresh_token,
id_token_payload,
token_expires_in,
token_issued_at,
token_expires_in,
)

return access_token
Expand Down
2 changes: 1 addition & 1 deletion cli/medperf/tests/comms/test_auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,5 @@ def test_refresh_token_sets_new_tokens(mocker):

# Assert
spy.assert_called_once_with(
access_token, refresh_token, id_token_payload, expires_in, ANY
access_token, refresh_token, id_token_payload, ANY, expires_in
)

0 comments on commit d1581dc

Please sign in to comment.