Skip to content

Commit

Permalink
fix: Don't handle CTRL_C_EVENT, call previous signal handler
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Aug 6, 2024
1 parent 6afd09d commit 812b84e
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import signal
import time
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -51,7 +51,6 @@ def _default_worker_name() -> str:
SIGNALS_TO_HANDLE_IF_AVAILABLE = [
"SIGINT",
"SIGTERM",
"CTRL_C_EVENT",
]


Expand Down Expand Up @@ -155,7 +154,7 @@ class DefaultWorker(Generic[Loc]):
worker_cumulative_evaluation_time_seconds: float = 0.0
"""The time spent evaluating configurations by this worker."""

_SIGNAL_HANDLER_FIRED: bool = False
_PREVIOUS_SIGNAL_HANDLERS: dict[int, signal._HANDLER] = field(default_factory=dict)

@classmethod
def new(
Expand Down Expand Up @@ -356,7 +355,8 @@ def _set_signal_handlers(self) -> None:
# HACK: Despite what python documentation says, the existance of a signal
# is not enough to guarantee that it can be caught.
with contextlib.suppress(ValueError):
signal.signal(sig, self._emergency_cleanup)
previous_signal_handler = signal.signal(sig, self._emergency_cleanup)
self._PREVIOUS_SIGNAL_HANDLERS[sig] = previous_signal_handler

def run(self) -> None: # noqa: C901, PLR0915
"""Run the worker.
Expand Down Expand Up @@ -446,9 +446,8 @@ def run(self) -> None: # noqa: C901, PLR0915
default_report_values=self.settings.default_report_values,
)
except KeyboardInterrupt as e:
if not self._SIGNAL_HANDLER_FIRED:
# This throws and we have stopped the worker at this point
self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e)
# This throws and we have stopped the worker at this point
self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e)

evaluation_duration = evaluated_trial.metadata.evaluation_duration
assert evaluation_duration is not None
Expand Down Expand Up @@ -491,12 +490,10 @@ def run(self) -> None: # noqa: C901, PLR0915
def _emergency_cleanup(
self,
signum: int,
frame: Any, # noqa: ARG002
frame: Any,
rethrow: KeyboardInterrupt | None = None,
) -> None:
"""Handle signals."""
self._SIGNAL_HANDLER_FIRED = True

global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603
logger.error(
f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!"
Expand All @@ -516,6 +513,9 @@ def _emergency_cleanup(
finally:
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None

previous_handler = self._PREVIOUS_SIGNAL_HANDLERS.get(signum)
if previous_handler is not None and callable(previous_handler):
previous_handler(signum, frame)
if rethrow is not None:
raise rethrow
raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.")
Expand Down

0 comments on commit 812b84e

Please sign in to comment.