diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 18b2893a0..55a8e3684 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -214,7 +214,7 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: dim_to_index_lambda_components, ) - bindings = {f"in{k}": self.rec(arg) for k, arg in enumerate(expr.args)} + bindings = {f"_in{k}": self.rec(arg) for k, arg in enumerate(expr.args)} redn_bounds: dict[str, tuple[ScalarExpression, ScalarExpression]] = {} args_as_pym_expr: list[prim.Subscript] = [] namegen = UniqueNameGenerator(set(bindings)) @@ -253,7 +253,7 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: subscript_indices.append(prim.Variable(redn_idx_name)) - args_as_pym_expr.append(prim.Subscript(prim.Variable(f"in{iarg}"), + args_as_pym_expr.append(prim.Subscript(prim.Variable(f"_in{iarg}"), tuple(subscript_indices))) # }}}