Skip to content

Commit

Permalink
fix roi align symbolic function in onnx opset>=16 (#2428)
Browse files Browse the repository at this point in the history
  • Loading branch information
CescMessi committed Sep 14, 2023
1 parent bb031c6 commit ec35b40
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions mmdeploy/mmcv/ops/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,38 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
else:
from torch.onnx.symbolic_opset9 import _cast_Long
from torch.onnx.symbolic_opset11 import add, select
batch_indices = _cast_Long(
g,
g.op(
'Squeeze',
select(
g, rois, 1,
g.op(
'Constant',
value_t=torch.tensor([0], dtype=torch.long))),
axes_i=[1]), False)
ir_cfg = get_ir_config(ctx.cfg)
opset_version = ir_cfg.get('opset_version', 11)
if opset_version < 13:
batch_indices = _cast_Long(
g,
g.op(
'Squeeze',
select(
g, rois, 1,
g.op(
'Constant',
value_t=torch.tensor([0], dtype=torch.long))),
axes_i=[1]), False)
else:
axes = g.op(
'Constant', value_t=torch.tensor([1], dtype=torch.long))
batch_indices = _cast_Long(
g,
g.op(
'Squeeze',
select(
g, rois, 1,
g.op(
'Constant',
value_t=torch.tensor([0], dtype=torch.long))),
axes), False)
rois = select(
g, rois, 1,
g.op(
'Constant',
value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
ir_cfg = get_ir_config(ctx.cfg)
opset_version = ir_cfg.get('opset_version', 11)

if opset_version < 16:
# preprocess rois to make compatible with opset 16-
# as for opset 16+, `aligned` get implemented inside onnxruntime.
Expand All @@ -96,6 +111,10 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
sampling_ratio_i=sampling_ratio,
mode_s=pool_mode)
else:
if aligned:
coordinate_transformation_mode = 'half_pixel'
else:
coordinate_transformation_mode = 'output_half_pixel'
return g.op(
'RoiAlign',
input,
Expand All @@ -106,4 +125,5 @@ def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
spatial_scale_f=spatial_scale,
sampling_ratio_i=sampling_ratio,
mode_s=pool_mode,
aligned_i=aligned)
coordinate_transformation_mode_s=coordinate_transformation_mode
)

0 comments on commit ec35b40

Please sign in to comment.