Skip to content

Commit

Permalink
update bnb model load.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jan 26, 2024
1 parent f60d309 commit 5dc838c
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def filter_empty_labels(example):
model_args.device_map = {"": int(os.environ.get("LOCAL_RANK", "0"))}
training_args.gradient_accumulation_steps = training_args.gradient_accumulation_steps // world_size or 1
if script_args.qlora and (len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled()):
logger.warning("FSDP and ZeRO3 are both currently incompatible with QLoRA.")
logger.warning("FSDP and DeepSpeed ZeRO-3 are both currently incompatible with QLoRA.")

config_kwargs = {
"trust_remote_code": model_args.trust_remote_code,
Expand Down Expand Up @@ -1270,29 +1270,38 @@ def filter_empty_labels(example):
load_in_4bit = model_args.load_in_4bit
load_in_8bit = model_args.load_in_8bit
load_in_8bit_skip_modules = None
if load_in_8bit or load_in_4bit:
quantization_config = None
if load_in_4bit and load_in_8bit:
raise ValueError("Error, load_in_4bit and load_in_8bit cannot be set at the same time")
elif load_in_8bit or load_in_4bit:
logger.info(f"Quantizing model, load_in_4bit: {load_in_4bit}, load_in_8bit: {load_in_8bit}")
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if script_args.modules_to_save is not None:
load_in_8bit_skip_modules = script_args.modules_to_save.split(',')
quantization_config = BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
)
if script_args.qlora:
quantization_config = BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)

model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch_dtype,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
device_map=model_args.device_map,
trust_remote_code=model_args.trust_remote_code,
quantization_config=BitsAndBytesConfig(
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
) if script_args.qlora else None,
quantization_config=quantization_config,
**config_kwargs,
)

# Fix ChatGLM2 and ChatGLM3 LM head
Expand Down

0 comments on commit 5dc838c

Please sign in to comment.