diff --git a/dali/python/nvidia/dali/_conditionals.py b/dali/python/nvidia/dali/_conditionals.py index 77d1b9733e..f7ae1f54d1 100644 --- a/dali/python/nvidia/dali/_conditionals.py +++ b/dali/python/nvidia/dali/_conditionals.py @@ -541,6 +541,19 @@ def _verify_branch_outputs(outputs, symbol_names, branch_name): ) +def _validate_logical(value, expression_name, expression_side): + v = fn._conditional.validate_logical( + value, expression_name=expression_name, expression_side=expression_side + ) + if v.device != "cpu": + raise RuntimeError( + f"Logical expression `{value}` is restricted to scalar (0-d tensors)" + f" inputs of `bool` type, that are placed on CPU." + f" Got a GPU input as the {expression_side} argument in logical expression." + ) + return v + + class DaliOperatorOverload(_autograph.OperatorBase): def detect_overload_ld(self, v): return isinstance(v, _DataNode) @@ -647,13 +660,11 @@ def lazy_and(self, a_value, b): # and_output = b() # else: # and_output = a_val - a_validated = fn._conditional.validate_logical( - a_value, expression_name="and", expression_side="left" - ) + a_validated = _validate_logical(a_value, expression_name="and", expression_side="left") with _cond_manager(a_validated) as split_predicate: with _cond_true(): b_value = b() - b_validated = fn._conditional.validate_logical( + b_validated = _validate_logical( b_value, expression_name="and", expression_side="right" ) body_outputs = apply_conditional_split(b_validated) @@ -676,15 +687,14 @@ def lazy_or(self, a_value, b): # or_output = a_val # else: # or_output = b() - a_validated = fn._conditional.validate_logical( - a_value, expression_name="or", expression_side="left" - ) + a_validated = _validate_logical(a_value, expression_name="or", expression_side="left") + with _cond_manager(a_validated) as split_predicate: with _cond_true(): body_outputs = apply_conditional_split(split_predicate) with _cond_false(): b_value = b() - b_validated = fn._conditional.validate_logical( + b_validated = _validate_logical( b_value, expression_name="or", expression_side="right" ) else_outputs = apply_conditional_split(b_validated)