diff --git a/src/brevitas/export/onnx/qonnx/handler.py b/src/brevitas/export/onnx/qonnx/handler.py index 34a4dbd9f..6aa17eebc 100644 --- a/src/brevitas/export/onnx/qonnx/handler.py +++ b/src/brevitas/export/onnx/qonnx/handler.py @@ -36,7 +36,7 @@ def prepare_for_export(self, module): 'bit_width': module.bit_width(), 'narrow_range': module.is_narrow_range, 'signed': module.is_signed, - 'rounding_mode': module.rounding_mode} + 'rounding_mode': module.rounding_mode.upper()} def symbolic_execution(self, x: Tensor): scale = self.symbolic_kwargs['scale'] @@ -116,7 +116,7 @@ class BrevitasTruncQuantProxyHandler(ONNXBaseHandler): def prepare_for_export(self, module: TruncQuantProxyFromInjector): self.symbolic_kwargs = { - 'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode} + 'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode.upper()} def symbolic_execution( self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor, diff --git a/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py b/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py index 06f8ea875..fd4281a3c 100644 --- a/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py +++ b/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py @@ -69,6 +69,10 @@ def parse_args(args): parser.add_argument("--network", default="LFC_1W1A", type=str, help="neural network") parser.add_argument("--pretrained", action='store_true', help="Load pretrained model") parser.add_argument("--strict", action='store_true', help="Strict state dictionary loading") + parser.add_argument( + "--state_dict_to_pth", + action='store_true', + help="Saves a model state_dict into a pth and then exits") return parser.parse_args(args) diff --git a/src/brevitas_examples/bnn_pynq/cfg/resnet18_3w3a.ini b/src/brevitas_examples/bnn_pynq/cfg/resnet18_3w3a.ini deleted file mode 100644 index 59ce49871..000000000 --- a/src/brevitas_examples/bnn_pynq/cfg/resnet18_3w3a.ini +++ /dev/null @@ -1,8 +0,0 @@ -[MODEL] -ARCH: RESNET18 -DATASET: CIFAR10 -NUM_CLASSES: 10 - -[QUANT] -WEIGHT_BIT_WIDTH: 3 -ACT_BIT_WIDTH: 3 diff --git a/src/brevitas_examples/bnn_pynq/cfg/resnet18_4w4a.ini b/src/brevitas_examples/bnn_pynq/cfg/resnet18_4w4a.ini index f6ebe2404..b695e522e 100644 --- a/src/brevitas_examples/bnn_pynq/cfg/resnet18_4w4a.ini +++ b/src/brevitas_examples/bnn_pynq/cfg/resnet18_4w4a.ini @@ -2,6 +2,7 @@ ARCH: RESNET18 DATASET: CIFAR10 NUM_CLASSES: 10 +PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r2/resnet18_4w4a-a172f334.pth [QUANT] WEIGHT_BIT_WIDTH: 4 diff --git a/src/brevitas_examples/bnn_pynq/models/__init__.py b/src/brevitas_examples/bnn_pynq/models/__init__.py index b958d510f..33753c7ab 100644 --- a/src/brevitas_examples/bnn_pynq/models/__init__.py +++ b/src/brevitas_examples/bnn_pynq/models/__init__.py @@ -104,11 +104,6 @@ def lfc_1w2a(pretrained=True): return model -def resnet18_4w4a(pretrained=False): +def resnet18_4w4a(pretrained=True): model, _ = model_with_cfg('resnet18_4w4a', pretrained) return model - - -def resnet18_3w3a(pretrained=False): - model, _ = model_with_cfg('resnet18_3w3a', pretrained) - return model diff --git a/src/brevitas_examples/bnn_pynq/trainer.py b/src/brevitas_examples/bnn_pynq/trainer.py index d9b1283a1..78c1db97e 100644 --- a/src/brevitas_examples/bnn_pynq/trainer.py +++ b/src/brevitas_examples/bnn_pynq/trainer.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from datetime import datetime +from hashlib import sha256 import os import random import time @@ -135,6 +136,19 @@ def __init__(self, args): model_state_dict = package['state_dict'] model.load_state_dict(model_state_dict, strict=args.strict) + if args.state_dict_to_pth: + state_dict = model.state_dict() + name = args.network.lower() + path = os.path.join(self.checkpoints_dir_path, name) + torch.save(state_dict, path) + with open(path, "rb") as f: + bytes = f.read() + readable_hash = sha256(bytes).hexdigest()[:8] + new_path = path + '-' + readable_hash + '.pth' + os.rename(path, new_path) + self.logger.info("Saving checkpoint model to {}".format(new_path)) + exit(0) + if args.gpus is not None and len(args.gpus) == 1: model = model.to(device=self.device) if args.gpus is not None and len(args.gpus) > 1: