diff --git a/Util.py b/Util.py index 2a495db4..9ccd707d 100644 --- a/Util.py +++ b/Util.py @@ -72,14 +72,14 @@ def load_calibration_dataset(directory: str, input_shape: List[int], batchsize: def quantize(working_directory: str, setting: QuantizationSetting, model_type: NetworkFramework, executing_device: str, input_shape: List[int], target_platform: TargetPlatform, - dataloader: DataLoader, calib_steps: int =32) -> BaseGraph: + dataloader: DataLoader, calib_steps: int=32) -> BaseGraph: if model_type == NetworkFramework.ONNX: if not os.path.exists(os.path.join(working_directory, 'model.onnx')): raise FileNotFoundError(f'无法找到你的模型: {os.path.join(working_directory, "model.onnx")},' '如果你使用caffe的模型,请设置MODEL_TYPE为CAFFE') return quantize_onnx_model( onnx_import_file=os.path.join(working_directory, 'model.onnx'), - calib_dataloader=dataloader, calib_steps=32, input_shape=input_shape, setting=setting, + calib_dataloader=dataloader, calib_steps=calib_steps, input_shape=input_shape, setting=setting, platform=target_platform, device=executing_device, collate_fn=lambda x: x.to(executing_device) ) if model_type == NetworkFramework.CAFFE: