diff --git a/law/patches.py b/law/patches.py index e6a83bd5..a9d9d01b 100644 --- a/law/patches.py +++ b/law/patches.py @@ -13,6 +13,7 @@ import logging import luigi +import law logger = logging.getLogger(__name__) @@ -20,10 +21,6 @@ _patched = False -_sandbox_switched = os.getenv("LAW_SANDBOX_SWITCHED", "") == "1" - -_sandbox_task_id = os.getenv("LAW_SANDBOX_WORKER_TASK", "") - def patch_all(): """ @@ -80,7 +77,7 @@ def patch_worker_add_task(): _add_task = luigi.worker.Worker._add_task def add_task(self, *args, **kwargs): - if _sandbox_switched and "deps" in kwargs: + if law.sandbox.base._sandbox_switched and "deps" in kwargs: kwargs["deps"] = None return _add_task(self, *args, **kwargs) @@ -100,8 +97,8 @@ def patch_worker_add(): def add(self, task, *args, **kwargs): # _add returns a generator, which we simply drain here # when we are in a sandbox - if _sandbox_switched: - task.task_id = _sandbox_task_id + if law.sandbox.base._sandbox_switched: + task.task_id = law.sandbox.base._sandbox_task_id for _ in _add(self, task, *args, **kwargs): pass return [] @@ -124,16 +121,16 @@ def run_task(self, task_id): task = self._scheduled_tasks[task_id] task._worker_id = self._id - task._worker_task = self._first_task + task._worker_first_task_id = self._first_task try: _run_task(self, task_id) finally: task._worker_id = None - task._worker_task = None + task._worker_first_task_id = None # make worker disposable when sandboxed - if _sandbox_switched: + if law.sandbox.base._sandbox_switched: self._start_phasing_out() luigi.worker.Worker._run_task = run_task @@ -151,10 +148,10 @@ def patch_worker_get_work(): _get_work = luigi.worker.Worker._get_work def get_work(self): - if _sandbox_switched: + if law.sandbox.base._sandbox_switched: # when the worker is configured to stop requesting work, as triggered by the patched # _run_task method (see above), the worker response should contain an empty task_id - task_id = None if self._stop_requesting_work else os.environ["LAW_SANDBOX_WORKER_TASK"] + task_id = None if self._stop_requesting_work else law.sandbox.base._sandbox_task_id return luigi.worker.GetWorkResponse( task_id=task_id, running_tasks=[], @@ -178,8 +175,8 @@ def patch_worker_factory(): """ def create_worker(self, scheduler, worker_processes, assistant=False): worker = luigi.worker.Worker(scheduler=scheduler, worker_processes=worker_processes, - assistant=assistant, worker_id=os.getenv("LAW_SANDBOX_WORKER_ID")) - worker._first_task = os.getenv("LAW_SANDBOX_WORKER_ROOT_TASK") + assistant=assistant, worker_id=law.sandbox.base._sandbox_worker_id or None) + worker._first_task = law.sandbox.base._sandbox_worker_first_task_id or None return worker luigi.interface._WorkerSchedulerFactory.create_worker = create_worker @@ -196,7 +193,7 @@ def patch_keepalive_run(): def run(self): # do not run the keep-alive loop when sandboxed - if _sandbox_switched: + if law.sandbox.base._sandbox_switched: self.stop() else: _run(self) diff --git a/law/sandbox/base.py b/law/sandbox/base.py index 8b6d4016..318442a8 100644 --- a/law/sandbox/base.py +++ b/law/sandbox/base.py @@ -38,19 +38,29 @@ _sandbox_switched = os.getenv("LAW_SANDBOX_SWITCHED", "") == "1" +_sandbox_task_id = os.getenv("LAW_SANDBOX_TASK_ID", "") + +_sandbox_worker_id = os.getenv("LAW_SANDBOX_WORKER_ID", "") + +_sandbox_worker_first_task_id = os.getenv("LAW_SANDBOX_WORKER_FIRST_TASK_ID", "") + _sandbox_is_root_task = os.getenv("LAW_SANDBOX_IS_ROOT_TASK", "") == "1" _sandbox_stagein_dir = os.getenv("LAW_SANDBOX_STAGEIN_DIR", "") _sandbox_stageout_dir = os.getenv("LAW_SANDBOX_STAGEOUT_DIR", "") -_sandbox_task_id = os.getenv("LAW_SANDBOX_WORKER_TASK", "") - -_sandbox_root_task_id = os.getenv("LAW_SANDBOX_WORKER_ROOT_TASK", "") -# the task id must be set when in a sandbox -if not _sandbox_task_id and _sandbox_switched: - raise Exception("LAW_SANDBOX_WORKER_TASK must not be empty in a sandbox") +# certain values must be present in a sandbox +if _sandbox_switched: + if not _current_sandbox or not _current_sandbox[0]: + raise Exception("LAW_SANDBOX must not be empty in a sandbox") + if not _sandbox_task_id: + raise Exception("LAW_SANDBOX_TASK_ID must not be empty in a sandbox") + elif not _sandbox_worker_id: + raise Exception("LAW_SANDBOX_WORKER_ID must not be empty in a sandbox") + elif not _sandbox_worker_first_task_id: + raise Exception("LAW_SANDBOX_WORKER_FIRST_TASK_ID must not be empty in a sandbox") class StageInfo(object): @@ -172,12 +182,13 @@ def _get_env(self): env["LAW_SANDBOX"] = self.key.replace("$", r"\$") env["LAW_SANDBOX_SWITCHED"] = "1" if self.task: + env["LAW_SANDBOX_TASK_ID"] = self.task.live_task_id + env["LAW_SANDBOX_ROOT_TASK_ID"] = root_task().task_id + env["LAW_SANDBOX_IS_ROOT_TASK"] = str(int(self.task.is_root_task())) if getattr(self.task, "_worker_id", None): env["LAW_SANDBOX_WORKER_ID"] = self.task._worker_id - if getattr(self.task, "_worker_task", None): - env["LAW_SANDBOX_WORKER_TASK"] = self.task.live_task_id - env["LAW_SANDBOX_WORKER_ROOT_TASK"] = root_task().task_id - env["LAW_SANDBOX_IS_ROOT_TASK"] = str(int(self.task.is_root_task())) + if getattr(self.task, "_worker_first_task_id", None): + env["LAW_SANDBOX_WORKER_FIRST_TASK_ID"] = self.task._worker_first_task_id # extend by variables from the config file cfg = Config.instance() @@ -469,6 +480,10 @@ def is_root_task(self): return is_root def _staged_input(self): + if not _sandbox_stagein_dir: + raise Exception("LAW_SANDBOX_STAGEIN_DIR must not be empty in a sandbox when target " + "stage-in is required") + # get the original inputs inputs = self.__getattribute__("input", proxy=False)() @@ -479,6 +494,10 @@ def _staged_input(self): return mask_struct(self.sandbox_stagein(), staged_inputs, inputs) def _staged_output(self): + if not _sandbox_stagein_dir: + raise Exception("LAW_SANDBOX_STAGEOUT_DIR must not be empty in a sandbox when target " + "stage-out is required") + # get the original outputs outputs = self.__getattribute__("output", proxy=False)()