Skip to content

Commit

Permalink
Added a test for the RedundantArray fix.
Browse files Browse the repository at this point in the history
I also verified that without the fix the test will fail.
  • Loading branch information
philip-paul-mueller committed Jun 18, 2024
1 parent f1ab0d4 commit e3f6cfe
Showing 1 changed file with 101 additions and 0 deletions.
101 changes: 101 additions & 0 deletions tests/transformations/redundant_copy_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import numpy as np
import pytest
from typing import Tuple

import dace
from dace import nodes
Expand All @@ -9,6 +10,105 @@
RedundantArrayCopyingIn)


def test_reshaping_with_redundant_arrays():
def make_sdfg() -> Tuple[dace.SDFG, dace.nodes.AccessNode, dace.nodes.AccessNode, dace.nodes.AccessNode]:
sdfg = dace.SDFG("slicing_sdfg")
_, input_desc = sdfg.add_array(
"input",
shape=(6, 6, 6),
transient=False,
strides=(1, 6, 36),
dtype=dace.float64,
)
_, a_desc = sdfg.add_array(
"a",
shape=(6, 6, 6),
transient=True,
strides=(36, 6, 1),
dtype=dace.float64,
)
_, b_desc = sdfg.add_array(
"b",
shape=(36, 1, 6),
transient=True,
strides=(6, 6, 1),
dtype=dace.float64,
)
_, output_desc = sdfg.add_array(
"output",
shape=(36, 1, 6),
transient=False,
strides=(6, 6, 1 ),
dtype=dace.float64,
)
state = sdfg.add_state("state", is_start_block=True)
input_an = state.add_access("input")
a_an = state.add_access("a")
b_an = state.add_access("b")
output_an = state.add_access("output")

state.add_edge(
input_an,
None,
a_an,
None,
dace.Memlet.from_array("input", input_desc),
)
state.add_edge(
a_an,
None,
b_an,
None,
dace.Memlet.simple(
"a",
subset_str="0:6, 0:6, 0:6",
other_subset_str="0:36, 0, 0:6",
)
)
state.add_edge(
b_an,
None,
output_an,
None,
dace.Memlet.from_array("b", b_desc),
)
sdfg.validate()
return sdfg, a_an, b_an, output_an

def apply_trafo(
sdfg: dace.SDFG,
in_array: dace.nodes.AccessNode,
out_array: dace.nodes.AccessNode,
) -> None:
trafo = RedundantArray()

candidate = {type(trafo).in_array: in_array, type(trafo).out_array: out_array}
state = sdfg.start_block
state_id = sdfg.node_id(state)
assert state.number_of_nodes() == 4
assert len(sdfg.arrays) == 4

trafo.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True)
if trafo.can_be_applied(state, 0, sdfg):
ret = trafo.apply(state, sdfg)
if ret is not None: # A view was created
assert False, f"A view was created instead removing '{in_array.data}'."
sdfg.validate()
assert state.number_of_nodes() == 3
assert len(sdfg.arrays) == 3
assert in_array.data not in sdfg.arrays
return
assert False, "Could not apply the transformation."

# Case 1: Removing a
sdfg, a_an, b_an, _ = make_sdfg()
apply_trafo(sdfg, in_array=a_an, out_array=b_an)

# Case 2: Removing b
sdfg, _, b_an, output_an = make_sdfg()
apply_trafo(sdfg, in_array=b_an, out_array=output_an)


def test_out():
sdfg = dace.SDFG("test_redundant_copy_out")
state = sdfg.add_state()
Expand Down Expand Up @@ -331,6 +431,7 @@ def flip_and_flatten(a, b):


if __name__ == '__main__':
test_slicing_with_redundant_arrays()
test_in()
test_out()
test_out_success()
Expand Down

0 comments on commit e3f6cfe

Please sign in to comment.