Skip to content

Commit

Permalink
Fix quantization model loading logic
Browse files Browse the repository at this point in the history
  • Loading branch information
danielz02 committed Sep 26, 2023
1 parent f332775 commit d2b66be
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/helm/proxy/clients/huggingface_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def __init__(self, model_config: HuggingFaceModelConfig):
with htrack_block(f"Loading Hugging Face model for config {model_config}"):
# WARNING this may fail if your GPU does not have enough memory
quantization_config: Optional[HuggingfaceModelQuantizationConfig] = model_config.quantization_config
if quantization_config and model_config.quantization_config.model_loader == ModelLoader.AWQ:
if quantization_config and quantization_config.model_loader == ModelLoader.AWQ:
from awq import AutoAWQForCausalLM
self.model = AutoAWQForCausalLM.from_quantized(
model_name, model_config.quantization_config.quant_file, fuse_layers=True
)
elif quantization_config and model_config.quantization_config.model_loader == ModelLoader.GPTQ:
elif quantization_config and quantization_config.model_loader == ModelLoader.GPTQ:
from auto_gptq import AutoGPTQForCausalLM
self.model = AutoGPTQForCausalLM.from_quantized(
model_name, trust_remote_code=True,
Expand All @@ -87,7 +87,8 @@ def __init__(self, model_config: HuggingFaceModelConfig):
model_name, trust_remote_code=True, **model_kwargs
).to(self.device)
with htrack_block(f"Loading Hugging Face tokenizer model for config {model_config}"):
if model_config.quantization_config.tokenizer_name:
# When the quantized model has uses a different tokenizer than its moddel name
if quantization_config and model_config.quantization_config.tokenizer_name:
tokenizer_name: str = model_config.quantization_config.tokenizer_name
if "revision" in model_kwargs:
model_kwargs.pop("revision")
Expand Down

0 comments on commit d2b66be

Please sign in to comment.