diff --git a/src/finn/transformation/fpgadataflow/insert_dwc.py b/src/finn/transformation/fpgadataflow/insert_dwc.py index b56c8b74e..2e3f0bf4d 100644 --- a/src/finn/transformation/fpgadataflow/insert_dwc.py +++ b/src/finn/transformation/fpgadataflow/insert_dwc.py @@ -67,7 +67,7 @@ def apply(self, model): for n in graph.node: node_ind += 1 if _suitable_node(n): - for output_name in n.output: + for out_idx, output_name in enumerate(n.output): consumers = model.find_consumers(output_name) if consumers == []: continue @@ -78,39 +78,28 @@ def apply(self, model): if _suitable_node(consumer) is True: n0 = getCustomOp(n) n1 = getCustomOp(consumer) - n0_out_shape = n0.get_folded_output_shape() - # in some special cases, we need to get folded shapes of - # non-default inputs for the consumer - # - if FC and external mem, it could be connected to input 1 - # - if concat, could be connected to any input - if ( - consumer.op_type.startswith("MVAU") - and n1.get_nodeattr("mem_mode") == "external" - ) or (consumer.op_type.startswith("StreamingConcat")): - # get input idx - in_idx = None - for idx, n_input in enumerate(consumer.input): - if output_name == n_input: - in_idx = idx - assert in_idx is not None, "Malformed model" - n1_in_shape = n1.get_folded_input_shape(in_idx) - else: - # use default folded input shape - n1_in_shape = n1.get_folded_input_shape() + n0_out_shape = n0.get_folded_output_shape(out_idx) + # get input idx + in_idx = None + for idx, n_input in enumerate(consumer.input): + if output_name == n_input: + in_idx = idx + assert in_idx is not None, "Malformed model" + n1_in_shape = n1.get_folded_input_shape(in_idx) if n0_out_shape[-1] != n1_in_shape[-1]: graph_modified = True # determine dwc inwidth - dwc_in_width = n0.get_outstream_width() + dwc_in_width = n0.get_outstream_width(out_idx) # determine dwc outwidth - dwc_out_width = n1.get_instream_width() + dwc_out_width = n1.get_instream_width(in_idx) node_optype = "StreamingDataWidthConverter" # determine shape for dwc - dwc_shape = n0.get_normal_output_shape() + dwc_shape = n0.get_normal_output_shape(out_idx) # determine FINN dtype for dwc - dtype = n0.get_output_datatype() + dtype = n0.get_output_datatype(out_idx) # determine onnx tensor dtype for dwc n0_otensor = model.get_tensor_valueinfo(output_name) n0_tensor_dtype = n0_otensor.type.tensor_type.elem_type