Skip to content

Commit

Permalink
Fix (export/qonnx): uppercase rounding mode
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Pappalardo <[email protected]>
  • Loading branch information
volcacius committed Apr 20, 2023
1 parent 798ddc4 commit d781df2
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/export/onnx/qonnx/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def prepare_for_export(self, module):
'bit_width': module.bit_width(),
'narrow_range': module.is_narrow_range,
'signed': module.is_signed,
'rounding_mode': module.rounding_mode}
'rounding_mode': module.rounding_mode.upper()}

def symbolic_execution(self, x: Tensor):
scale = self.symbolic_kwargs['scale']
Expand Down Expand Up @@ -116,7 +116,7 @@ class BrevitasTruncQuantProxyHandler(ONNXBaseHandler):

def prepare_for_export(self, module: TruncQuantProxyFromInjector):
self.symbolic_kwargs = {
'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode}
'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode.upper()}

def symbolic_execution(
self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor,
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/bnn_pynq/bnn_pynq_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def parse_args(args):
parser.add_argument("--network", default="LFC_1W1A", type=str, help="neural network")
parser.add_argument("--pretrained", action='store_true', help="Load pretrained model")
parser.add_argument("--strict", action='store_true', help="Strict state dictionary loading")
parser.add_argument(
"--state_dict_to_pth",
action='store_true',
help="Saves a model state_dict into a pth and then exits")
return parser.parse_args(args)


Expand Down
8 changes: 0 additions & 8 deletions src/brevitas_examples/bnn_pynq/cfg/resnet18_3w3a.ini

This file was deleted.

1 change: 1 addition & 0 deletions src/brevitas_examples/bnn_pynq/cfg/resnet18_4w4a.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ARCH: RESNET18
DATASET: CIFAR10
NUM_CLASSES: 10
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r2/resnet18_4w4a-a172f334.pth

[QUANT]
WEIGHT_BIT_WIDTH: 4
Expand Down
7 changes: 1 addition & 6 deletions src/brevitas_examples/bnn_pynq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ def lfc_1w2a(pretrained=True):
return model


def resnet18_4w4a(pretrained=False):
def resnet18_4w4a(pretrained=True):
model, _ = model_with_cfg('resnet18_4w4a', pretrained)
return model


def resnet18_3w3a(pretrained=False):
model, _ = model_with_cfg('resnet18_3w3a', pretrained)
return model
14 changes: 14 additions & 0 deletions src/brevitas_examples/bnn_pynq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from datetime import datetime
from hashlib import sha256
import os
import random
import time
Expand Down Expand Up @@ -135,6 +136,19 @@ def __init__(self, args):
model_state_dict = package['state_dict']
model.load_state_dict(model_state_dict, strict=args.strict)

if args.state_dict_to_pth:
state_dict = model.state_dict()
name = args.network.lower()
path = os.path.join(self.checkpoints_dir_path, name)
torch.save(state_dict, path)
with open(path, "rb") as f:
bytes = f.read()
readable_hash = sha256(bytes).hexdigest()[:8]
new_path = path + '-' + readable_hash + '.pth'
os.rename(path, new_path)
self.logger.info("Saving checkpoint model to {}".format(new_path))
exit(0)

if args.gpus is not None and len(args.gpus) == 1:
model = model.to(device=self.device)
if args.gpus is not None and len(args.gpus) > 1:
Expand Down

0 comments on commit d781df2

Please sign in to comment.