From 5dc838c3cb2b7e93440767f5feb16bf1bfbc7083 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Fri, 26 Jan 2024 17:05:17 +0800 Subject: [PATCH] update bnb model load. --- supervised_finetuning.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/supervised_finetuning.py b/supervised_finetuning.py index 2d8ef02..d0f61a0 100644 --- a/supervised_finetuning.py +++ b/supervised_finetuning.py @@ -1214,7 +1214,7 @@ def filter_empty_labels(example): model_args.device_map = {"": int(os.environ.get("LOCAL_RANK", "0"))} training_args.gradient_accumulation_steps = training_args.gradient_accumulation_steps // world_size or 1 if script_args.qlora and (len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled()): - logger.warning("FSDP and ZeRO3 are both currently incompatible with QLoRA.") + logger.warning("FSDP and DeepSpeed ZeRO-3 are both currently incompatible with QLoRA.") config_kwargs = { "trust_remote_code": model_args.trust_remote_code, @@ -1270,29 +1270,38 @@ def filter_empty_labels(example): load_in_4bit = model_args.load_in_4bit load_in_8bit = model_args.load_in_8bit load_in_8bit_skip_modules = None - if load_in_8bit or load_in_4bit: + quantization_config = None + if load_in_4bit and load_in_8bit: + raise ValueError("Error, load_in_4bit and load_in_8bit cannot be set at the same time") + elif load_in_8bit or load_in_4bit: logger.info(f"Quantizing model, load_in_4bit: {load_in_4bit}, load_in_8bit: {load_in_8bit}") if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") if script_args.modules_to_save is not None: load_in_8bit_skip_modules = script_args.modules_to_save.split(',') + quantization_config = BitsAndBytesConfig( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + load_in_8bit_skip_modules=load_in_8bit_skip_modules, + ) + if script_args.qlora: + quantization_config = BitsAndBytesConfig( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + load_in_8bit_skip_modules=load_in_8bit_skip_modules, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch_dtype, + ) + model = model_class.from_pretrained( model_args.model_name_or_path, config=config, torch_dtype=torch_dtype, - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), device_map=model_args.device_map, - trust_remote_code=model_args.trust_remote_code, - quantization_config=BitsAndBytesConfig( - load_in_4bit=load_in_4bit, - load_in_8bit=load_in_8bit, - load_in_8bit_skip_modules=load_in_8bit_skip_modules, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch_dtype, - ) if script_args.qlora else None, + quantization_config=quantization_config, + **config_kwargs, ) # Fix ChatGLM2 and ChatGLM3 LM head