Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: typhoonzero <[email protected]>
  • Loading branch information
typhoonzero committed May 9, 2023
1 parent c32a1be commit 5cb02ab
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion elyra/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def gpu_vendor(self) -> Optional[str]:

@property
def parallel_count(self) -> Optional[str]:
return self._component_props.get("parallel_count")
return self._component_props.get("parallel_count", 1)

def __eq__(self, other: GenericOperation) -> bool:
if isinstance(self, other.__class__):
Expand Down
8 changes: 6 additions & 2 deletions elyra/templates/kubeflow/v1/python_dsl_template.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ def generated_pipeline(
{% set task_name = "task_" + workflow_task.escaped_task_id %}
# Task for node '{{ workflow_task.name }}'
{% set parallel_indent = 0 %}
{% if workflow_task.task_modifiers.parallel_count > 1 %}
{% if 'parallel_count' in workflow_task.task_modifiers and workflow_task.task_modifiers.parallel_count is not none %}
{% if workflow_task.task_modifiers.parallel_count > 1 %}
{% set parallel_indent = 4 %}
parallel_count = {{workflow_task.task_modifiers.parallel_count}}
with kfp.dsl.ParallelFor(list(range(parallel_count))) as rank:
{% endif %}
{% endif %}

{% filter indent(width=parallel_indent) %}
Expand Down Expand Up @@ -81,9 +83,11 @@ def generated_pipeline(
{% for env_var_name, env_var_value in workflow_task.task_modifiers.env_variables.items() %}
{{ task_name }}.add_env_variable(V1EnvVar(name="{{ env_var_name }}", value="{{ env_var_value | string_delimiter_safe }}"))
{% endfor %}
{% if workflow_task.task_modifiers.parallel_count > 1 %}
{% if 'parallel_count' in workflow_task.task_modifiers and workflow_task.task_modifiers.parallel_count is not none %}
{% if workflow_task.task_modifiers.parallel_count > 1 %}
{{ task_name }}.add_env_variable(V1EnvVar(name="NRANKS", value=str(parallel_count)))
{{ task_name }}.add_env_variable(V1EnvVar(name="RANK", value=str(rank)))
{% endif %}
{% endif %}
{% if workflow_engine == "argo" %}
{{ task_name }}.add_env_variable(V1EnvVar(
Expand Down
14 changes: 5 additions & 9 deletions elyra/tests/pipeline/kfp/test_processor_kfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def test_generate_pipeline_dsl_compile_pipeline_dsl_one_generic_node_pipeline_te

# Verify component definition information (see generic_component_definition_template.jinja2)
# - property 'name'
assert node_template["name"] == "run-a-file"
assert node_template["name"] == sanitize_label_value(op.name)
# - property 'implementation.container.command'
assert node_template["container"]["command"] == ["sh", "-c"]
# - property 'implementation.container.args'
Expand Down Expand Up @@ -1416,11 +1416,9 @@ def test_generate_pipeline_dsl_compile_pipeline_dsl_generic_components_data_exch
assert len(compiled_spec["spec"]["templates"]) >= 3
template_specs = {}
for node_template in compiled_spec["spec"]["templates"]:
if node_template["name"] == compiled_spec["spec"]["entrypoint"] or not node_template["name"].startswith(
"run-a-file"
):
if node_template["name"] == compiled_spec["spec"]["entrypoint"]:
continue
template_specs[node_template["name"]] = node_template
template_specs[sanitize_label_value(node_template["name"])] = node_template

# Iterate through sorted operations and verify that their inputs
# and outputs are properly represented in their respective template
Expand All @@ -1430,10 +1428,8 @@ def test_generate_pipeline_dsl_compile_pipeline_dsl_generic_components_data_exch
if not op.is_generic:
# ignore custom nodes
continue
if template_index == 1:
template_name = "run-a-file"
else:
template_name = f"run-a-file-{template_index}"
template_name = sanitize_label_value(op.name)
template_name = template_name.replace("_", "-") # kubernetes does this replace
template_index = template_index + 1
# compare outputs
if len(op.outputs) > 0:
Expand Down

0 comments on commit 5cb02ab

Please sign in to comment.