Skip to content

Commit

Permalink
[Fix] InsertDWC now properly handles multiple input/output nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Danilowicz committed Sep 24, 2024
1 parent 71b546b commit cedfc44
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions src/finn/transformation/fpgadataflow/insert_dwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def apply(self, model):
for n in graph.node:
node_ind += 1
if _suitable_node(n):
for output_name in n.output:
for out_idx, output_name in enumerate(n.output):
consumers = model.find_consumers(output_name)
if consumers == []:
continue
Expand All @@ -78,39 +78,28 @@ def apply(self, model):
if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n1 = getCustomOp(consumer)
n0_out_shape = n0.get_folded_output_shape()
# in some special cases, we need to get folded shapes of
# non-default inputs for the consumer
# - if FC and external mem, it could be connected to input 1
# - if concat, could be connected to any input
if (
consumer.op_type.startswith("MVAU")
and n1.get_nodeattr("mem_mode") == "external"
) or (consumer.op_type.startswith("StreamingConcat")):
# get input idx
in_idx = None
for idx, n_input in enumerate(consumer.input):
if output_name == n_input:
in_idx = idx
assert in_idx is not None, "Malformed model"
n1_in_shape = n1.get_folded_input_shape(in_idx)
else:
# use default folded input shape
n1_in_shape = n1.get_folded_input_shape()
n0_out_shape = n0.get_folded_output_shape(out_idx)
# get input idx
in_idx = None
for idx, n_input in enumerate(consumer.input):
if output_name == n_input:
in_idx = idx
assert in_idx is not None, "Malformed model"
n1_in_shape = n1.get_folded_input_shape(in_idx)

if n0_out_shape[-1] != n1_in_shape[-1]:
graph_modified = True
# determine dwc inwidth
dwc_in_width = n0.get_outstream_width()
dwc_in_width = n0.get_outstream_width(out_idx)
# determine dwc outwidth
dwc_out_width = n1.get_instream_width()
dwc_out_width = n1.get_instream_width(in_idx)
node_optype = "StreamingDataWidthConverter"

# determine shape for dwc
dwc_shape = n0.get_normal_output_shape()
dwc_shape = n0.get_normal_output_shape(out_idx)

# determine FINN dtype for dwc
dtype = n0.get_output_datatype()
dtype = n0.get_output_datatype(out_idx)
# determine onnx tensor dtype for dwc
n0_otensor = model.get_tensor_valueinfo(output_name)
n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type
Expand Down

0 comments on commit cedfc44

Please sign in to comment.