From 3845803f490678011f76ba9d4682e492970ee509 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Fri, 26 Jan 2024 17:48:42 +0800 Subject: [PATCH] update quantization train. --- pretraining.py | 34 +++++++++++++++------------------- supervised_finetuning.py | 34 +++++++++++++++------------------- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/pretraining.py b/pretraining.py index e95ef08..9452c27 100644 --- a/pretraining.py +++ b/pretraining.py @@ -635,30 +635,27 @@ def group_text_function(examples): config = config_class.from_pretrained(model_args.model_name_or_path, **config_kwargs) load_in_4bit = model_args.load_in_4bit load_in_8bit = model_args.load_in_8bit - load_in_8bit_skip_modules = None - quantization_config = None if load_in_4bit and load_in_8bit: raise ValueError("Error, load_in_4bit and load_in_8bit cannot be set at the same time") elif load_in_8bit or load_in_4bit: logger.info(f"Quantizing model, load_in_4bit: {load_in_4bit}, load_in_8bit: {load_in_8bit}") if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - if script_args.modules_to_save is not None: - load_in_8bit_skip_modules = script_args.modules_to_save.split(',') - quantization_config = BitsAndBytesConfig( - load_in_8bit=load_in_8bit, - load_in_4bit=load_in_4bit, - load_in_8bit_skip_modules=load_in_8bit_skip_modules, - ) - if script_args.qlora: - quantization_config = BitsAndBytesConfig( - load_in_8bit=load_in_8bit, - load_in_4bit=load_in_4bit, - load_in_8bit_skip_modules=load_in_8bit_skip_modules, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch_dtype, - ) + if load_in_8bit: + config_kwargs['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) + elif load_in_4bit: + if script_args.qlora: + config_kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch_dtype, + ) + else: + config_kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch_dtype, + ) model = model_class.from_pretrained( model_args.model_name_or_path, @@ -666,7 +663,6 @@ def group_text_function(examples): torch_dtype=torch_dtype, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), device_map=model_args.device_map, - quantization_config=quantization_config, **config_kwargs, ) else: diff --git a/supervised_finetuning.py b/supervised_finetuning.py index d0f61a0..0541622 100644 --- a/supervised_finetuning.py +++ b/supervised_finetuning.py @@ -1269,30 +1269,27 @@ def filter_empty_labels(example): load_in_4bit = model_args.load_in_4bit load_in_8bit = model_args.load_in_8bit - load_in_8bit_skip_modules = None - quantization_config = None if load_in_4bit and load_in_8bit: raise ValueError("Error, load_in_4bit and load_in_8bit cannot be set at the same time") elif load_in_8bit or load_in_4bit: logger.info(f"Quantizing model, load_in_4bit: {load_in_4bit}, load_in_8bit: {load_in_8bit}") if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - if script_args.modules_to_save is not None: - load_in_8bit_skip_modules = script_args.modules_to_save.split(',') - quantization_config = BitsAndBytesConfig( - load_in_8bit=load_in_8bit, - load_in_4bit=load_in_4bit, - load_in_8bit_skip_modules=load_in_8bit_skip_modules, - ) - if script_args.qlora: - quantization_config = BitsAndBytesConfig( - load_in_8bit=load_in_8bit, - load_in_4bit=load_in_4bit, - load_in_8bit_skip_modules=load_in_8bit_skip_modules, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch_dtype, - ) + if load_in_8bit: + config_kwargs['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) + elif load_in_4bit: + if script_args.qlora: + config_kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch_dtype, + ) + else: + config_kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch_dtype, + ) model = model_class.from_pretrained( model_args.model_name_or_path, @@ -1300,7 +1297,6 @@ def filter_empty_labels(example): torch_dtype=torch_dtype, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), device_map=model_args.device_map, - quantization_config=quantization_config, **config_kwargs, )