Skip to content

Commit

Permalink
Update acclerator
Browse files Browse the repository at this point in the history
  • Loading branch information
liuhongwei committed May 14, 2024
1 parent 8209348 commit ec6b394
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions opencompass/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from mmengine.config import Config

from opencompass.datasets.custom import make_custom_dataset_config
from opencompass.models import VLLM, HuggingFaceCausalLM, TurboMindModel
from opencompass.models import (VLLM, HuggingFace, HuggingFaceCausalLM,
HuggingFaceChatGLM3, TurboMindModel)
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.runners import DLCRunner, LocalRunner, SlurmRunner
from opencompass.tasks import OpenICLEvalTask, OpenICLInferTask
Expand Down Expand Up @@ -190,7 +191,9 @@ def change_accelerator(models, accelerator):
for model in models:
get_logger().info(f'Transforming {model["abbr"]} to {accelerator}')
# change HuggingFace model to VLLM or TurboMindModel
if model['type'] is HuggingFaceCausalLM:
if model['type'] in [
HuggingFace, HuggingFaceCausalLM, HuggingFaceChatGLM3
]:
gen_args = dict()
if model.get('generation_kwargs') is not None:
generation_kwargs = model['generation_kwargs'].copy()
Expand Down Expand Up @@ -265,6 +268,8 @@ def change_accelerator(models, accelerator):
acc_model[item] = model[item]
else:
raise ValueError(f'Unsupported accelerator {accelerator}')
else:
raise ValueError(f'Unsupported model type {model["type"]}')
model_accels.append(acc_model)
return model_accels

Expand Down

0 comments on commit ec6b394

Please sign in to comment.