diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index 35afb1d36..9961d10b9 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -85,7 +85,6 @@ CLASSES_TO_IGNORE = [ - "T5ForSequenceClassification", # TODO: enable this class when it can be traced for pipeline parallelism. "LlamaForQuestionAnswering", ] @@ -128,7 +127,7 @@ def _generate_supported_model_classes( for task in supported_tasks: config_class = CONFIG_MAPPING[model_type] model_class = task_mapping[task].get(config_class, None) - if model_class is not None and model_class not in CLASSES_TO_IGNORE: + if model_class is not None and model_class.__name__ not in CLASSES_TO_IGNORE: model_classes.append(model_class) return list(set(model_classes))