Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick a small number of early, independent, non-breaking commits about "offsets" out of multi_sdfg branch. #1730

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions dace/frontend/fortran/ast_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,18 +488,29 @@ def parse_shape_specification(self, dim: f03.Explicit_Shape_Spec, size: List[FAS
#now to add the dimension to the size list after processing it if necessary
size.append(self.create_ast(dim_expr))
offset.append(1)

# Here we support arrays that have size declaration - with initial offset.
elif len(dim_expr) == 2:
# extract offets
for expr in dim_expr:
if not isinstance(expr, f03.Int_Literal_Constant):
raise TypeError("Array offsets must be constant expressions!")
offset.append(int(dim_expr[0].tostr()))

fortran_size = int(dim_expr[1].tostr()) - int(dim_expr[0].tostr()) + 1
fortran_ast_size = f03.Int_Literal_Constant(str(fortran_size))

size.append(self.create_ast(fortran_ast_size))
if isinstance(dim_expr[0], f03.Int_Literal_Constant):
#raise TypeError("Array offsets must be constant expressions!")
offset.append(int(dim_expr[0].tostr()))
else:
expr = self.create_ast(dim_expr[0])
offset.append(expr)

fortran_size = ast_internal_classes.BinOp_Node(
lval=self.create_ast(dim_expr[1]),
rval=self.create_ast(dim_expr[0]),
op="-",
type="INTEGER"
)
size.append(ast_internal_classes.BinOp_Node(
lval=fortran_size,
rval=ast_internal_classes.Int_Literal_Node(value=str(1)),
op="+",
type="INTEGER")
)
else:
raise TypeError("Array dimension must be at most two expressions")

Expand Down
21 changes: 18 additions & 3 deletions dace/frontend/fortran/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,14 +467,18 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
variable = self.scope_vars.get_var(child.parent, var_name)
offset = variable.offsets[idx]

# it can be a symbol - Name_Node - or a value
if not isinstance(offset, ast_internal_classes.Name_Node):
offset = ast_internal_classes.Int_Literal_Node(value=str(offset))

newbody.append(
ast_internal_classes.BinOp_Node(
op="=",
lval=ast_internal_classes.Name_Node(name=tmp_name),
rval=ast_internal_classes.BinOp_Node(
op="-",
lval=i,
rval=ast_internal_classes.Int_Literal_Node(value=str(offset)),
rval=offset,
line_number=child.line_number),
line_number=child.line_number))
else:
Expand Down Expand Up @@ -752,7 +756,11 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node,

lower_boundary = None
if offsets[idx] != 1:
lower_boundary = ast_internal_classes.Int_Literal_Node(value=str(offsets[idx]))
# support symbols and integer literals
if isinstance(offsets[idx], ast_internal_classes.Name_Node):
lower_boundary = offsets[idx]
else:
lower_boundary = ast_internal_classes.Int_Literal_Node(value=str(offsets[idx]))
else:
lower_boundary = ast_internal_classes.Int_Literal_Node(value="1")

Expand All @@ -765,10 +773,17 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node,
But since the generated loop has `<=` condition, we need to subtract 1.
"""
if offsets[idx] != 1:

# support symbols and integer literals
if isinstance(offsets[idx], ast_internal_classes.Name_Node):
offset = offsets[idx]
else:
offset = ast_internal_classes.Int_Literal_Node(value=str(offsets[idx]))

upper_boundary = ast_internal_classes.BinOp_Node(
lval=upper_boundary,
op="+",
rval=ast_internal_classes.Int_Literal_Node(value=str(offsets[idx]))
rval=offset
)
upper_boundary = ast_internal_classes.BinOp_Node(
lval=upper_boundary,
Expand Down
20 changes: 18 additions & 2 deletions dace/sdfg/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,15 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node):
if internal_memlet is None:
continue
try:
iedge.data = unsqueeze_memlet(internal_memlet, iedge.data, True)
ext_desc = parent_sdfg.arrays[iedge.data.data]
int_desc = sdfg.arrays[iedge.dst_conn]
iedge.data = unsqueeze_memlet(
internal_memlet,
iedge.data,
True,
internal_offset=int_desc.offset,
external_offset=ext_desc.offset
)
# If no appropriate memlet found, use array dimension
for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[iedge.data.data].shape)):
if rng[1] + 1 == s:
Expand All @@ -1123,7 +1131,15 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node):
if internal_memlet is None:
continue
try:
oedge.data = unsqueeze_memlet(internal_memlet, oedge.data, True)
ext_desc = parent_sdfg.arrays[oedge.data.data]
int_desc = sdfg.arrays[oedge.src_conn]
oedge.data = unsqueeze_memlet(
internal_memlet,
oedge.data,
True,
internal_offset=int_desc.offset,
external_offset=ext_desc.offset
)
# If no appropriate memlet found, use array dimension
for i, (rng, s) in enumerate(zip(internal_memlet.subset, parent_sdfg.arrays[oedge.data.data].shape)):
if rng[1] + 1 == s:
Expand Down
81 changes: 64 additions & 17 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,24 @@ def apply(self, state: SDFGState, sdfg: SDFG):
if (edge not in modified_edges and edge.data.data == node.data):
for e in state.memlet_tree(edge):
if e._data.get_dst_subset(e, state):
new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_dst_subset=True)
offset = sdfg.arrays[e.data.data].offset
new_memlet = helpers.unsqueeze_memlet(e.data,
outer_edge.data,
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)
e._data.dst_subset = new_memlet.subset
# NOTE: Node is source
for edge in state.out_edges(node):
if (edge not in modified_edges and edge.data.data == node.data):
for e in state.memlet_tree(edge):
if e._data.get_src_subset(e, state):
new_memlet = helpers.unsqueeze_memlet(e.data, outer_edge.data, use_src_subset=True)
offset = sdfg.arrays[e.data.data].offset
new_memlet = helpers.unsqueeze_memlet(e.data,
outer_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
e._data.src_subset = new_memlet.subset

# If source/sink node is not connected to a source/destination access
Expand Down Expand Up @@ -624,10 +634,17 @@ def _modify_access_to_access(self,
state.out_edges_by_connector(nsdfg_node, inner_data))
# Create memlet by unsqueezing both w.r.t. src and
# dst subsets
in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True)
offset = state.parent.arrays[top_edge.data.data].offset
in_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
out_memlet = helpers.unsqueeze_memlet(inner_edge.data,
matching_edge.data,
use_dst_subset=True)
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)
new_memlet = in_memlet
new_memlet.other_subset = out_memlet.subset

Expand All @@ -650,10 +667,18 @@ def _modify_access_to_access(self,
state.out_edges_by_connector(nsdfg_node, inner_data))
# Create memlet by unsqueezing both w.r.t. src and
# dst subsets
in_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data, use_src_subset=True)
offset = state.parent.arrays[top_edge.data.data].offset
in_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
use_src_subset=True,
internal_offset=offset,
external_offset=offset)
out_memlet = helpers.unsqueeze_memlet(inner_edge.data,
matching_edge.data,
use_dst_subset=True)
use_dst_subset=True,
internal_offset=offset,
external_offset=offset)

new_memlet = in_memlet
new_memlet.other_subset = out_memlet.subset

Expand Down Expand Up @@ -688,7 +713,11 @@ def _modify_memlet_path(
if inner_edge in edges_to_ignore:
new_memlet = inner_edge.data
else:
new_memlet = helpers.unsqueeze_memlet(inner_edge.data, top_edge.data)
offset = state.parent.arrays[top_edge.data.data].offset
new_memlet = helpers.unsqueeze_memlet(inner_edge.data,
top_edge.data,
internal_offset=offset,
external_offset=offset)
if inputs:
if inner_edge.dst in inner_to_outer:
dst = inner_to_outer[inner_edge.dst]
Expand All @@ -707,15 +736,19 @@ def _modify_memlet_path(
mtree = state.memlet_tree(new_edge)

# Modify all memlets going forward/backward
def traverse(mtree_node):
def traverse(mtree_node, state, nstate):
result.add(mtree_node.edge)
mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data, top_edge.data)
offset = state.parent.arrays[top_edge.data.data].offset
mtree_node.edge._data = helpers.unsqueeze_memlet(mtree_node.edge.data,
top_edge.data,
internal_offset=offset,
external_offset=offset)
for child in mtree_node.children:
traverse(child)
traverse(child, state, nstate)

result.add(new_edge)
for child in mtree.children:
traverse(child)
traverse(child, state, nstate)

return result

Expand Down Expand Up @@ -1035,7 +1068,8 @@ def _check_cand(candidates, outer_edges):

# If there are any symbols here that are not defined
# in "defined_symbols"
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - set(nsdfg.symbol_mapping.keys()))
missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) -
set(nsdfg.symbol_mapping.keys()))
if missing_symbols:
ignore.add(cname)
continue
Expand All @@ -1044,10 +1078,13 @@ def _check_cand(candidates, outer_edges):
_check_cand(out_candidates, state.out_edges_by_connector)

# Return result, filtering out the states
return ({k: (dc(v), ind)
for k, (v, _, ind) in in_candidates.items()
if k not in ignore}, {k: (dc(v), ind)
for k, (v, _, ind) in out_candidates.items() if k not in ignore})
return ({
k: (dc(v), ind)
for k, (v, _, ind) in in_candidates.items() if k not in ignore
}, {
k: (dc(v), ind)
for k, (v, _, ind) in out_candidates.items() if k not in ignore
})

def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False):
nsdfg = self.nsdfg
Expand All @@ -1070,7 +1107,17 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]],
outer_edge = next(iter(outer_edges(nsdfg_node, aname)))
except StopIteration:
continue
new_memlet = helpers.unsqueeze_memlet(refine, outer_edge.data)

if isinstance(outer_edge.dst, nodes.NestedSDFG):
conn = outer_edge.dst_conn
else:
conn = outer_edge.src_conn
int_desc = nsdfg.arrays[conn]
ext_desc = sdfg.arrays[outer_edge.data.data]
new_memlet = helpers.unsqueeze_memlet(refine,
outer_edge.data,
internal_offset=int_desc.offset,
external_offset=ext_desc.offset)
outer_edge.data.subset = subsets.Range([
ns if i in indices else os
for i, (os, ns) in enumerate(zip(outer_edge.data.subset, new_memlet.subset))
Expand Down
Loading