Skip to content

Commit

Permalink
Cleanup sandbox environment variable handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Jan 13, 2020
1 parent c15c30a commit 183dd59
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
27 changes: 12 additions & 15 deletions law/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
import logging

import luigi
import law


logger = logging.getLogger(__name__)


_patched = False

_sandbox_switched = os.getenv("LAW_SANDBOX_SWITCHED", "") == "1"

_sandbox_task_id = os.getenv("LAW_SANDBOX_WORKER_TASK", "")


def patch_all():
"""
Expand Down Expand Up @@ -80,7 +77,7 @@ def patch_worker_add_task():
_add_task = luigi.worker.Worker._add_task

def add_task(self, *args, **kwargs):
if _sandbox_switched and "deps" in kwargs:
if law.sandbox.base._sandbox_switched and "deps" in kwargs:
kwargs["deps"] = None
return _add_task(self, *args, **kwargs)

Expand All @@ -100,8 +97,8 @@ def patch_worker_add():
def add(self, task, *args, **kwargs):
# _add returns a generator, which we simply drain here
# when we are in a sandbox
if _sandbox_switched:
task.task_id = _sandbox_task_id
if law.sandbox.base._sandbox_switched:
task.task_id = law.sandbox.base._sandbox_task_id
for _ in _add(self, task, *args, **kwargs):
pass
return []
Expand All @@ -124,16 +121,16 @@ def run_task(self, task_id):
task = self._scheduled_tasks[task_id]

task._worker_id = self._id
task._worker_task = self._first_task
task._worker_first_task_id = self._first_task

try:
_run_task(self, task_id)
finally:
task._worker_id = None
task._worker_task = None
task._worker_first_task_id = None

# make worker disposable when sandboxed
if _sandbox_switched:
if law.sandbox.base._sandbox_switched:
self._start_phasing_out()

luigi.worker.Worker._run_task = run_task
Expand All @@ -151,10 +148,10 @@ def patch_worker_get_work():
_get_work = luigi.worker.Worker._get_work

def get_work(self):
if _sandbox_switched:
if law.sandbox.base._sandbox_switched:
# when the worker is configured to stop requesting work, as triggered by the patched
# _run_task method (see above), the worker response should contain an empty task_id
task_id = None if self._stop_requesting_work else os.environ["LAW_SANDBOX_WORKER_TASK"]
task_id = None if self._stop_requesting_work else law.sandbox.base._sandbox_task_id
return luigi.worker.GetWorkResponse(
task_id=task_id,
running_tasks=[],
Expand All @@ -178,8 +175,8 @@ def patch_worker_factory():
"""
def create_worker(self, scheduler, worker_processes, assistant=False):
worker = luigi.worker.Worker(scheduler=scheduler, worker_processes=worker_processes,
assistant=assistant, worker_id=os.getenv("LAW_SANDBOX_WORKER_ID"))
worker._first_task = os.getenv("LAW_SANDBOX_WORKER_ROOT_TASK")
assistant=assistant, worker_id=law.sandbox.base._sandbox_worker_id or None)
worker._first_task = law.sandbox.base._sandbox_worker_first_task_id or None
return worker

luigi.interface._WorkerSchedulerFactory.create_worker = create_worker
Expand All @@ -196,7 +193,7 @@ def patch_keepalive_run():

def run(self):
# do not run the keep-alive loop when sandboxed
if _sandbox_switched:
if law.sandbox.base._sandbox_switched:
self.stop()
else:
_run(self)
Expand Down
39 changes: 29 additions & 10 deletions law/sandbox/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,29 @@

_sandbox_switched = os.getenv("LAW_SANDBOX_SWITCHED", "") == "1"

_sandbox_task_id = os.getenv("LAW_SANDBOX_TASK_ID", "")

_sandbox_worker_id = os.getenv("LAW_SANDBOX_WORKER_ID", "")

_sandbox_worker_first_task_id = os.getenv("LAW_SANDBOX_WORKER_FIRST_TASK_ID", "")

_sandbox_is_root_task = os.getenv("LAW_SANDBOX_IS_ROOT_TASK", "") == "1"

_sandbox_stagein_dir = os.getenv("LAW_SANDBOX_STAGEIN_DIR", "")

_sandbox_stageout_dir = os.getenv("LAW_SANDBOX_STAGEOUT_DIR", "")

_sandbox_task_id = os.getenv("LAW_SANDBOX_WORKER_TASK", "")

_sandbox_root_task_id = os.getenv("LAW_SANDBOX_WORKER_ROOT_TASK", "")

# the task id must be set when in a sandbox
if not _sandbox_task_id and _sandbox_switched:
raise Exception("LAW_SANDBOX_WORKER_TASK must not be empty in a sandbox")
# certain values must be present in a sandbox
if _sandbox_switched:
if not _current_sandbox or not _current_sandbox[0]:
raise Exception("LAW_SANDBOX must not be empty in a sandbox")
if not _sandbox_task_id:
raise Exception("LAW_SANDBOX_TASK_ID must not be empty in a sandbox")
elif not _sandbox_worker_id:
raise Exception("LAW_SANDBOX_WORKER_ID must not be empty in a sandbox")
elif not _sandbox_worker_first_task_id:
raise Exception("LAW_SANDBOX_WORKER_FIRST_TASK_ID must not be empty in a sandbox")


class StageInfo(object):
Expand Down Expand Up @@ -172,12 +182,13 @@ def _get_env(self):
env["LAW_SANDBOX"] = self.key.replace("$", r"\$")
env["LAW_SANDBOX_SWITCHED"] = "1"
if self.task:
env["LAW_SANDBOX_TASK_ID"] = self.task.live_task_id
env["LAW_SANDBOX_ROOT_TASK_ID"] = root_task().task_id
env["LAW_SANDBOX_IS_ROOT_TASK"] = str(int(self.task.is_root_task()))
if getattr(self.task, "_worker_id", None):
env["LAW_SANDBOX_WORKER_ID"] = self.task._worker_id
if getattr(self.task, "_worker_task", None):
env["LAW_SANDBOX_WORKER_TASK"] = self.task.live_task_id
env["LAW_SANDBOX_WORKER_ROOT_TASK"] = root_task().task_id
env["LAW_SANDBOX_IS_ROOT_TASK"] = str(int(self.task.is_root_task()))
if getattr(self.task, "_worker_first_task_id", None):
env["LAW_SANDBOX_WORKER_FIRST_TASK_ID"] = self.task._worker_first_task_id

# extend by variables from the config file
cfg = Config.instance()
Expand Down Expand Up @@ -469,6 +480,10 @@ def is_root_task(self):
return is_root

def _staged_input(self):
if not _sandbox_stagein_dir:
raise Exception("LAW_SANDBOX_STAGEIN_DIR must not be empty in a sandbox when target "
"stage-in is required")

# get the original inputs
inputs = self.__getattribute__("input", proxy=False)()

Expand All @@ -479,6 +494,10 @@ def _staged_input(self):
return mask_struct(self.sandbox_stagein(), staged_inputs, inputs)

def _staged_output(self):
if not _sandbox_stagein_dir:
raise Exception("LAW_SANDBOX_STAGEOUT_DIR must not be empty in a sandbox when target "
"stage-out is required")

# get the original outputs
outputs = self.__getattribute__("output", proxy=False)()

Expand Down

0 comments on commit 183dd59

Please sign in to comment.