From b3db8291cabe320b76ddbb8477d1caaa25f117e3 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 May 2024 14:00:39 +0200 Subject: [PATCH] fix typing --- optimum/utils/preprocessing/task_processors_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/utils/preprocessing/task_processors_manager.py b/optimum/utils/preprocessing/task_processors_manager.py index 2720ed41fb..0426d1a2b4 100644 --- a/optimum/utils/preprocessing/task_processors_manager.py +++ b/optimum/utils/preprocessing/task_processors_manager.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: - from .base import DatasetProcessing + from .base import TaskProcessor class TaskProcessorsManager: @@ -35,7 +35,7 @@ class TaskProcessorsManager: } @classmethod - def get_task_processor_class_for_task(cls, task: str) -> Type: + def get_task_processor_class_for_task(cls, task: str) -> Type["TaskProcessor"]: if task not in cls._TASK_TO_DATASET_PROCESSING_CLASS: supported_tasks = ", ".join(cls._TASK_TO_DATASET_PROCESSING_CLASS.keys()) raise KeyError( @@ -45,5 +45,5 @@ def get_task_processor_class_for_task(cls, task: str) -> Type: return cls._TASK_TO_DATASET_PROCESSING_CLASS[task] @classmethod - def for_task(cls, task: str, *dataset_processing_args, **dataset_processing_kwargs: Any) -> "DatasetProcessing": + def for_task(cls, task: str, *dataset_processing_args, **dataset_processing_kwargs: Any) -> "TaskProcessor": return cls.get_task_processor_class_for_task(task)(*dataset_processing_args, **dataset_processing_kwargs)