diff --git a/grudge/pytato_transforms/pytato_indirection_transforms.py b/grudge/pytato_transforms/pytato_indirection_transforms.py index a5f05c64..9d4961a5 100644 --- a/grudge/pytato_transforms/pytato_indirection_transforms.py +++ b/grudge/pytato_transforms/pytato_indirection_transforms.py @@ -46,8 +46,12 @@ def _is_materialized(expr: Array) -> bool: def _can_index_lambda_propagate_indirections_without_changing_axes( - expr: IndexLambda) -> bool: - + expr: IndexLambda, iel_axis: Optional[int], idof_axis: Optional[int] +) -> bool: + """ + Returns *True* only if the axes being reindexed appear at the same + positions in the bindings' indexing locations. + """ from pytato.utils import are_shapes_equal from pytato.raising import (index_lambda_to_high_level_op, BinaryOp) @@ -219,8 +223,8 @@ def _fuse_from_element_indices(from_element_indices: Tuple[Array, ...]): return result -def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...], from_element_indices: - Tuple[Array, ...]): +def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...], + from_element_indices: Tuple[Array, ...]): assert all(from_el_idx.ndim == 2 for from_el_idx in from_element_indices) assert all(dof_pick_list.ndim == 2 for dof_pick_list in dof_pick_lists) assert all(from_el_idx.shape[1] == 1 for from_el_idx in from_element_indices) @@ -239,7 +243,10 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array, from_element_indices: Tuple[Array, ...], dof_pick_lists: Tuple[Array, ...] ) -> Array: - + raise NotImplementedError("We still need to port this from" + " the previous version, where only" + " indirections only along the element" + " axes.") if iel_axis is not None: assert idof_axis is not None assert len(from_element_indices) != 0 @@ -263,6 +270,56 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array, return rec_expr +def _is_iel_idof_picking(expr: AdvancedIndexInContiguousAxes, + iel_axis: Optional[int], + idof_axis: Optional[int], + ) -> bool: + if expr.ndim != 2: + return False + + if expr.array.ndim != 2: + return False + + if not ((iel_axis is None and idof_axis is None) + or (iel_axis == 0 and idof_axis == 1)): + return False + + if (isinstance(expr.indices[0], Array) + and isinstance(expr.indices[1], Array)): + from pytato.utils import are_shape_components_equal + from_el_indices, dof_pick_lists = expr.indices + assert isinstance(from_el_indices, Array) + assert isinstance(dof_pick_lists, Array) + + if dof_pick_lists.ndim != 1: + return False + if from_el_indices.ndim != 2: + return False + if are_shape_components_equal(from_el_indices.shape[1], 1): + return False + + return True + else: + return False + + +def _is_iel_only_picking(expr: AdvancedIndexInContiguousAxes, + iel_axis: Optional[int]) -> bool: + if expr.ndim != 1: + return False + + if expr.array.ndim != 1: + return False + + if not isinstance(expr.indices[0], Array): + return False + + if iel_axis not in [0, None]: + return False + + return True + + class PickListFusers(Mapper): def __init__(self) -> None: self.can_pick_indirections_be_propagated = _CanPickIndirectionsBePropagated() @@ -283,18 +340,22 @@ def rec(self, # type: ignore[override] " is illegal for PickListFusers. Pass arrays" " instead.") - if iel_axis is not None: - assert idof_axis is not None + if idof_axis is not None: + assert iel_axis is not None assert 0 <= iel_axis < expr.ndim assert 0 <= idof_axis < expr.ndim # the condition below ensures that we are only dealing with indirections # appearing at contiguous locations. assert abs(iel_axis-idof_axis) == 1 - else: + assert len(dof_pick_lists) == len(from_element_indices) + elif iel_axis is not None: assert idof_axis is None + assert len(dof_pick_lists) == 0 + assert len(from_element_indices) > 0 + else: + assert iel_axis is None assert len(from_element_indices) == 0 - - assert len(dof_pick_lists) == len(from_element_indices) + assert len(dof_pick_lists) == 0 key = (expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) try: @@ -318,8 +379,8 @@ def __call__(self, # type: ignore[override] def _map_input_base(self, expr: InputArgumentBase, - iel_axis: int, - idof_axis: int, + iel_axis: Optional[int], + idof_axis: Optional[int], from_element_indices: Tuple[Array, ...], dof_pick_lists: Tuple[Array, ...]) -> Array: return _pick_list_fusers_map_materialized_node( @@ -351,30 +412,36 @@ def map_index_lambda(self, rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) if iel_axis is not None: - assert idof_axis is not None assert _can_index_lambda_propagate_indirections_without_changing_axes( - expr) - from pytato.utils import are_shapes_equal - new_el_dim, new_dofs_dim = dof_pick_lists[0].shape - assert are_shapes_equal(from_element_indices[0].shape, (new_el_dim, 1)) - - new_shape = tuple( - new_el_dim if idim == iel_axis else ( - new_dofs_dim if idim == idof_axis else dim) - for idim, dim in enumerate(expr.shape)) - - return IndexLambda( - expr.expr, - new_shape, - expr.dtype, - Map({name: self.rec(bnd, iel_axis, idof_axis, - from_element_indices, - dof_pick_lists) - for name, bnd in expr.bindings.items()}), - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - axes=expr.axes - ) + expr, iel_axis, idof_axis) + if idof_axis is None: + # TODO: Not encountered any practical DAGs that take this code path. + # Implement this branch only if seen in any practical applications. + raise NotImplementedError + else: + assert idof_axis is not None + from pytato.utils import are_shapes_equal + new_el_dim, new_dofs_dim = dof_pick_lists[0].shape + assert are_shapes_equal(from_element_indices[0].shape, + (new_el_dim, 1)) + + new_shape = tuple( + new_el_dim if idim == iel_axis else ( + new_dofs_dim if idim == idof_axis else dim) + for idim, dim in enumerate(expr.shape)) + + return IndexLambda( + expr.expr, + new_shape, + expr.dtype, + Map({name: self.rec(bnd, iel_axis, idof_axis, + from_element_indices, + dof_pick_lists) + for name, bnd in expr.bindings.items()}), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes + ) else: return IndexLambda( expr.expr, @@ -405,14 +472,17 @@ def map_contiguous_advanced_index(self, return _pick_list_fusers_map_materialized_node( rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) - if self.can_pick_indirections_be_propagated(expr, - iel_axis or 0, - idof_axis or 1): - idx1, idx2 = expr.indices - assert isinstance(idx1, Array) and isinstance(idx2, Array) - return self.rec(expr.array, 0, 1, - from_element_indices + (idx1,), - dof_pick_lists + (idx2,)) + if (_is_iel_idof_picking(expr, iel_axis, idof_axis) + and self.can_pick_indirections_be_propagated(expr, + iel_axis or 0, + idof_axis or 1)): + raise NotImplementedError + elif (_is_iel_only_picking(expr, iel_axis) + and self.can_pick_indirections_be_propagated(expr, + iel_axis or 0, + None)): + assert idof_axis is None + raise NotImplementedError else: assert iel_axis is None and idof_axis is None return AdvancedIndexInContiguousAxes(