Skip to content

Commit

Permalink
enabled load_in_4/8bit with hf models (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Aug 9, 2023
1 parent 7976492 commit 3e0627e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
43 changes: 30 additions & 13 deletions src/openelm/codegen/codegen_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,36 @@ def find_re(string, pattern, start_pos):
else:
return completion

def is_codegen_model(cfg: ModelConfig):
return "codegen" in cfg.model_path.lower()

def config_to_kwargs(cfg: ModelConfig):
# TODO: need a better way to determine the priority of these options
if cfg.load_in_8bit:
return {"load_in_8bit": True, "device_map": "auto"}

if cfg.load_in_4bit:
return {"load_in_4bit": True, "device_map": "auto"}

return {
"torch_dtype": torch.float16 if cfg.fp16 else None,
"low_cpu_mem_usage": cfg.fp16
}

def model_setup(cfg: ModelConfig, device=None, codegen_tokenizer: bool = True):
set_seed(cfg.seed)
if device is None:
device = torch.device("cuda")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(cfg.model_path)
# TODO: may need to check model type to determine padding
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_path, trust_remote_code=cfg.trust_remote_code
)
if is_codegen_model(cfg):
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.model_max_length > 32768:
tokenizer.model_max_length = 2048

tokenizer.pad_token = tokenizer.eos_token

autoconfig = AutoConfig.from_pretrained(
cfg.model_path, trust_remote_code=cfg.trust_remote_code
)
Expand All @@ -87,13 +102,15 @@ def model_setup(cfg: ModelConfig, device=None, codegen_tokenizer: bool = True):
model_cls = AutoModelForSeq2SeqLM
else:
model_cls = AutoModelForCausalLM

model_kwargs = config_to_kwargs(cfg)
model = model_cls.from_pretrained(
cfg.model_path,
torch_dtype=torch.float16 if cfg.fp16 else None,
low_cpu_mem_usage=cfg.fp16,
trust_remote_code=cfg.trust_remote_code,
# device_map="auto",
).to(device)
cfg.model_path, trust_remote_code=cfg.trust_remote_code,
**model_kwargs
)

if "device_map" not in model_kwargs:
model = model.to(device)

return model, tokenizer, device

Expand Down
2 changes: 2 additions & 0 deletions src/openelm/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class ModelConfig(BaseConfig):
do_sample: bool = True
num_return_sequences: int = 1
trust_remote_code: bool = True # needed for mosaicml/mpt-7b-instruct
load_in_8bit: bool = False # need to install bitsandbytes
load_in_4bit: bool = False


@dataclass
Expand Down

0 comments on commit 3e0627e

Please sign in to comment.