diff --git a/mmdeploy/codebase/mmpretrain/deploy/classification.py b/mmdeploy/codebase/mmpretrain/deploy/classification.py index 499b921349..093d6402e0 100644 --- a/mmdeploy/codebase/mmpretrain/deploy/classification.py +++ b/mmdeploy/codebase/mmpretrain/deploy/classification.py @@ -332,7 +332,8 @@ def get_postprocess(self, *args, **kwargs) -> Dict: dict: Composed of the postprocess information. """ postprocess = self.model_cfg.model.head - if postprocess['type'] in ('EfficientFormerClsHead', 'StackedLinearClsHead'): + if postprocess['type'] in ('EfficientFormerClsHead', + 'StackedLinearClsHead'): postprocess['type'] = 'LinearClsHead' if 'topk' not in postprocess: