From 9cf903eb4fc50dc0d4314d7b14503e9a9f1ba6f5 Mon Sep 17 00:00:00 2001 From: Kert Date: Mon, 24 Apr 2023 17:08:26 -0700 Subject: [PATCH] Refactor Raspi launcher (#119) This makes the following changes * Add test coverage to Raspi launcher * Adds a standalone Python retry module Adds a `retry` module to use from AbstractLauncher code. This is intended to clean up retry handling in Raspberry Pi launcher specifically. * Refactor and fix up Raspi launcher retry code Refactors launcher code to use shared `retry` module code, and re-applies fixes from reverted #106 after clean-up and refactor of retry code. b/279249837 b/278447902 b/240984469 --- starboard/raspi/shared/launcher.py | 196 +++++++++++------- starboard/raspi/shared/launcher_test.py | 208 +++++++++++++++++++ starboard/raspi/shared/retry.py | 121 +++++++++++ starboard/raspi/shared/retry_test.py | 257 ++++++++++++++++++++++++ 4 files changed, 705 insertions(+), 77 deletions(-) create mode 100644 starboard/raspi/shared/launcher_test.py create mode 100644 starboard/raspi/shared/retry.py create mode 100644 starboard/raspi/shared/retry_test.py diff --git a/starboard/raspi/shared/launcher.py b/starboard/raspi/shared/launcher.py index 5b3951683bd3..9d831a893c73 100644 --- a/starboard/raspi/shared/launcher.py +++ b/starboard/raspi/shared/launcher.py @@ -22,14 +22,16 @@ import sys import threading import time +import contextlib import _env # pylint: disable=unused-import import pexpect from starboard.tools import abstract_launcher +from starboard.raspi.shared import retry # pylint: disable=unused-argument -def _SigIntOrSigTermHandler(signum, frame): +def _sigint_or_sigterm_handler(signum, frame): """Clean up and exit with status |signum|. Args: @@ -42,7 +44,7 @@ def _SigIntOrSigTermHandler(signum, frame): # First call returns True, otherwise return false. -def FirstRun(): +def first_run(): v = globals() if not v.has_key('first_run'): v['first_run'] = False @@ -64,16 +66,35 @@ class Launcher(abstract_launcher.AbstractLauncher): # pexpect times out each second to allow Kill to quickly stop a test run _PEXPECT_TIMEOUT = 1 - # Wait up to 30 seconds for the password prompt from the raspi - _PEXPECT_PASSWORD_TIMEOUT_MAX_RETRIES = 30 + # SSH shell command retries + _PEXPECT_SPAWN_RETRIES = 20 + + # pexpect.sendline retries + _PEXPECT_SENDLINE_RETRIES = 3 + + # Old process kill retries + _KILL_RETRIES = 3 + + _PEXPECT_SHUTDOWN_SLEEP_TIME = 3 + # Time to wait after processes were killed + _PROCESS_KILL_SLEEP_TIME = 10 + + # Retrys for getting a clean prompt + _PROMPT_WAIT_MAX_RETRIES = 5 + # Wait up to 10 seconds for the password prompt from the raspi + _PEXPECT_PASSWORD_TIMEOUT_MAX_RETRIES = 10 # Wait up to 900 seconds for new output from the raspi _PEXPECT_READLINE_TIMEOUT_MAX_RETRIES = 900 # Delay between subsequent SSH commands - _INTER_COMMAND_DELAY_SECONDS = 0.5 + _INTER_COMMAND_DELAY_SECONDS = 1.5 # This is used to strip ansi color codes from pexpect output. _PEXPECT_SANITIZE_LINE_RE = re.compile(r'\x1b[^m]*m') + # Exceptions to retry + _RETRY_EXCEPTIONS = (pexpect.TIMEOUT, pexpect.ExceptionPexpect, + pexpect.exceptions.EOF, OSError) + def __init__(self, platform, target_name, config, device_id, **kwargs): # pylint: disable=super-with-arguments super(Launcher, self).__init__(platform, target_name, config, device_id, @@ -101,14 +122,17 @@ def __init__(self, platform, target_name, config, device_id, **kwargs): self.log_targets = kwargs.get('log_targets', True) - signal.signal(signal.SIGINT, functools.partial(_SigIntOrSigTermHandler)) - signal.signal(signal.SIGTERM, functools.partial(_SigIntOrSigTermHandler)) + signal.signal(signal.SIGINT, functools.partial(_sigint_or_sigterm_handler)) + signal.signal(signal.SIGTERM, functools.partial(_sigint_or_sigterm_handler)) self.last_run_pexpect_cmd = '' def _InitPexpectCommands(self): """Initializes all of the pexpect commands needed for running the test.""" + # Ensure no trailing slashes + self.out_directory = self.out_directory.rstrip('/') + test_dir = os.path.join(self.out_directory, 'deploy', self.target_name) test_file = self.target_name @@ -154,6 +178,18 @@ def _InitPexpectCommands(self): test_success_output, test_failure_output) + # pylint: disable=no-method-argument + def _CommandBackoff(): + time.sleep(Launcher._INTER_COMMAND_DELAY_SECONDS) + + def _ShutdownBackoff(self): + Launcher._CommandBackoff() + return self.shutdown_initiated.is_set() + + @retry.retry( + exceptions=_RETRY_EXCEPTIONS, + retries=_PEXPECT_SPAWN_RETRIES, + backoff=_CommandBackoff) def _PexpectSpawnAndConnect(self, command): """Spawns a process with pexpect and connect to the raspi. @@ -166,70 +202,67 @@ def _PexpectSpawnAndConnect(self, command): command, timeout=Launcher._PEXPECT_TIMEOUT) # Let pexpect output directly to our output stream self.pexpect_process.logfile_read = self.output_file - retry_count = 0 expected_prompts = [ r'.*Are\syou\ssure.*', # Fingerprint verification r'.* password:', # Password prompt '.*[a-zA-Z]+.*', # Any other text input ] - while True: - try: - i = self.pexpect_process.expect(expected_prompts) - if i == 0: - self._PexpectSendLine('yes') - elif i == 1: - self._PexpectSendLine(Launcher._RASPI_PASSWORD) - break - else: - # If any other input comes in, maybe we've logged in with rsa key or - # raspi does not have password. Check if we've logged in by echoing - # a special sentence and expect it back. - self._PexpectSendLine('echo ' + Launcher._SSH_LOGIN_SIGNAL) - i = self.pexpect_process.expect([Launcher._SSH_LOGIN_SIGNAL]) - break - except pexpect.TIMEOUT: - if self.shutdown_initiated.is_set(): - return - retry_count += 1 - # Check if the max retry count has been exceeded. If it has, then - # re-raise the timeout exception. - if retry_count > Launcher._PEXPECT_PASSWORD_TIMEOUT_MAX_RETRIES: - raise + # pylint: disable=unnecessary-lambda + @retry.retry( + exceptions=Launcher._RETRY_EXCEPTIONS, + retries=Launcher._PEXPECT_PASSWORD_TIMEOUT_MAX_RETRIES, + backoff=lambda: self._ShutdownBackoff(), + wrap_exceptions=False) + def _inner(): + i = self.pexpect_process.expect(expected_prompts) + if i == 0: + self._PexpectSendLine('yes') + elif i == 1: + self._PexpectSendLine(Launcher._RASPI_PASSWORD) + else: + # If any other input comes in, maybe we've logged in with rsa key or + # raspi does not have password. Check if we've logged in by echoing + # a special sentence and expect it back. + self._PexpectSendLine('echo ' + Launcher._SSH_LOGIN_SIGNAL) + i = self.pexpect_process.expect([Launcher._SSH_LOGIN_SIGNAL]) + + _inner() + + @retry.retry( + exceptions=_RETRY_EXCEPTIONS, + retries=_PEXPECT_SENDLINE_RETRIES, + wrap_exceptions=False) def _PexpectSendLine(self, cmd): """Send lines to Pexpect and record the last command for logging purposes""" + logging.info('sending >> : %s ', cmd) self.last_run_pexpect_cmd = cmd self.pexpect_process.sendline(cmd) def _PexpectReadLines(self): """Reads all lines from the pexpect process.""" - - retry_count = 0 - while True: - try: + # pylint: disable=unnecessary-lambda + @retry.retry( + exceptions=Launcher._RETRY_EXCEPTIONS, + retries=Launcher._PEXPECT_READLINE_TIMEOUT_MAX_RETRIES, + backoff=lambda: self.shutdown_initiated.is_set(), + wrap_exceptions=False) + def _readloop(): + while True: # Sanitize the line to remove ansi color codes. line = Launcher._PEXPECT_SANITIZE_LINE_RE.sub( '', self.pexpect_process.readline()) self.output_file.flush() if not line: - break + return # Check for the test complete tag. It will be followed by either a # success or failure tag. if line.startswith(self.test_complete_tag): if line.find(self.test_success_tag) != -1: self.return_value = 0 - break - # A line was successfully read without timing out; reset the retry - # count before attempting to read the next line. - retry_count = 0 - except pexpect.TIMEOUT: - if self.shutdown_initiated.is_set(): return - retry_count += 1 - # Check if the max retry count has been exceeded. If it has, then - # re-raise the timeout exception. - if retry_count > Launcher._PEXPECT_READLINE_TIMEOUT_MAX_RETRIES: - raise + + _readloop() def _Sleep(self, val): self._PexpectSendLine('sleep {};echo {}'.format(val, @@ -243,39 +276,38 @@ def _CleanupPexpectProcess(self): # Check if kernel logged OOM kill or any other system failure message if self.return_value: logging.info('Sending dmesg') - self._PexpectSendLine('dmesg -P --color=never | tail -n 100') - time.sleep(3) - try: + with contextlib.suppress(Launcher._RETRY_EXCEPTIONS): + self._PexpectSendLine('dmesg -P --color=never | tail -n 100') + time.sleep(self._PEXPECT_SHUTDOWN_SLEEP_TIME) + with contextlib.suppress(Launcher._RETRY_EXCEPTIONS): self.pexpect_process.readlines() - except pexpect.TIMEOUT: - logging.info('Timeout exception during cleanup command: %s', - self.last_run_pexpect_cmd) - pass logging.info('Done sending dmesg') # Send ctrl-c to the raspi and close the process. - self._PexpectSendLine(chr(3)) - time.sleep(1) # Allow a second for normal shutdown - self.pexpect_process.close() + with contextlib.suppress(Launcher._RETRY_EXCEPTIONS): + self._PexpectSendLine(chr(3)) + time.sleep(self._PEXPECT_TIMEOUT) # Allow time for normal shutdown + with contextlib.suppress(Launcher._RETRY_EXCEPTIONS): + self.pexpect_process.close() def _WaitForPrompt(self): """Sends empty commands, until a bash prompt is returned""" - retry_count = 5 - while True: - try: - self.pexpect_process.expect(self._RASPI_PROMPT) - break - except pexpect.TIMEOUT: - logging.info('Timeout exception during WaitForPrompt command: %s', - self.last_run_pexpect_cmd) - if self.shutdown_initiated.is_set(): - return - retry_count -= 1 - if not retry_count: - raise - self._PexpectSendLine('echo ' + Launcher._SSH_SLEEP_SIGNAL) - time.sleep(self._INTER_COMMAND_DELAY_SECONDS) + def backoff(): + self._PexpectSendLine('echo ' + Launcher._SSH_SLEEP_SIGNAL) + return self._ShutdownBackoff() + + retry.with_retry( + lambda: self.pexpect_process.expect(self._RASPI_PROMPT), + exceptions=Launcher._RETRY_EXCEPTIONS, + retries=Launcher._PROMPT_WAIT_MAX_RETRIES, + backoff=backoff, + wrap_exceptions=False) + + @retry.retry( + exceptions=_RETRY_EXCEPTIONS, + retries=_KILL_RETRIES, + backoff=_CommandBackoff) def _KillExistingCobaltProcesses(self): """If there are leftover Cobalt processes, kill them. @@ -294,7 +326,7 @@ def _KillExistingCobaltProcesses(self): logging.warning('Forced to pkill existing instance(s) of cobalt. ' 'Pausing to ensure no further operations are run ' 'before processes shut down.') - time.sleep(10) + time.sleep(Launcher._PROCESS_KILL_SLEEP_TIME) logging.info('Done killing existing processes') def Run(self): @@ -330,12 +362,19 @@ def Run(self): if self.test_result_xml_path: first_run_commands.append('touch {}'.format(self.test_result_xml_path)) first_run_commands.extend(['free -mh', 'ps -ux', 'df -h']) - if FirstRun(): + if first_run(): for cmd in first_run_commands: if not self.shutdown_initiated.is_set(): self._PexpectSendLine(cmd) - line = self.pexpect_process.readline() - self.output_file.write(line) + + def _readline(): + line = self.pexpect_process.readline() + self.output_file.write(line) + + retry.with_retry( + _readline, + exceptions=Launcher._RETRY_EXCEPTIONS, + retries=Launcher._PROMPT_WAIT_MAX_RETRIES) self._WaitForPrompt() self.output_file.flush() self._Sleep(self._INTER_COMMAND_DELAY_SECONDS) @@ -346,6 +385,9 @@ def Run(self): self._PexpectSendLine(self.test_command) self._PexpectReadLines() + except retry.RetriesExceeded: + logging.exception('Command retry exceeded (cmd: %s)', + self.last_run_pexpect_cmd) except pexpect.EOF: logging.exception('pexpect encountered EOF while reading line. (cmd: %s)', self.last_run_pexpect_cmd) @@ -377,7 +419,7 @@ def Kill(self): # Initiate the shutdown. This causes the run to abort within one second. self.shutdown_initiated.set() # Wait up to three seconds for the run to be set to inactive. - self.run_inactive.wait(3) + self.run_inactive.wait(Launcher._PEXPECT_SHUTDOWN_SLEEP_TIME) def GetDeviceIp(self): """Gets the device IP.""" diff --git a/starboard/raspi/shared/launcher_test.py b/starboard/raspi/shared/launcher_test.py new file mode 100644 index 000000000000..5fe679404618 --- /dev/null +++ b/starboard/raspi/shared/launcher_test.py @@ -0,0 +1,208 @@ +# +# Copyright 2023 The Cobalt Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Raspi launcher""" + +import logging +from starboard.raspi.shared import launcher +import sys +import argparse +import unittest +import os +from unittest.mock import patch, ANY, call, Mock +import tempfile +from pathlib import Path +import pexpect + +# pylint: disable=missing-class-docstring + + +class LauncherTest(unittest.TestCase): + + def setUp(self): + self.target = 'baz' + self.device_id = '198.51.100.1' # Reserved address + # Current launcher requires real files, so we generate one + # pylint: disable=consider-using-with + self.tmpdir = tempfile.TemporaryDirectory() + target_path = os.path.join(self.tmpdir.name, 'install', self.target) + os.makedirs(target_path) + Path(os.path.join(target_path, self.target)).touch() + # Minimal set of params required to crete one + self.params = { + 'device_id': self.device_id, + 'platform': 'raspi-2', + 'target_name': self.target, + 'config': 'test', + 'out_directory': self.tmpdir.name + } + self.fake_timeout = 0.11 + + # pylint: disable=protected-access + def _make_launcher(self): + launcher.Launcher._PEXPECT_TIMEOUT = self.fake_timeout + launcher.Launcher._PEXPECT_PASSWORD_TIMEOUT_MAX_RETRIES = 0 + launcher.Launcher._PEXPECT_SHUTDOWN_SLEEP_TIME = 0.12 + launcher.Launcher._INTER_COMMAND_DELAY_SECONDS = 0.013 + launcher.Launcher._PEXPECT_READLINE_TIMEOUT_MAX_RETRIES = 2 + launch = launcher.Launcher(**self.params) + return launch + + +class LauncherAPITest(LauncherTest): + + def test_construct(self): + launch = self._make_launcher() + self.assertIsNotNone(launch) + self.assertEqual(launch.device_id, self.device_id) + self.assertEqual(launch.platform_name, 'raspi-2') + self.assertEqual(launch.target_name, self.target) + self.assertEqual(launch.config, 'test') + self.assertEqual(launch.out_directory, self.tmpdir.name) + + def test_run(self): + result = self._make_launcher().Run() + # Expect test failure + self.assertEqual(result, 1) + + def test_ip(self): + self.assertEqual(self._make_launcher().GetDeviceIp(), self.device_id) + + def test_output(self): + # The path is hardcoded in the launcher + self.assertEqual(self._make_launcher().GetDeviceOutputPath(), '/tmp') + + def test_kill(self): + self.assertIsNone(self._make_launcher().Kill()) + + +class StringContains(str): + + def __eq__(self, value): + return self in value + + +# Tests here test implementation details, rather than behavior. +# pylint: disable=protected-access +class LauncherInternalsTest(LauncherTest): + + def setUp(self): + super().setUp() + self.launch = self._make_launcher() + self.launch.pexpect_process = Mock( + spec_set=['expect', 'sendline', 'readline']) + + @patch('starboard.raspi.shared.launcher.pexpect.spawn') + def test_spawn(self, spawn): + mock_pexpect = spawn.return_value + self.launch._PexpectSpawnAndConnect('echo test') + spawn.assert_called_once_with('echo test', timeout=ANY, encoding=ANY) + mock_pexpect.sendline.assert_called_once_with( + 'echo cobalt-launcher-login-success') + mock_pexpect.expect.assert_any_call(['cobalt-launcher-login-success']) + + def test_sleep(self): + self.launch._Sleep(42) + self.launch.pexpect_process.sendline.assert_called_once_with( + 'sleep 42;echo cobalt-launcher-done-sleeping') + self.launch.pexpect_process.expect.assert_called_once_with( + ['cobalt-launcher-done-sleeping']) + + def test_waitforconnect(self): + self.launch._WaitForPrompt() + self.launch.pexpect_process.expect.assert_called_once_with( + 'pi@raspberrypi:') + + # trigger one timeout + self.launch.pexpect_process.expect = Mock( + side_effect=[pexpect.TIMEOUT(1), None]) + self.launch._WaitForPrompt() + self.launch.pexpect_process.expect.assert_has_calls([ + call('pi@raspberrypi:'), + call('pi@raspberrypi:'), + ]) + + # infinite timeout + self.launch.pexpect_process.expect = Mock(side_effect=pexpect.TIMEOUT(1)) + with self.assertRaises(pexpect.TIMEOUT): + self.launch._WaitForPrompt() + + def test_readlines(self): + # Return empty string + self.launch.pexpect_process.readline = Mock(return_value='') + self.launch._PexpectReadLines() + self.launch.pexpect_process.readline.assert_called_once() + self.assertIsNone(getattr(self.launch, 'return_value', None)) + + # Return default success tag + self.launch.pexpect_process.readline = Mock( + return_value=self.launch.test_complete_tag) + self.launch._PexpectReadLines() + self.launch.pexpect_process.readline.assert_called_once() + # This is a bug + self.assertIsNone(getattr(self.launch, 'return_value', None)) + + line = self.launch.test_complete_tag + self.launch.test_success_tag + self.launch.pexpect_process.readline = Mock(return_value=line) + self.launch._PexpectReadLines() + self.assertEqual(self.launch.return_value, 0) + + self.launch.pexpect_process.readline = Mock(side_effect=pexpect.TIMEOUT(1)) + with self.assertRaises(pexpect.TIMEOUT): + self.launch._PexpectReadLines() + + def test_readlines_multiple(self): + self.launch.pexpect_process.readline = Mock(side_effect=['abc', 'bbc', '']) + self.launch._PexpectReadLines() + self.assertEqual(3, self.launch.pexpect_process.readline.call_count) + + self.launch.pexpect_process.readline = Mock( + side_effect=['abc', 'bbc', '', 'none']) + self.launch._PexpectReadLines() + self.assertEqual(3, self.launch.pexpect_process.readline.call_count) + + def test_kill_processes(self): + self.launch._KillExistingCobaltProcesses() + self.launch.pexpect_process.sendline.assert_any_call( + StringContains('pkill')) + + @patch('starboard.raspi.shared.launcher.pexpect.spawn') + def test_run_with_mock(self, spawn): + pexpect_ = Mock() + pexpect_.readline = Mock(return_value='') + spawn.return_value = pexpect_ + self.launch.Run() + self.assertEqual(self.launch.return_value, 1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('device_id') + parser.add_argument('--target', default='eztime_test') + parser.add_argument('--out_directory') + parser.add_argument('--config', default='devel') + parser.add_argument('--verbose', '-v', action='store_true') + args = parser.parse_args() + logging.basicConfig( + stream=sys.stdout, level=logging.DEBUG if args.verbose else logging.INFO) + path = os.path.join( + os.path.dirname(launcher.__file__), f'../../../out/raspi-2_{args.config}') + logging.info('path: %s', path) + launch_test = launcher.Launcher( + platform='raspi-2', + target_name=args.target, + config=args.config, + device_id=args.device_id, + out_directory=path) + launch_test.Run() diff --git a/starboard/raspi/shared/retry.py b/starboard/raspi/shared/retry.py new file mode 100644 index 000000000000..1e8fd0e0483d --- /dev/null +++ b/starboard/raspi/shared/retry.py @@ -0,0 +1,121 @@ +# +# Copyright 2023 The Cobalt Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""General retry wrapper module + +Allows retrying a function call either with a decorator or inline call. +This is a substitute for more comprehensive Python retry wrapper packages like +`tenacity`, `retry`, `backoff` and others. +The only reason this exists is that Python package deployment for on-device +tests cannot currently dynamically include dependencies. +TODO(b/279249837): Remove this and use an off the shelf package. +""" + +from typing import Sequence, Callable +import functools +import logging + + +class RetriesExceeded(RuntimeError): + """Exception recording retry failure conditions""" + + def __init__(self, retries: int, function: Callable, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.retries = retries + self.function = function + + def __str__(self) -> str: + callable_str = getattr(self.function, '__name__', repr(self.function)) + return (f'Retries exceeded while calling {callable_str}' + f' with max {self.retries}') + super().__str__() + + +def _retry_function(function: Callable, exceptions: Sequence, retries: int, + backoff: Callable, wrap_exceptions: bool): + current_retry = 0 + while current_retry <= retries: + try: + return function() + except exceptions as inner: + current_retry += 1 + logging.debug('Exception running %s, retry %d/%d', function, retry, + retries) + if current_retry > retries: + # If 0 retries were attempted, pass up original exception + if not retries or not wrap_exceptions: + raise + raise RetriesExceeded(retries, function) from inner + if backoff: + if backoff(): + raise StopIteration() from inner + + raise RuntimeError('Bug: we should never get here') + + +def with_retry(function: Callable, + args: tuple = (), + kwargs: dict = None, + exceptions: Sequence = (Exception,), + retries: int = 0, + backoff: Callable = None, + wrap_exceptions: bool = True): + """Call a function with retry on exception + + :param args: Called function positional args. + :param kwargs: Called function named args. + :param exceptions: Sequence of exception types that will be retried. + :param retries: Max retries attempted. + :param backoff: Optional backoff callable. Truthy return from callable + terminates the loop. + :param wrap_exceptions: If true ( default ) wrap underlying exceptions in + RetriesExceeded exception type + : + """ + return _retry_function( + functools.partial(function, *args, **(kwargs if kwargs else {})), + exceptions=exceptions, + retries=retries, + backoff=backoff, + wrap_exceptions=wrap_exceptions, + ) + + +def retry(exceptions: Sequence = (Exception,), + retries: int = 0, + backoff: Callable = None, + wrap_exceptions: bool = True): + """Decorator for self-retrying function on thrown exception + + :param exceptions: Sequence of exception types that will be retried. + :param retries: Max retries attempted. + :param backoff: Optional backoff callable. Truthy return from callable + terminates the loop. + :param wrap_exceptions: If true ( default ) wrap underlying exceptions in + RetriesExceeded exception type + """ + + def decorator(function): + + @functools.wraps(function) + def wrapper(*args, **kwargs): + return _retry_function( + functools.partial(function, *args, **kwargs), + exceptions=exceptions, + retries=retries, + backoff=backoff, + wrap_exceptions=wrap_exceptions) + + return wrapper + + return decorator diff --git a/starboard/raspi/shared/retry_test.py b/starboard/raspi/shared/retry_test.py new file mode 100644 index 000000000000..465edda5e265 --- /dev/null +++ b/starboard/raspi/shared/retry_test.py @@ -0,0 +1,257 @@ +# +# Copyright 2023 The Cobalt Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for `retry` module""" + +import unittest +from starboard.raspi.shared import retry +import argparse +import logging +import sys +import time +from enum import IntEnum + + +class Behavior(IntEnum): + OK = 0 + OS_ERROR = 1 + RUNTIME_ERROR = 2 + OTHER_ERROR = 3 + + +def _problem(param: int, caller=str): + logging.info('%s: param=%d', caller, param) + if param == Behavior.OS_ERROR: + raise OSError('OS made an oops') + if param == Behavior.RUNTIME_ERROR: + raise RuntimeError('Runtime oops') + if param == Behavior.OTHER_ERROR: + raise MemoryError('Download more RAM') + return 100 + param * 3 + + +def problem(param: int): + return _problem(param, 'undecorated problem') + + +@retry.retry(exceptions=(RuntimeError,), retries=1) +def decorated_runtimeerror(param: int): + return _problem(param, 'decorated with runtimeerror') + + +@retry.retry(exceptions=(OSError,), retries=1) +def decorated_oserror(param: int): + return _problem(param, 'decorated with oserror') + + +@retry.retry(exceptions=(OSError, RuntimeError), retries=1) +def decorated_both(param: int): + return _problem(param, 'decorated with oserror+runtimeerror') + + +@retry.retry( + exceptions=(OSError,), + retries=2, + backoff=lambda: (logging.info('sleeping 0.2'), time.sleep(0.2))) +def decorated_oserror_backoff_2(param: int): + return _problem(param, 'decorated with oserror, 2 retries and sleep backoff') + + +class RetryTest(unittest.TestCase): + + def setUp(self) -> None: + self.actual_calls = 0 + self.call_counter = 0 + return super().setUp() + + def problem(self, param): + self.actual_calls += 1 + return _problem(param, 'undecorated problem method') + + @retry.retry(exceptions=(OSError,), retries=1) + def decorated_os_problem(self, param): + self.actual_calls += 1 + return _problem(param, 'decorated problem method') + + @retry.retry(exceptions=(OSError,), retries=1, wrap_exceptions=False) + def decorated_os_problem_nowrap(self, param): + self.actual_calls += 1 + return _problem(param, 'decorated problem method, pass-through exceptions') + + @retry.retry(exceptions=(OSError,), retries=5) + def decorated_os_problem_3(self, param): + self.actual_calls += 1 + self.call_counter += 1 + if self.call_counter == 3: + return 200 + return _problem(param, 'decorated problem that succeeds on 3rd try') + + def test_ok_call_undecorated(self): + self.assertEqual(100, retry.with_retry(problem, (Behavior.OK,))) + self.assertEqual(100, retry.with_retry(self.problem, (Behavior.OK,))) + self.assertEqual(self.actual_calls, 1) + + def test_ok_call_decorated(self): + self.assertEqual(100, decorated_both(Behavior.OK)) + self.assertEqual(100, self.decorated_os_problem(Behavior.OK)) + self.assertEqual(self.actual_calls, 1) + + def test_retry_exceeds(self): + with self.assertRaises(OSError): + retry.with_retry(problem, (Behavior.OS_ERROR,), retries=0) + with self.assertRaises(retry.RetriesExceeded): + retry.with_retry(problem, (Behavior.OS_ERROR,), retries=1) + with self.assertRaises(retry.RetriesExceeded): + retry.with_retry(problem, (Behavior.OS_ERROR,), retries=50) + with self.assertRaises(retry.RetriesExceeded): + retry.with_retry(self.problem, (Behavior.OS_ERROR,), retries=1) + self.assertEqual(self.actual_calls, 2) + + def test_retry_exceeds_decorated(self): + with self.assertRaises(retry.RetriesExceeded): + decorated_oserror(Behavior.OS_ERROR) + with self.assertRaises(retry.RetriesExceeded): + self.decorated_os_problem(Behavior.OS_ERROR) + self.assertEqual(self.actual_calls, 2) + with self.assertRaises(retry.RetriesExceeded): + decorated_runtimeerror(Behavior.RUNTIME_ERROR) + + def test_other_exceptions_propagate(self): + with self.assertRaises(RuntimeError): + retry.with_retry( + problem, (Behavior.RUNTIME_ERROR,), + exceptions=(OSError, MemoryError), + retries=0) + with self.assertRaises(RuntimeError): + retry.with_retry( + problem, (Behavior.RUNTIME_ERROR,), + exceptions=(OSError, MemoryError), + retries=4) + with self.assertRaises(RuntimeError): + retry.with_retry( + self.problem, (Behavior.RUNTIME_ERROR,), + exceptions=(OSError, MemoryError), + retries=0) + self.assertEqual(self.actual_calls, 1) + with self.assertRaises(RuntimeError): + retry.with_retry( + self.problem, (Behavior.RUNTIME_ERROR,), + exceptions=(OSError, MemoryError), + retries=50) + self.assertEqual(self.actual_calls, 2) + + def test_original_exceptions(self): + with self.assertRaises(OSError): + retry.with_retry( + problem, (Behavior.OS_ERROR,), retries=0, wrap_exceptions=False) + with self.assertRaises(OSError): + retry.with_retry( + problem, (Behavior.OS_ERROR,), retries=1, wrap_exceptions=False) + with self.assertRaises(OSError): + retry.with_retry( + problem, (Behavior.OS_ERROR,), retries=50, wrap_exceptions=False) + with self.assertRaises(OSError): + self.decorated_os_problem_nowrap(Behavior.OS_ERROR) + + def test_call_can_succeed_1(self): + self.assertEqual(100, self.decorated_os_problem_3(Behavior.OK)) + self.assertEqual(self.actual_calls, 1) + + def test_call_can_succeed_2(self): + self.assertEqual(200, self.decorated_os_problem_3(Behavior.OS_ERROR)) + self.assertEqual(self.actual_calls, 3) + with self.assertRaises(RuntimeError): # ensure other errors still throw + self.decorated_os_problem_3(Behavior.RUNTIME_ERROR) + + def test_backoff_gets_called(self): + backoff_calls = 0 + + def inrcement(): + nonlocal backoff_calls + backoff_calls += 1 + + with self.assertRaises(RuntimeError): + retry.with_retry( + self.problem, (Behavior.RUNTIME_ERROR,), + exceptions=(OSError), + retries=2, + backoff=inrcement) + self.assertEqual(backoff_calls, 0) + with self.assertRaises(RuntimeError): + retry.with_retry( + self.problem, (Behavior.RUNTIME_ERROR,), + exceptions=(RuntimeError), + retries=2, + backoff=inrcement) + self.assertEqual(backoff_calls, 2) + + def test_backoff_terminates_loop(self): + with self.assertRaises(StopIteration): + retry.with_retry( + self.problem, (Behavior.RUNTIME_ERROR,), + exceptions=(RuntimeError), + retries=2, + backoff=lambda: True) + + def test_exception_has_details(self): + with self.assertRaises(retry.RetriesExceeded) as context: + retry.with_retry(problem, (Behavior.OS_ERROR,), retries=50) + self.assertIn('50', str(context.exception)) + self.assertIn('problem', str(context.exception)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--verbose', '-v', action='store_true') + parser.add_argument('func_behavior', type=int) + parser.add_argument('--oserror', action='store_true') + parser.add_argument('--runtimeerror', action='store_true') + parser.add_argument('--retries', type=int, default=1) + parser.add_argument('--backoff', action='store_true') + parser.add_argument('--decorated', action='store_true') + args = parser.parse_args() + logging.basicConfig( + stream=sys.stdout, level=logging.DEBUG if args.verbose else logging.INFO) + exceptions = [] + if args.oserror: + exceptions.append(OSError) + if args.runtimeerror: + exceptions.append(RuntimeError) + backoff = None if not args.backoff else lambda: (print('Backoff'), + time.sleep(1)) + if not args.decorated: + if exceptions: + print( + retry.with_retry( + problem, (args.func_behavior,), + retries=args.retries, + exceptions=tuple(exceptions), + backoff=backoff)) + else: # default, accept all exceptions + print( + retry.with_retry( + problem, (args.func_behavior,), + retries=args.retries, + backoff=backoff)) + else: + if args.backoff: + print(decorated_oserror_backoff_2(args.func_behavior)) + elif args.oserror and args.runtimeerror: + print(decorated_both(args.func_behavior)) + elif args.oserror: + print(decorated_oserror(args.func_behavior)) + elif args.runtimeerror: + print(decorated_runtimeerror(args.func_behavior)) + else: + raise NotImplementedError('No test implemented for these args')