Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make thread-safe, retain cache despite refresh #344

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions mopidy_spotify/playlists.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import threading
import time

from mopidy import backend
from mopidy.core import listener
Expand Down Expand Up @@ -27,11 +28,11 @@ def as_list(self):

return list(self._get_flattened_playlist_refs())

def _get_flattened_playlist_refs(self):
def _get_flattened_playlist_refs(self, force_refresh=True):
if not self._backend._web_client.logged_in:
return []

user_playlists = self._backend._web_client.get_user_playlists()
user_playlists = self._backend._web_client.get_user_playlists(force_refresh=force_refresh)
return translator.to_playlist_refs(
user_playlists, self._backend._web_client.user_id
)
Expand Down Expand Up @@ -64,24 +65,24 @@ def refresh(self):
logger.info("Refreshing Spotify playlists")

def refresher():
try:
with utils.time_logger("playlists.refresh()", logging.DEBUG):
_sp_links.clear()
self._backend._web_client.clear_cache()
count = 0
for playlist_ref in self._get_flattened_playlist_refs():
self._get_playlist(playlist_ref.uri)
count += 1
logger.info(f"Refreshed {count} Spotify playlists")

listener.CoreListener.send("playlists_loaded")
self._loaded = True
except Exception as e:
logger.exception(
f"An error occurred while refreshing Spotify playlists: {e}"
)
finally:
self._refreshing = False
self._loaded = False
while not self._loaded:
try:
with utils.time_logger("playlists.refresh()", logging.DEBUG):
count = 0
for playlist_ref in self._get_flattened_playlist_refs(force_refresh=True):
self._get_playlist(playlist_ref.uri)
count += 1
logger.info(f"Refreshed {count} Spotify playlists")

listener.CoreListener.send("playlists_loaded")
self._refreshing = False
self._loaded = True
except Exception as e:
logger.exception(
f"An error occurred while refreshing Spotify playlists, retrying: {e}"
)
time.sleep(3000)

thread = threading.Thread(target=refresher)
thread.daemon = True
Expand Down
145 changes: 79 additions & 66 deletions mopidy_spotify/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime
from enum import Enum, unique
from typing import Optional
from threading import Lock

import requests

Expand All @@ -20,6 +21,58 @@
def _trace(*args, **kwargs):
logger.log(utils.TRACE, *args, **kwargs)

def _format_url(base_url, url, *args, **kwargs):
b = urllib.parse.urlsplit(base_url)
u = urllib.parse.urlsplit(url.format(*args))

if u.scheme or u.netloc:
scheme, netloc, path = u.scheme, u.netloc, u.path
query = urllib.parse.parse_qsl(u.query, keep_blank_values=True)
else:
scheme, netloc = b.scheme, b.netloc
path = os.path.normpath(os.path.join(b.path, u.path))
query = urllib.parse.parse_qsl(b.query, keep_blank_values=True)
query.extend(
urllib.parse.parse_qsl(u.query, keep_blank_values=True)
)

for key, value in kwargs.items():
query.append((key, value))

encoded_query = urllib.parse.urlencode(dict(query))
return urllib.parse.urlunsplit(
(scheme, netloc, path, encoded_query, "")
)

def _normalise_query_string(url, params=None):
u = urllib.parse.urlsplit(url)
scheme, netloc, path = u.scheme, u.netloc, u.path

query = dict(urllib.parse.parse_qsl(u.query, keep_blank_values=True))
if isinstance(params, dict):
query.update(params)
sorted_unique_query = sorted(query.items())
encoded_query = urllib.parse.urlencode(sorted_unique_query)
return urllib.parse.urlunsplit(
(scheme, netloc, path, encoded_query, "")
)

def _parse_retry_after(response):
"""Parse Retry-After header from response if it is set."""
value = response.headers.get("Retry-After")

if not value:
seconds = 0
elif re.match(r"^\s*[0-9]+\s*$", value):
seconds = int(value)
else:
date_tuple = email.utils.parsedate(value)
if date_tuple is None:
seconds = 0
else:
seconds = time.mktime(date_tuple) - time.time()
return max(0, seconds)


class OAuthTokenRefreshError(Exception):
def __init__(self, reason):
Expand Down Expand Up @@ -64,32 +117,38 @@ def __init__(

self._headers = {"Content-Type": "application/json"}
self._session = utils.get_requests_session(proxy_config or {})
self._cache_mutex = Lock()
self._refresh_mutex = Lock()

def get(self, path, cache=None, *args, **kwargs):
if self._authorization_failed:
logger.debug("Blocking request as previous authorization failed.")
return WebResponse(None, None)

params = kwargs.pop("params", None)
path = self._normalise_query_string(path, params)
path = _normalise_query_string(path, params)

_trace(f"Get '{path}'")

ignore_expiry = kwargs.pop("ignore_expiry", False)
force_refresh = kwargs.pop("force_refresh", False)
if cache is not None and path in cache:
cached_result = cache.get(path)
if cached_result.still_valid(ignore_expiry):
if not force_refresh and cached_result.still_valid(ignore_expiry):
return cached_result
kwargs.setdefault("headers", {}).update(cached_result.etag_headers)

# TODO: Factor this out once we add more methods.
# TODO: Don't silently error out.
try:
self._refresh_mutex.acquire()
if self._should_refresh_token():
self._refresh_token()
except OAuthTokenRefreshError as e:
logger.error(e)
return WebResponse(None, None)
finally:
self._refresh_mutex.release()

# Make sure our headers always override user supplied ones.
kwargs.setdefault("headers", {}).update(self._headers)
Expand All @@ -102,11 +161,15 @@ def get(self, path, cache=None, *args, **kwargs):
)
return WebResponse(None, None)

