From e2091ed478a30d42c72d5eb27a28333b982cd53d Mon Sep 17 00:00:00 2001 From: Lj Miranda Date: Sun, 4 Aug 2024 15:37:18 -0700 Subject: [PATCH] Support torch_bfloat16 --- scripts/run_rewardbench.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/scripts/run_rewardbench.py b/scripts/run_rewardbench.py index 4543983..83e9c37 100644 --- a/scripts/run_rewardbench.py +++ b/scripts/run_rewardbench.py @@ -32,6 +32,7 @@ from datasets import load_dataset from rewardbench import DPO_MODEL_CONFIG, REWARD_MODEL_CONFIG from rewardbench import check_tokenizer_chat_template, load_eval_dataset +from rewardbench import torch_dtype_mapping from rewardbench.constants import EXAMPLE_COUNTS, SUBSET_MAPPING from rewardbench.utils import calculate_scores_per_section from tqdm import tqdm @@ -39,6 +40,9 @@ from scripts.utils import load_multilingual_eval_dataset +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + def main(): parser = argparse.ArgumentParser(description="Evaluate a reward model.") @@ -62,8 +66,10 @@ def main(): parser.add_argument("--output_dir", type=str, default="results/", help="the output directory to save results") parser.add_argument("--save_all", action="store_true", default=False, help="save all results (include scores per instance)") parser.add_argument("--force_truncation", action="store_true", default=False, help="force truncation (if model errors)") + parser.add_argument("--torch_dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32", "float64"], help="set PyTorch dtype (default: float16)") # fmt: on args = parser.parse_args() + args.torch_dtype = torch_dtype_mapping(args.torch_dtype) ############### # Setup logging @@ -111,6 +117,14 @@ def main(): config = MODEL_CONFIGS["default"] logger.info(f"Using reward model config: {config}") + torch_dtype = config.get("torch_dtype", None) + if torch_dtype is None: + # if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes) + if args.torch_dtype == torch.bfloat16: + quantized = False + logger.info("Disabling quantization for bfloat16 datatype") + torch_dtype = args.torch_dtype + # Default entries # "model_builder": AutoModelForSequenceClassification.from_pretrained, # "pipeline_builder": pipeline, @@ -126,6 +140,7 @@ def main(): or ("Llama3" in args.model) or ("Llama-3" in args.model) or ("LLaMA3" in args.model) + or ("llama3" in args.model) or args.not_quantized ): quantized = False @@ -184,7 +199,7 @@ def main(): model_kwargs = { "load_in_8bit": True, "device_map": "auto", - "torch_dtype": torch.float16 if torch.cuda.is_available() else None, + "torch_dtype": torch_dtype if torch.cuda.is_available() else None, } model = model_builder( args.model, @@ -247,11 +262,14 @@ def main(): model_kwargs = { "load_in_8bit": True, "device_map": {"": current_device}, - "torch_dtype": torch.float16 if torch.cuda.is_available() else None, + "torch_dtype": torch_dtype if torch.cuda.is_available() else None, } else: # note, device map auto does not work for quantized models - model_kwargs = {"device_map": "auto"} + model_kwargs = { + "device_map": "auto", + "torch_dtype": torch_dtype, + } model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code) reward_pipe = pipeline_builder(