diff --git a/test/test_base.py b/test/test_base.py index a6234b32..409a02c8 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -213,7 +213,7 @@ def __eq__(self, other): build_opener_mock.assert_called_once_with( _HTTPBasicAuthHandlerMatcher(self)) - open_mock.assert_called_once_with(url, timeout=None) + open_mock.assert_called_once_with(url, timeout=10) @mock.patch('vcstool.clients.vcs_base.urlopen', autospec=True) @mock.patch('vcstool.clients.vcs_base.build_opener', autospec=True) @@ -243,7 +243,71 @@ def __eq__(self, other): return True urlopen_mock.assert_called_once_with( - _RequestMatcher(self), timeout=None) + _RequestMatcher(self), timeout=10) + + @mock.patch('vcstool.clients.vcs_base.urlopen', autospec=True) + def test_load_url_retries(self, urlopen_mock): + urlopen_read_mock = urlopen_mock.return_value.read + urlopen_mock.side_effect = [ + HTTPError(None, 503, 'test1', None, None), + HTTPError(None, 503, 'test2', None, None), + HTTPError(None, 503, 'test3', None, None), + ] + + with self.assertRaisesRegex(HTTPError, 'test3'): + vcs_base.load_url('example.com') + + self.assertEqual(len(urlopen_mock.mock_calls), 3) + urlopen_mock.assert_has_calls([ + mock.call('example.com', timeout=10), + mock.call('example.com', timeout=10), + mock.call('example.com', timeout=10), + ]) + self.assertFalse(urlopen_read_mock.mock_calls) + + @mock.patch('vcstool.clients.vcs_base.urlopen', autospec=True) + def test_load_url_retries_authenticated(self, urlopen_mock): + urlopen_read_mock = urlopen_mock.return_value.read + urlopen_mock.side_effect = [ + HTTPError(None, 401, 'test1', None, None), + HTTPError(None, 503, 'test2', None, None), + HTTPError(None, 503, 'test3', None, None), + HTTPError(None, 503, 'test4', None, None), + ] + + machine = 'example.com' + _create_netrc_file( + os.path.join(self.default_auth_dir, '.netrc'), + textwrap.dedent('''\ + machine %s + password password + ''' % machine)) + + url = 'https://%s/foo/bar' % machine + + with self.assertRaisesRegex(HTTPError, 'test4'): + vcs_base.load_url(url) + + self.assertEqual(len(urlopen_mock.mock_calls), 4) + + class _RequestMatcher(object): + def __init__(self, test): + self.test = test + + def __eq__(self, other): + self.test.assertEqual(other.get_full_url(), url) + self.test.assertEqual( + other.get_header('Private-token'), 'password') + return True + + urlopen_mock.assert_has_calls([ + mock.call(url, timeout=10), + mock.call(_RequestMatcher(self), timeout=10), + mock.call(_RequestMatcher(self), timeout=10), + mock.call(_RequestMatcher(self), timeout=10), + ]) + self.assertFalse(urlopen_read_mock.mock_calls) + def _create_netrc_file(path, contents): diff --git a/vcstool/clients/vcs_base.py b/vcstool/clients/vcs_base.py index f42a700f..5091c37f 100644 --- a/vcstool/clients/vcs_base.py +++ b/vcstool/clients/vcs_base.py @@ -1,4 +1,5 @@ import errno +import functools import glob import logging import netrc @@ -112,26 +113,17 @@ def run_command(cmd, cwd, env=None): def load_url(url, retry=2, retry_period=1, timeout=10): fh = None try: - fh = urlopen(url, timeout=timeout) + fh = _retryable_urlopen(url, timeout=timeout) except HTTPError as e: - e.msg += ' (%s)' % url if e.code in (401, 404): # Try again, but with authentication fh = _authenticated_urlopen(url, timeout=timeout) - elif e.code == 503 and retry: - time.sleep(retry_period) - return load_url( - url, retry=retry - 1, retry_period=retry_period, - timeout=timeout) if fh is None: + e.msg += ' (%s)' % url raise except URLError as e: - if isinstance(e.reason, socket.timeout) and retry: - time.sleep(retry_period) - return load_url( - url, retry=retry - 1, retry_period=retry_period, - timeout=timeout) raise URLError(str(e) + ' (%s)' % url) + return fh.read() @@ -140,33 +132,52 @@ def test_url(url, retry=2, retry_period=1, timeout=10): request.get_method = lambda: 'HEAD' try: - response = urlopen(request) + response = _retryable_urlopen(request) except HTTPError as e: - if e.code == 503 and retry: - time.sleep(retry_period) - return test_url( - url, retry=retry - 1, retry_period=retry_period, - timeout=timeout) e.msg += ' (%s)' % url raise except URLError as e: - if isinstance(e.reason, socket.timeout) and retry: - time.sleep(retry_period) - return test_url( - url, retry=retry - 1, retry_period=retry_period, - timeout=timeout) raise URLError(str(e) + ' (%s)' % url) return response -def _authenticated_urlopen(uri, timeout=None): +def _urlopen_retry(f): + @functools.wraps(f) + def _retryable_function(url, retry=2, retry_period=1, timeout=10): + retry += 1 + + while True: + try: + retry -= 1 + return f(url, timeout=timeout) + except HTTPError as e: + if e.code != 503 or retry <= 0: + raise + except URLError as e: + if not isinstance(e.reason, socket.timeout) or retry <= 0: + raise + + if retry > 0: + time.sleep(retry_period) + else: + break + + return _retryable_function + + +@_urlopen_retry +def _retryable_urlopen(url, timeout=10): + return urlopen(url, timeout=timeout) + + +@_urlopen_retry +def _authenticated_urlopen(uri, timeout=10): machine = urlparse(uri).netloc if not machine: return None credentials = _credentials_for_machine(machine) if credentials is None: - logger.warning('No credentials found for "%s"' % machine) return None (username, account, password) = credentials