Skip to content

Commit

Permalink
update quantization train.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jan 26, 2024
1 parent 5dc838c commit 3845803
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 38 deletions.
34 changes: 15 additions & 19 deletions pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,38 +635,34 @@ def group_text_function(examples):
config = config_class.from_pretrained(model_args.model_name_or_path, **config_kwargs)
load_in_4bit = model_args.load_in_4bit
load_in_8bit = model_args.load_in_8bit
load_in_8bit_skip_modules = None
quantization_config = None
if load_in_4bit and load_in_8bit:
raise ValueError("Error, load_in_4bit and load_in_8bit cannot be set at the same time")
elif load_in_8bit or load_in_4bit:
logger.info(f"Quantizing model, load_in_4bit: {load_in_4bit}, load_in_8bit: {load_in_8bit}")
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if script_args.modules_to_save is not None:
load_in_8bit_skip_modules = script_args.modules_to_save.split(',')
quantization_config = BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
)
if script_args.qlora:
quantization_config = BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)
if load_in_8bit:
config_kwargs['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
elif load_in_4bit:
if script_args.qlora:
config_kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)
else:
config_kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch_dtype,
)

model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
device_map=model_args.device_map,
quantization_config=quantization_config,
**config_kwargs,
)
else:
Expand Down
34 changes: 15 additions & 19 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,38 +1269,34 @@ def filter_empty_labels(example):

load_in_4bit = model_args.load_in_4bit
load_in_8bit = model_args.load_in_8bit
load_in_8bit_skip_modules = None
quantization_config = None
if load_in_4bit and load_in_8bit:
raise ValueError("Error, load_in_4bit and load_in_8bit cannot be set at the same time")
elif load_in_8bit or load_in_4bit:
logger.info(f"Quantizing model, load_in_4bit: {load_in_4bit}, load_in_8bit: {load_in_8bit}")
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if script_args.modules_to_save is not None:
load_in_8bit_skip_modules = script_args.modules_to_save.split(',')
quantization_config = BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
)
if script_args.qlora:
quantization_config = BitsAndBytesConfig(
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)
if load_in_8bit:
config_kwargs['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
elif load_in_4bit:
if script_args.qlora:
config_kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)
else:
config_kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch_dtype,
)

model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
device_map=model_args.device_map,
quantization_config=quantization_config,
**config_kwargs,
)

Expand Down

0 comments on commit 3845803

Please sign in to comment.