Skip to content

Commit

Permalink
Add python-side backend validation in conditional expressions.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Sep 16, 2024
1 parent 6305131 commit 2e89b84
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions dali/python/nvidia/dali/_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,19 @@ def _verify_branch_outputs(outputs, symbol_names, branch_name):
)


def _validate_logical(value, expression_name, expression_side):
v = fn._conditional.validate_logical(
value, expression_name=expression_name, expression_side=expression_side
)
if v.device != "cpu":
raise RuntimeError(
f"Logical expression `{value}` is restricted to scalar (0-d tensors)"
f" inputs of `bool` type, that are placed on CPU."
f" Got a GPU input as the {expression_side} argument in logical expression."
)
return v


class DaliOperatorOverload(_autograph.OperatorBase):
def detect_overload_ld(self, v):
return isinstance(v, _DataNode)
Expand Down Expand Up @@ -647,13 +660,11 @@ def lazy_and(self, a_value, b):
# and_output = b()
# else:
# and_output = a_val
a_validated = fn._conditional.validate_logical(
a_value, expression_name="and", expression_side="left"
)
a_validated = _validate_logical(a_value, expression_name="and", expression_side="left")
with _cond_manager(a_validated) as split_predicate:
with _cond_true():
b_value = b()
b_validated = fn._conditional.validate_logical(
b_validated = _validate_logical(
b_value, expression_name="and", expression_side="right"
)
body_outputs = apply_conditional_split(b_validated)
Expand All @@ -676,15 +687,14 @@ def lazy_or(self, a_value, b):
# or_output = a_val
# else:
# or_output = b()
a_validated = fn._conditional.validate_logical(
a_value, expression_name="or", expression_side="left"
)
a_validated = _validate_logical(a_value, expression_name="or", expression_side="left")

with _cond_manager(a_validated) as split_predicate:
with _cond_true():
body_outputs = apply_conditional_split(split_predicate)
with _cond_false():
b_value = b()
b_validated = fn._conditional.validate_logical(
b_validated = _validate_logical(
b_value, expression_name="or", expression_side="right"
)
else_outputs = apply_conditional_split(b_validated)
Expand Down

0 comments on commit 2e89b84

Please sign in to comment.