Skip to content

Commit

Permalink
[ConvolutionInputGenerator] Make infer_node_datatype update attributes
Browse files Browse the repository at this point in the history
Without updating the datatype attributes of the node, there might be a
mismatch between tensor annotations (the actual datatype) and the type
assumed by the node. This becomes an issue for example when querying the
bit-width of the stream when inserting data-width converters.
  • Loading branch information
iksnagreb committed Aug 7, 2024
1 parent 912cadf commit 84cbc0c
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import warnings
from onnx import TensorProto, helper
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
Expand Down Expand Up @@ -141,6 +142,27 @@ def infer_node_datatype(self, model):
node = self.onnx_node
# data type stays the same
dtype = model.get_tensor_datatype(node.input[0])

# Test for changing input datatype
if dtype != self.get_nodeattr("inputDataType"):
# Issue a warning message
warnings.warn(
f"{node.name}: inputDataType changing from"
f" {self.get_nodeattr('inputDataType')} to {dtype}"
)
# Set the new datatype attribute
self.set_nodeattr("inputDataType", dtype.name)

# Test for changing output datatype
if dtype != self.get_nodeattr("outputDataType"):
# Issue a warning message
warnings.warn(
f"{node.name}: outputDataType changing from"
f" {self.get_nodeattr('outputDataType')} to {dtype}"
)
# Set the new datatype attribute
self.set_nodeattr("outputDataType", dtype.name)
# Propagate the datatype through the model graph
model.set_tensor_datatype(node.output[0], dtype)

def verify_node(self):
Expand Down

0 comments on commit 84cbc0c

Please sign in to comment.