if self._should_cache_response(cache, result):
previous_result = cache.get(path)
if previous_result and previous_result.updated(result):
result = previous_result
cache[path] = result
try:
self._cache_mutex.acquire()
if self._should_cache_response(cache, result):
previous_result = cache.get(path)
if previous_result and previous_result.updated(result):
result = previous_result
cache[path] = result
finally:
self._cache_mutex.release()

return result

Expand Down Expand Up @@ -150,15 +213,16 @@ def _refresh_token(self):

def _request_with_retries(self, method, url, *args, **kwargs):
prepared_request = self._session.prepare_request(
requests.Request(method, self._prepare_url(url, *args), **kwargs)
requests.Request(method, _format_url(self._base_url, url, *args), **kwargs)
)

retries = kwargs.pop('retries', self._number_of_retries)
try_until = time.time() + self._timeout

result = None
backoff_time = 0

for i in range(self._number_of_retries):
for i in range(retries):
remaining_timeout = max(try_until - time.time(), 1)

# Give up if we don't have any timeout left after sleeping.
Expand All @@ -178,7 +242,7 @@ def _request_with_retries(self, method, url, *args, **kwargs):
result = None
else:
status_code = response.status_code
backoff_time = self._parse_retry_after(response)
backoff_time = _parse_retry_after(response)
result = WebResponse.from_requests(prepared_request, response)

if status_code and 400 <= status_code < 600:
Expand Down Expand Up @@ -212,59 +276,6 @@ def _request_with_retries(self, method, url, *args, **kwargs):
)
return result

def _prepare_url(self, url, *args, **kwargs):
# TODO: Move this out as a helper and unit-test it directly?
b = urllib.parse.urlsplit(self._base_url)
u = urllib.parse.urlsplit(url.format(*args))

if u.scheme or u.netloc:
scheme, netloc, path = u.scheme, u.netloc, u.path
query = urllib.parse.parse_qsl(u.query, keep_blank_values=True)
else:
scheme, netloc = b.scheme, b.netloc
path = os.path.normpath(os.path.join(b.path, u.path))
query = urllib.parse.parse_qsl(b.query, keep_blank_values=True)
query.extend(
urllib.parse.parse_qsl(u.query, keep_blank_values=True)
)

for key, value in kwargs.items():
query.append((key, value))

encoded_query = urllib.parse.urlencode(dict(query))
return urllib.parse.urlunsplit(
(scheme, netloc, path, encoded_query, "")
)

def _normalise_query_string(self, url, params=None):
u = urllib.parse.urlsplit(url)
scheme, netloc, path = u.scheme, u.netloc, u.path

query = dict(urllib.parse.parse_qsl(u.query, keep_blank_values=True))
if isinstance(params, dict):
query.update(params)
sorted_unique_query = sorted(query.items())
encoded_query = urllib.parse.urlencode(sorted_unique_query)
return urllib.parse.urlunsplit(
(scheme, netloc, path, encoded_query, "")
)

def _parse_retry_after(self, response):
"""Parse Retry-After header from response if it is set."""
value = response.headers.get("Retry-After")

if not value:
seconds = 0
elif re.match(r"^\s*[0-9]+\s*$", value):
seconds = int(value)
else:
date_tuple = email.utils.parsedate(value)
if date_tuple is None:
seconds = 0
else:
seconds = time.mktime(date_tuple) - time.time()
return max(0, seconds)


class WebResponse(dict):
def __init__(self, url, data, expires=0.0, etag=None, status_code=400):
Expand Down Expand Up @@ -433,9 +444,11 @@ def login(self):
def logged_in(self):
return self.user_id is not None

def get_user_playlists(self):
def get_user_playlists(self, **kwargs):
pages = self.get_all(
f"users/{self.user_id}/playlists", params={"limit": 50}
f"users/{self.user_id}/playlists",
params={"limit": 50},
**kwargs
)
for page in pages:
yield from page.get("items", [])
Expand Down Expand Up @@ -476,7 +489,7 @@ def get_playlist(self, uri):

return playlist

def clear_cache(self, extra_expiry=None):
def clear_cache(self):
self._cache.clear()


Expand Down
3 changes: 1 addition & 2 deletions tests/test_playlists.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def test_refresh_clears_caches(provider, web_client_mock):
with ThreadJoiner(timeout=1.0):
provider.refresh()

assert "bar" not in playlists._sp_links
web_client_mock.clear_cache.assert_called_once()
assert "spotify:track:abc" in playlists._sp_links


def test_lookup(provider):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def test_should_cache_response(oauth_client, cache, ok, expected):
],
)
def test_normalise_query_string(oauth_client, path, params, expected):
result = oauth_client._normalise_query_string(path, params)
result = web._normalise_query_string(path, params)
assert result == expected


Expand Down