Skip to content

Commit

Permalink
Pass scheduler as an artifact (#29)
Browse files Browse the repository at this point in the history
* change interface of scheduler from parameter to artifact

* adjust test for artifact scheduler

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
amcadmus and Han Wang authored May 11, 2022
1 parent a1fd77b commit 1044207
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 47 deletions.
57 changes: 34 additions & 23 deletions dpgen2/flow/dpgen_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ class SchedulerWrapper(OP):
@classmethod
def get_input_sign(cls):
return OPIOSign({
"exploration_scheduler" : ExplorationScheduler,
"exploration_scheduler" : Artifact(Path),
"exploration_report": ExplorationReport,
"trajs": Artifact(List[Path]),
})

@classmethod
def get_output_sign(cls):
return OPIOSign({
"exploration_scheduler" : ExplorationScheduler,
"exploration_scheduler" : Artifact(Path),
"converged" : bool,
"lmp_task_grp" : Artifact(Path),
"conf_selector" : ConfSelector,
Expand All @@ -58,28 +58,36 @@ def execute(
self,
ip : OPIO,
) -> OPIO:
scheduler = ip['exploration_scheduler']
scheduler_in = ip['exploration_scheduler']
report = ip['exploration_report']
trajs = ip['trajs']
lmp_task_grp_file = Path('lmp_task_grp.dat')
scheduler_file = Path('scheduler.dat')

with open(scheduler_in, 'rb') as fp:
scheduler = pickle.load(fp)

conv, lmp_task_grp, selector = scheduler.plan_next_iteration(report, trajs)

with open('lmp_task_grp.dat', 'wb') as fp:
with open(lmp_task_grp_file, 'wb') as fp:
pickle.dump(lmp_task_grp, fp)

with open(scheduler_file, 'wb') as fp:
pickle.dump(scheduler, fp)

return OPIO({
"exploration_scheduler" : scheduler,
"exploration_scheduler" : scheduler_file,
"converged" : conv,
"conf_selector" : selector,
"lmp_task_grp" : Path('lmp_task_grp.dat'),
"lmp_task_grp" : lmp_task_grp_file,
})


class MakeBlockId(OP):
@classmethod
def get_input_sign(cls):
return OPIOSign({
"exploration_scheduler" : ExplorationScheduler,
"exploration_scheduler" : Artifact(Path),
})

@classmethod
Expand All @@ -93,8 +101,11 @@ def execute(
self,
ip : OPIO,
) -> OPIO:
scheduler = ip['exploration_scheduler']
scheduler_in = ip['exploration_scheduler']

with open(scheduler_in, 'rb') as fp:
scheduler = pickle.load(fp)

stage = scheduler.get_stage()
iteration = scheduler.get_iteration()

Expand All @@ -121,18 +132,18 @@ def __init__(
"conf_selector" : InputParameter(),
"fp_inputs" : InputParameter(),
"fp_config" : InputParameter(),
"exploration_scheduler" : InputParameter(),
}
self._input_artifacts={
"exploration_scheduler" : InputArtifact(),
"init_models" : InputArtifact(),
"init_data" : InputArtifact(),
"iter_data" : InputArtifact(),
"lmp_task_grp" : InputArtifact(),
}
self._output_parameters={
"exploration_scheduler": OutputParameter(),
}
self._output_artifacts={
"exploration_scheduler": OutputArtifact(),
"models": OutputArtifact(),
"iter_data" : OutputArtifact(),
}
Expand Down Expand Up @@ -213,17 +224,17 @@ def __init__(
"lmp_config" : InputParameter(),
"fp_inputs" : InputParameter(),
"fp_config" : InputParameter(),
"exploration_scheduler" : InputParameter(),
}
self._input_artifacts={
"exploration_scheduler" : InputArtifact(),
"init_models" : InputArtifact(),
"init_data" : InputArtifact(),
"iter_data" : InputArtifact(),
}
self._output_parameters={
"exploration_scheduler": OutputParameter(),
}
self._output_artifacts={
"exploration_scheduler": OutputArtifact(),
"models": OutputArtifact(),
"iter_data" : OutputArtifact(),
}
Expand Down Expand Up @@ -321,10 +332,10 @@ def _loop (
python_packages = upload_python_package,
),
parameters={
"exploration_scheduler": steps.inputs.parameters['exploration_scheduler'],
"exploration_report": block_step.outputs.parameters['exploration_report'],
},
artifacts={
"exploration_scheduler": steps.inputs.artifacts['exploration_scheduler'],
"trajs" : block_step.outputs.artifacts['trajs'],
},
key = step_keys['scheduler'],
Expand All @@ -339,9 +350,9 @@ def _loop (
python_packages = upload_python_package,
),
parameters={
"exploration_scheduler": scheduler_step.outputs.parameters['exploration_scheduler'],
},
artifacts={
"exploration_scheduler": scheduler_step.outputs.artifacts['exploration_scheduler'],
},
key = step_keys['id'],
)
Expand All @@ -360,9 +371,9 @@ def _loop (
"conf_selector" : scheduler_step.outputs.parameters["conf_selector"],
"fp_inputs" : steps.inputs.parameters["fp_inputs"],
"fp_config" : steps.inputs.parameters["fp_config"],
"exploration_scheduler" : scheduler_step.outputs.parameters["exploration_scheduler"],
},
artifacts={
"exploration_scheduler" : scheduler_step.outputs.artifacts["exploration_scheduler"],
"lmp_task_grp" : scheduler_step.outputs.artifacts["lmp_task_grp"],
"init_models" : block_step.outputs.artifacts['models'],
"init_data" : steps.inputs.artifacts['init_data'],
Expand All @@ -372,11 +383,11 @@ def _loop (
)
steps.add(next_step)

steps.outputs.parameters['exploration_scheduler'].value_from_expression = \
steps.outputs.artifacts['exploration_scheduler'].from_expression = \
if_expression(
_if = (scheduler_step.outputs.parameters['converged'] == True),
_then = scheduler_step.outputs.parameters['exploration_scheduler'],
_else = next_step.outputs.parameters['exploration_scheduler'],
_then = scheduler_step.outputs.artifacts['exploration_scheduler'],
_else = next_step.outputs.artifacts['exploration_scheduler'],
)
steps.outputs.artifacts['models'].from_expression = \
if_expression(
Expand Down Expand Up @@ -411,10 +422,10 @@ def _dpgen(
python_packages = upload_python_package,
),
parameters={
"exploration_scheduler": steps.inputs.parameters['exploration_scheduler'],
"exploration_report": None,
},
artifacts={
"exploration_scheduler": steps.inputs.artifacts['exploration_scheduler'],
"trajs" : None,
},
key = step_keys['scheduler'],
Expand All @@ -429,9 +440,9 @@ def _dpgen(
python_packages = upload_python_package,
),
parameters={
"exploration_scheduler": scheduler_step.outputs.parameters['exploration_scheduler'],
},
artifacts={
"exploration_scheduler": scheduler_step.outputs.artifacts['exploration_scheduler'],
},
key = step_keys['id'],
)
Expand All @@ -450,9 +461,9 @@ def _dpgen(
"lmp_config" : steps.inputs.parameters['lmp_config'],
"fp_inputs" : steps.inputs.parameters['fp_inputs'],
"fp_config" : steps.inputs.parameters['fp_config'],
"exploration_scheduler" : scheduler_step.outputs.parameters['exploration_scheduler'],
},
artifacts={
"exploration_scheduler" : scheduler_step.outputs.artifacts['exploration_scheduler'],
"lmp_task_grp" : scheduler_step.outputs.artifacts['lmp_task_grp'],
"init_models": steps.inputs.artifacts["init_models"],
"init_data": steps.inputs.artifacts["init_data"],
Expand All @@ -462,8 +473,8 @@ def _dpgen(
)
steps.add(loop_step)

steps.outputs.parameters["exploration_scheduler"].value_from_parameter = \
loop_step.outputs.parameters["exploration_scheduler"]
steps.outputs.artifacts["exploration_scheduler"]._from = \
loop_step.outputs.artifacts["exploration_scheduler"]
steps.outputs.artifacts["models"]._from = \
loop_step.outputs.artifacts["models"]
steps.outputs.artifacts["iter_data"]._from = \
Expand Down
Loading

0 comments on commit 1044207

Please sign in to comment.