Skip to content

Commit

Permalink
Merge pull request #1151 from iksnagreb/fix/convolutioninputgenerator…
Browse files Browse the repository at this point in the history
…-datatype-inference

[ConvolutionInputGenerator] Make infer_node_datatype update attributes
  • Loading branch information
auphelia authored Aug 13, 2024
2 parents 313316d + 84cbc0c commit 2580269
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import warnings
from onnx import TensorProto, helper
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
Expand Down Expand Up @@ -141,6 +142,27 @@ def infer_node_datatype(self, model):
node = self.onnx_node
# data type stays the same
dtype = model.get_tensor_datatype(node.input[0])

# Test for changing input datatype
if dtype != self.get_nodeattr("inputDataType"):
# Issue a warning message
warnings.warn(
f"{node.name}: inputDataType changing from"
f" {self.get_nodeattr('inputDataType')} to {dtype}"
)
# Set the new datatype attribute
self.set_nodeattr("inputDataType", dtype.name)

# Test for changing output datatype
if dtype != self.get_nodeattr("outputDataType"):
# Issue a warning message
warnings.warn(
f"{node.name}: outputDataType changing from"
f" {self.get_nodeattr('outputDataType')} to {dtype}"
)
# Set the new datatype attribute
self.set_nodeattr("outputDataType", dtype.name)
# Propagate the datatype through the model graph
model.set_tensor_datatype(node.output[0], dtype)

def verify_node(self):
Expand Down

0 comments on commit 2580269

Please sign in to comment.