Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converting to onnx #154

Open
cjenkins5614 opened this issue Nov 20, 2021 · 2 comments
Open

Converting to onnx #154

cjenkins5614 opened this issue Nov 20, 2021 · 2 comments

Comments

@cjenkins5614
Copy link

Hello,

Thanks for the great work. I'm trying to convert this model into onnx, but have met a few issues.

The mv and dot operator used by PyTorch's spectral_norm was one of them. Following onnx/onnx#3006 (comment) I coverted them to matmul in my own implementation of spectral_norm and the issue went away.

Now it's complaining:

Traceback (most recent call last):
    out = torch.onnx.export(model, input_dict["image"], "model.onnx", verbose=False, opset_version=11,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 271, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 694, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 463, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 206, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 309, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 997, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 148, in wrapper
    return fn(g, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 1285, in batch_norm
    if weight is None or sym_help._is_none(weight):
RuntimeError: Unsupported: ONNX export of batch_norm for unknown channel size.

The code to convert this is:

    opt = EasyDict(aspect_ratio=1.0,
                checkpoints_dir='Face_Enhancement/checkpoints',
                contain_dontcare_label=False,
                crop_size=256,
                gpu_ids=[0],
                init_type='xavier',
                init_variance=0.02,
                injection_layer='all',
                isTrain=False,
                label_nc=18,
                load_size=256,
                model='pix2pix',
                name='Setting_9_epoch_100',
                nef=16,
                netG='spade',
                ngf=64,
                no_flip=True,
                no_instance=True,
                no_parsing_map=True,
                norm_D='spectralinstance',
                norm_E='spectralinstance',
                # norm_G='spectralspadebatch3x3',
                norm_G='spectralspadesyncbatch3x3',
                num_upsampling_layers='normal',
                output_nc=3,
                preprocess_mode='resize',
                semantic_nc=18,
                use_vae=False,
                which_epoch='latest',
                z_dim=256)

    model = Pix2PixModel(opt)
    model.eval()

    input_dict = {
        "label": torch.zeros((1, 18, 256, 256)),
        "image": torch.randn(1, 3, 256, 256),
        "path": None,
    }

    # from torchsummary import summary
    # summary(model, (3, 256, 256))
    out = torch.onnx.export(model, input_dict, "model.onnx", verbose=False, opset_version=11,
                      input_names = ['input'],
                      output_names = ['output'])

I printed out the graph g from https://github.com/pytorch/pytorch/blob/e56d3b023818f54553f2dc5d30b6b7aaf6b6a325/torch/onnx/symbolic_opset9.py#L1337

...
  %450 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 1  1 [ CPULongType{2} ]]()
  %451 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %452 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= 0  0 [ CPULongType{2} ]]()
  %453 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %454 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %455 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %456 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %457 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %458 : Float(*, 1024, *, *, strides=[65536, 64, 8, 1], requires_grad=0, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%436, %447, %netG.head_0.conv_1.bias) # /usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:395:0
  %459 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %460 : Float(*, 1024, *, *, strides=[65536, 64, 8, 1], requires_grad=0, device=cuda:0) = onnx::Add(%266, %458) # /workdir/Face_Enhancement/models/networks/architecture.py:56:0
  %461 : None = prim::Constant()
  %462 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 2  2 [ CPUFloatType{2} ]]()
  %463 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 1  1 [ CPUFloatType{2} ]]()
  %464 : Float(2, strides=[1], device=cpu) = onnx::Constant[value= 2  2 [ CPUFloatType{2} ]]()
  %465 : Float(4, strides=[1], device=cpu) = onnx::Concat[axis=0](%463, %464)
  %466 : Float(0, strides=[1], device=cpu) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %467 : Float(*, *, *, *, strides=[262144, 256, 16, 1], requires_grad=0, device=cuda:0) = onnx::Resize[coordinate_transformation_mode="asymmetric", cubic_coeff_a=-0.75, mode="nearest", nearest_mode="floor"](%460, %466, %465) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3535:0
  %468 : None = prim::Constant()
  %469 : None = prim::Constant()
  %470 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={0}]()
  %471 : Double(requires_grad=0, device=cpu) = onnx::Constant[value={0.1}]()
  %472 : Double(requires_grad=0, device=cpu) = onnx::Constant[value={1e-05}]()
  %473 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  return ()

ipdb> input
467 defined in (%467 : Float(*, *, *, *, strides=[262144, 256, 16, 1], requires_grad=0, device=cuda:0) = onnx::Resize[coordinate_transformation_mode="asymmetric", cubic_coeff_a=-0.75, mode="nearest", nearest_mode="floor"](%460, %466, %465) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3535:0
)
ipdb> weight
468 defined in (%468 : None = prim::Constant()
)
ipdb> bias
469 defined in (%469 : None = prim::Constant()
)

Float(*, *, *, * stood out to me but I'm not sure how to interpret this.

@ymzlygw
Copy link

ymzlygw commented Mar 31, 2022

Hi , did you find any solution now?

@eaidova
Copy link

eaidova commented Aug 18, 2022

Not sure that this solution is right, I downgraded torch to 1.7 and then model converted to onnx. Looks like some bug on torch to onnx conversion side that upsample in new versions to produce dynamic shapes which lead to error for batch norms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants