From 812b84e0db14672e5edef91642f07bdaca5cb93d Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 6 Aug 2024 16:31:50 +0200 Subject: [PATCH] fix: Don't handle CTRL_C_EVENT, call previous signal handler --- neps/runtime.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 0baa4b7a..cd733dc8 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -9,7 +9,7 @@ import signal import time from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import ( TYPE_CHECKING, @@ -51,7 +51,6 @@ def _default_worker_name() -> str: SIGNALS_TO_HANDLE_IF_AVAILABLE = [ "SIGINT", "SIGTERM", - "CTRL_C_EVENT", ] @@ -155,7 +154,7 @@ class DefaultWorker(Generic[Loc]): worker_cumulative_evaluation_time_seconds: float = 0.0 """The time spent evaluating configurations by this worker.""" - _SIGNAL_HANDLER_FIRED: bool = False + _PREVIOUS_SIGNAL_HANDLERS: dict[int, signal._HANDLER] = field(default_factory=dict) @classmethod def new( @@ -356,7 +355,8 @@ def _set_signal_handlers(self) -> None: # HACK: Despite what python documentation says, the existance of a signal # is not enough to guarantee that it can be caught. with contextlib.suppress(ValueError): - signal.signal(sig, self._emergency_cleanup) + previous_signal_handler = signal.signal(sig, self._emergency_cleanup) + self._PREVIOUS_SIGNAL_HANDLERS[sig] = previous_signal_handler def run(self) -> None: # noqa: C901, PLR0915 """Run the worker. @@ -446,9 +446,8 @@ def run(self) -> None: # noqa: C901, PLR0915 default_report_values=self.settings.default_report_values, ) except KeyboardInterrupt as e: - if not self._SIGNAL_HANDLER_FIRED: - # This throws and we have stopped the worker at this point - self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e) + # This throws and we have stopped the worker at this point + self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e) evaluation_duration = evaluated_trial.metadata.evaluation_duration assert evaluation_duration is not None @@ -491,12 +490,10 @@ def run(self) -> None: # noqa: C901, PLR0915 def _emergency_cleanup( self, signum: int, - frame: Any, # noqa: ARG002 + frame: Any, rethrow: KeyboardInterrupt | None = None, ) -> None: """Handle signals.""" - self._SIGNAL_HANDLER_FIRED = True - global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603 logger.error( f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!" @@ -516,6 +513,9 @@ def _emergency_cleanup( finally: _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None + previous_handler = self._PREVIOUS_SIGNAL_HANDLERS.get(signum) + if previous_handler is not None and callable(previous_handler): + previous_handler(signum, frame) if rethrow is not None: raise rethrow raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.")