Skip to content

Commit

Permalink
fix int8
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Feb 18, 2024
1 parent 6e24f27 commit 8328ea3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 36 deletions.
18 changes: 9 additions & 9 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
- [Model Inference](#model-inference)
- [**Q: Inference take a long time on a single image.**](#q-inference-take-a-long-time-on-a-single-image)
- [**Q: Memory leak when inference.**](#q-memory-leak-when-inference)
- [**Q: error: parameter check failed at: engine.cpp::setBindingDimensions::1046, condition: profileMinDims.d[i] <= dimensions.d[i]**](#q-error-parameter-check-failed-at-enginecppsetbindingdimensions1046-condition-profilemindimsdi--dimensionsdi)
- [**Q: error: parameter check failed at: engine.cpp::setBindingDimensions::1046, condition: profileMinDims.d\[i\] \<= dimensions.d\[i\]**](#q-error-parameter-check-failed-at-enginecppsetbindingdimensions1046-condition-profilemindimsdi--dimensionsdi)
- [**Q: FP16 model is slower than FP32 model**](#q-fp16-model-is-slower-than-fp32-model)
- [**Q: error: [TensorRT] INTERNAL ERROR: Assertion failed: cublasStatus == CUBLAS_STATUS_SUCCESS**](#q-error-tensorrt-internal-error-assertion-failed-cublasstatus--cublas_status_success)
- [**Q: error: \[TensorRT\] INTERNAL ERROR: Assertion failed: cublasStatus == CUBLAS\_STATUS\_SUCCESS**](#q-error-tensorrt-internal-error-assertion-failed-cublasstatus--cublas_status_success)

This page provides some frequently asked questions and their solutions.

Expand All @@ -33,13 +33,13 @@ This is a bug of on old version TensorRT, read [this](https://forums.developer.n
The input tensor shape is out of the range. Please enlarge the `opt_shape_param` when converting the model.

```python
opt_shape_param=[
[
[1,3,224,224], # min tensor shape
[1,3,800,1312], # shape used to do int8 calib
[1,3,1344,1344], # max tensor shape
]
]
shape_ranges=dict(
x=dict(
min=[1,3,320,320],
opt=[1,3,800,1344],
max=[1,3,1344,1344],
)
)
```

### **Q: FP16 model is slower than FP32 model**
Expand Down
28 changes: 13 additions & 15 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,25 @@ trt_model = mmdet2trt( ...,

## int8 support

**int8 mode** needs more configs.

- set `input8_mode=True`.
- provide calibrate dataset, the `__getitem__()` method of dataset should return a list of tensor with shape (C,H,W), the shape **must** be the same as `opt_shape_param[0][1][1:]` (optimize shape). The tensor should do the same preprocess as the model. There is a default dataset, you can also set your custom one.
- provide calibrate dataset, the `__getitem__()` method of dataset should return a list of tensor with shape (C,H,W), the shape **must** be the same as `shape_range['x']['opt'][1:]` (optimize shape). The tensor should do the same preprocess as the model. There is a default dataset, you can also set your custom one.
- set the calibrate algorithm, support `entropy` and `minmax`.

```python
from mmdet2trt import mmdet2trt, Int8CalibDataset
cfg_path="..." # mmdetection config path
model_path="..." # mmdetection checkpoint path
image_path_list = [...] # lists of image pathes
opt_shape_param=[
[
[...],
[...],
[...],
]
]
calib_dataset = Int8CalibDataset(image_path_list, cfg_path, opt_shape_param)
cfg_path="..." # MMDetection config path
model_path="..." # MMDetection checkpoint path
image_path_list = [...] # lists of image paths
shape_ranges=dict(
x=dict(
min=[...],
opt=[...],
max=[...],
)
)
calib_dataset = Int8CalibDataset(image_path_list, cfg_path, shape_ranges)
trt_model = mmdet2trt(cfg_path, model_path,
opt_shape_param=opt_shape_param,
shape_ranges=shape_ranges,
int8_mode=True,
int8_calib_dataset=calib_dataset,
int8_calib_alg="entropy")
Expand Down
24 changes: 12 additions & 12 deletions mmdet2trt/mmdet2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import time
from typing import Any, Dict

import mmengine
import tensorrt as trt
import torch
from mmdet2trt.models.builder import build_wrapper
from mmdet2trt.models.detectors import TwoStageDetectorWraper
from mmdet.apis import init_detector
from torch2trt_dynamic import BuildEngineConfig, module2trt

import mmcv

logger = logging.getLogger('mmdet2trt')


Expand All @@ -20,25 +19,26 @@ class Int8CalibDataset():
feed to int8_calib_dataset
"""

def __init__(self, image_paths, config, opt_shape_param):
def __init__(self, image_paths, config, shape_ranges):
r"""
datas used to calibrate int8 model
feed to int8_calib_dataset
Args:
image_paths (list[str]): image paths to calib
config (str|dict): config of mmdetection model
opt_shape_param: same as mmdet2trt
shape_ranges: same as mmdet2trt
"""
from mmdet.apis.inference import LoadImage
from mmdet.datasets.pipelines import Compose
from mmcv.transforms import Compose
from mmengine.registry import init_default_scope
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
config = mmengine.Config.fromfile(config)

init_default_scope(config.get('default_scope', 'mmdet'))
self.cfg = config
self.image_paths = image_paths
self.opt_shape = opt_shape_param[0][1]
self.opt_shape = shape_ranges['x']['opt']

test_pipeline = [LoadImage()] + config.data.test.pipeline[1:]
test_pipeline = config.val_dataloader.dataset.pipeline
self.test_pipeline = Compose(test_pipeline)

def __len__(self):
Expand All @@ -47,14 +47,14 @@ def __len__(self):
def __getitem__(self, index):
image_path = self.image_paths[index]

data = dict(img=image_path)
data = dict(img=image_path, img_path=image_path)
data = self.test_pipeline(data)

tensor = data['img'][0].unsqueeze(0)
tensor = data['inputs'].unsqueeze(0)
tensor = torch.nn.functional.interpolate(
tensor, self.opt_shape[-2:]).squeeze(0)

return [tensor]
return dict(x=tensor.cuda())


def _get_shape_ranges(config):
Expand Down

0 comments on commit 8328ea3

Please sign in to comment.