diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md index 84e1fe5922..dba7b25d27 100644 --- a/docs/en/04-supported-codebases/mmdet.md +++ b/docs/en/04-supported-codebases/mmdet.md @@ -190,35 +190,40 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter ## Supported models -| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | -| :------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | -| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | Object Detection | Y | Y | N | N | Y | -| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | Object Detection | Y | Y | Y | N | Y | -| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | Object Detection | Y | N | N | N | Y | -| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y | -| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y | -| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | Object Detection | Y | Y | Y | N | Y | -| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | Object Detection | N | N | N | N | Y | -| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | Object Detection | Y | Y | Y | N | Y | -| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | Object Detection | Y | Y | Y | N | Y | -| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y | -| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | -| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | -| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y | -| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y | -| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y | -| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | Y | -| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | -| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | -| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | -| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | -| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | Y | N | N | N | Y | -| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | -| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | -| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | -| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | +| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | +| :-----------------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | +| [ATSS](https://github.com/open-mmlab/mmdetection/tree/main/configs/atss) | Object Detection | Y | Y | N | N | Y | +| [FCOS](https://github.com/open-mmlab/mmdetection/tree/main/configs/fcos) | Object Detection | Y | Y | Y | N | Y | +| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/main/configs/foveabox) | Object Detection | Y | N | N | N | Y | +| [FSAF](https://github.com/open-mmlab/mmdetection/tree/main/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y | +| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y | +| [SSD](https://github.com/open-mmlab/mmdetection/tree/main/configs/ssd) | Object Detection | Y | Y | Y | N | Y | +| [VFNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/vfnet) | Object Detection | N | N | N | N | Y | +| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolo) | Object Detection | Y | Y | Y | N | Y | +| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolox) | Object Detection | Y | Y | Y | N | Y | +| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y | +| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | +| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | +| [GFL](https://github.com/open-mmlab/mmdetection/tree/main/configs/gfl) | Object Detection | Y | Y | N | ? | Y | +| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/main/configs/reppoints) | Object Detection | N | Y | N | ? | Y | +| [DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [Deformable DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/deformable_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [Conditional DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/conditional_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [DAB-DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/dab_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [DINO](https://github.com/open-mmlab/mmdetection/tree/main/configs/dino)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y | +| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | +| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | +| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | +| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | +| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | +| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | +| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | ## Reminder - For transformer based models, strongly suggest use `TensorRT>=8.4`. - Mask2Former should use `TensorRT>=8.6.1` for dynamic shape inference. +- DETR-like models do not support multi-batch inference. diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md index 17c501630f..37bfe072a1 100644 --- a/docs/zh_cn/04-supported-codebases/mmdet.md +++ b/docs/zh_cn/04-supported-codebases/mmdet.md @@ -192,35 +192,40 @@ cv2.imwrite('output_detection.png', img) ## 模型支持列表 -| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | -| :------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | -| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | Object Detection | Y | Y | N | N | Y | -| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | Object Detection | Y | Y | Y | N | Y | -| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | Object Detection | Y | N | N | N | Y | -| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y | -| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y | -| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | Object Detection | Y | Y | Y | N | Y | -| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | Object Detection | N | N | N | N | Y | -| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | Object Detection | Y | Y | Y | N | Y | -| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | Object Detection | Y | Y | Y | N | Y | -| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y | -| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | -| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | -| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y | -| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y | -| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y | -| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | Y | -| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | -| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | -| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | -| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | -| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | Y | N | N | N | Y | -| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | -| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | -| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | -| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | +| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | +| :-----------------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | +| [ATSS](https://github.com/open-mmlab/mmdetection/tree/main/configs/atss) | Object Detection | Y | Y | N | N | Y | +| [FCOS](https://github.com/open-mmlab/mmdetection/tree/main/configs/fcos) | Object Detection | Y | Y | Y | N | Y | +| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/main/configs/foveabox) | Object Detection | Y | N | N | N | Y | +| [FSAF](https://github.com/open-mmlab/mmdetection/tree/main/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y | +| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y | +| [SSD](https://github.com/open-mmlab/mmdetection/tree/main/configs/ssd) | Object Detection | Y | Y | Y | N | Y | +| [VFNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/vfnet) | Object Detection | N | N | N | N | Y | +| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolo) | Object Detection | Y | Y | Y | N | Y | +| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolox) | Object Detection | Y | Y | Y | N | Y | +| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y | +| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | +| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y | +| [GFL](https://github.com/open-mmlab/mmdetection/tree/main/configs/gfl) | Object Detection | Y | Y | N | ? | Y | +| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/main/configs/reppoints) | Object Detection | N | Y | N | ? | Y | +| [DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [Deformable DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/deformable_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [Conditional DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/conditional_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [DAB-DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/dab_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [DINO](https://github.com/open-mmlab/mmdetection/tree/main/configs/dino)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y | +| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y | +| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | +| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | +| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | +| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | +| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | +| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | +| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | ## 注意事项 - 强烈建议使用`TensorRT>=8.4`来转换基于 `transformer` 的模型. - Mask2Former 请使用 `TensorRT>=8.6.1` 以保证动态尺寸正常推理. +- DETR系列模型 不支持多批次推理. diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py index 08121bdfbb..bb2bdee2a8 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py @@ -8,6 +8,8 @@ from mmdeploy.core import FUNCTION_REWRITER +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.dense_heads.DeformableDETRHead.predict_by_feat') @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.DETRHead.predict_by_feat') def detrhead__predict_by_feat__default(self, @@ -17,10 +19,15 @@ def detrhead__predict_by_feat__default(self, rescale: bool = True): """Rewrite `predict_by_feat` of `FoveaHead` for default backend.""" from mmdet.structures.bbox import bbox_cxcywh_to_xyxy + cls_scores = all_cls_scores_list[-1] bbox_preds = all_bbox_preds_list[-1] - img_shape = batch_img_metas[0]['img_shape'] + if isinstance(img_shape, list): + img_shape = torch.tensor( + img_shape, dtype=torch.long, device=cls_scores.device) + img_shape = img_shape.unsqueeze(0) + max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0])) batch_size = cls_scores.size(0) # `batch_index_offset` is used for the gather of concatenated tensor @@ -49,19 +56,9 @@ def detrhead__predict_by_feat__default(self, ...].squeeze(-1) det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds) - - if isinstance(img_shape, torch.Tensor): - hw = img_shape.flip(0).to(det_bboxes.device) - else: - hw = det_bboxes.new_tensor([img_shape[1], img_shape[0]]) - shape_scale = torch.cat([hw, hw]) - shape_scale = shape_scale.view(1, 1, -1) + det_bboxes.clamp_(min=0., max=1.) + shape_scale = img_shape.flip(1).repeat(1, 2).unsqueeze(1) det_bboxes = det_bboxes * shape_scale - # dynamically clip bboxes - x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1) - from mmdeploy.codebase.mmdet.deploy import clip_bboxes - x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, img_shape) - det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1) det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1) return det_bboxes, det_labels diff --git a/mmdeploy/codebase/mmdet/models/detectors/__init__.py b/mmdeploy/codebase/mmdet/models/detectors/__init__.py index 460694aa72..ac1d82d7a6 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/__init__.py +++ b/mmdeploy/codebase/mmdet/models/detectors/__init__.py @@ -3,6 +3,10 @@ single_stage, single_stage_instance_seg, two_stage) __all__ = [ - 'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage', - 'panoptic_two_stage_segmentor', 'maskformer' + 'base_detr', + 'single_stage', + 'single_stage_instance_seg', + 'two_stage', + 'panoptic_two_stage_segmentor', + 'maskformer', ] diff --git a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py index 3531c9183c..42c0cf45f6 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py +++ b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py @@ -47,8 +47,8 @@ def _set_metainfo(data_samples, img_shape): @FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.detectors.base_detr.DetectionTransformer.predict') -def detection_transformer__predict(self, + 'mmdet.models.detectors.base_detr.DetectionTransformer.forward') +def detection_transformer__forward(self, batch_inputs: torch.Tensor, data_samples: OptSampleList = None, rescale: bool = True, @@ -79,11 +79,11 @@ def detection_transformer__predict(self, # get origin input shape as tensor to support onnx dynamic shape is_dynamic_flag = is_dynamic_shape(deploy_cfg) - img_shape = torch._shape_as_tensor(batch_inputs)[2:] + img_shape = torch._shape_as_tensor(batch_inputs)[2:].to( + batch_inputs.device) if not is_dynamic_flag: img_shape = [int(val) for val in img_shape] # set the metainfo data_samples = _set_metainfo(data_samples, img_shape) - return __predict_impl(self, batch_inputs, data_samples, rescale) diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py index 5565596fee..6af418421a 100644 --- a/mmdeploy/utils/config_utils.py +++ b/mmdeploy/utils/config_utils.py @@ -245,7 +245,7 @@ def is_dynamic_shape(deploy_cfg: Union[str, mmengine.Config], return False # check if 2 (height) and 3 (width) in input axes - if 2 in input_axes and 3 in input_axes: + if 2 in input_axes or 3 in input_axes: return True return False diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index 1df7404e5e..f0e813ce8e 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -311,10 +311,7 @@ models: - configs/detr/detr_r50_8xb2-150e_coco.py pipelines: - *pipeline_ort_dynamic_fp32 - - deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py - convert_image: *convert_image - backend_test: *default_backend_test - sdk_config: *sdk_dynamic + - *pipeline_trt_dynamic_fp32 - name: CenterNet metafile: configs/centernet/metafile.yml @@ -416,3 +413,35 @@ models: - deploy_config: configs/mmdet/panoptic-seg/panoptic-seg_maskformer_tensorrt_static-800x1344.py convert_image: *convert_image backend_test: *default_backend_test + + - name: DINO + metafile: configs/dino/metafile.yml + model_configs: + - configs/dino/dino-4scale_r50_8xb2-12e_coco.py + pipelines: + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 + + - name: ConditionalDETR + metafile: configs/conditional_detr/metafile.yml + model_configs: + - configs/conditional_detr/conditional-detr_r50_8xb2-50e_coco.py + pipelines: + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 + + - name: DAB-DETR + metafile: configs/dab_detr/metafile.yml + model_configs: + - configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py + pipelines: + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 + + - name: DeformableDETR + metafile: configs/deformable_detr/metafile.yml + model_configs: + - configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py + pipelines: + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp32 diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 78b3255b06..ca1c5c1255 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -727,7 +727,7 @@ def test_predict_of_detr_detector(model_cfg_path, backend): from mmdet.structures import DetDataSample data_sample = DetDataSample(metainfo=dict(batch_input_shape=(64, 64))) rewrite_inputs = {'batch_inputs': img} - wrapped_model = WrapModel(model, 'predict', data_samples=[data_sample]) + wrapped_model = WrapModel(model, 'forward', data_samples=[data_sample]) rewrite_outputs, _ = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs,