Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InsertDWC and multiple input/output nodes #1201

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions src/finn/transformation/fpgadataflow/insert_dwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def apply(self, model):
for n in graph.node:
node_ind += 1
if _suitable_node(n):
for output_name in n.output:
for out_idx, output_name in enumerate(n.output):
consumers = model.find_consumers(output_name)
if consumers == []:
continue
Expand All @@ -78,39 +78,28 @@ def apply(self, model):
if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n1 = getCustomOp(consumer)
n0_out_shape = n0.get_folded_output_shape()
# in some special cases, we need to get folded shapes of
# non-default inputs for the consumer
# - if FC and external mem, it could be connected to input 1
# - if concat, could be connected to any input
if (
consumer.op_type.startswith("MVAU")
and n1.get_nodeattr("mem_mode") == "external"
) or (consumer.op_type.startswith("StreamingConcat")):
# get input idx
in_idx = None
for idx, n_input in enumerate(consumer.input):
if output_name == n_input:
in_idx = idx
assert in_idx is not None, "Malformed model"
n1_in_shape = n1.get_folded_input_shape(in_idx)
else:
# use default folded input shape
n1_in_shape = n1.get_folded_input_shape()
n0_out_shape = n0.get_folded_output_shape(out_idx)
# get input idx
in_idx = None
for idx, n_input in enumerate(consumer.input):
if output_name == n_input:
in_idx = idx
assert in_idx is not None, "Malformed model"
n1_in_shape = n1.get_folded_input_shape(in_idx)

if n0_out_shape[-1] != n1_in_shape[-1]:
graph_modified = True
# determine dwc inwidth
dwc_in_width = n0.get_outstream_width()
dwc_in_width = n0.get_outstream_width(out_idx)
# determine dwc outwidth
dwc_out_width = n1.get_instream_width()
dwc_out_width = n1.get_instream_width(in_idx)
node_optype = "StreamingDataWidthConverter"

# determine shape for dwc
dwc_shape = n0.get_normal_output_shape()
dwc_shape = n0.get_normal_output_shape(out_idx)

# determine FINN dtype for dwc
dtype = n0.get_output_datatype()
dtype = n0.get_output_datatype(out_idx)
# determine onnx tensor dtype for dwc
n0_otensor = model.get_tensor_valueinfo(output_name)
n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type
Expand Down
Loading