Skip to content

Commit

Permalink
Support torch_bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Aug 4, 2024
1 parent b57e690 commit e2091ed
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions scripts/run_rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@
from datasets import load_dataset
from rewardbench import DPO_MODEL_CONFIG, REWARD_MODEL_CONFIG
from rewardbench import check_tokenizer_chat_template, load_eval_dataset
from rewardbench import torch_dtype_mapping
from rewardbench.constants import EXAMPLE_COUNTS, SUBSET_MAPPING
from rewardbench.utils import calculate_scores_per_section
from tqdm import tqdm
from transformers import AutoTokenizer

from scripts.utils import load_multilingual_eval_dataset

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def main():
parser = argparse.ArgumentParser(description="Evaluate a reward model.")
Expand All @@ -62,8 +66,10 @@ def main():
parser.add_argument("--output_dir", type=str, default="results/", help="the output directory to save results")
parser.add_argument("--save_all", action="store_true", default=False, help="save all results (include scores per instance)")
parser.add_argument("--force_truncation", action="store_true", default=False, help="force truncation (if model errors)")
parser.add_argument("--torch_dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32", "float64"], help="set PyTorch dtype (default: float16)")
# fmt: on
args = parser.parse_args()
args.torch_dtype = torch_dtype_mapping(args.torch_dtype)

###############
# Setup logging
Expand Down Expand Up @@ -111,6 +117,14 @@ def main():
config = MODEL_CONFIGS["default"]
logger.info(f"Using reward model config: {config}")

torch_dtype = config.get("torch_dtype", None)
if torch_dtype is None:
# if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
if args.torch_dtype == torch.bfloat16:
quantized = False
logger.info("Disabling quantization for bfloat16 datatype")
torch_dtype = args.torch_dtype

# Default entries
# "model_builder": AutoModelForSequenceClassification.from_pretrained,
# "pipeline_builder": pipeline,
Expand All @@ -126,6 +140,7 @@ def main():
or ("Llama3" in args.model)
or ("Llama-3" in args.model)
or ("LLaMA3" in args.model)
or ("llama3" in args.model)
or args.not_quantized
):
quantized = False
Expand Down Expand Up @@ -184,7 +199,7 @@ def main():
model_kwargs = {
"load_in_8bit": True,
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
model = model_builder(
args.model,
Expand Down Expand Up @@ -247,11 +262,14 @@ def main():
model_kwargs = {
"load_in_8bit": True,
"device_map": {"": current_device},
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
else:
# note, device map auto does not work for quantized models
model_kwargs = {"device_map": "auto"}
model_kwargs = {
"device_map": "auto",
"torch_dtype": torch_dtype,
}

model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code)
reward_pipe = pipeline_builder(
Expand Down

0 comments on commit e2091ed

Please sign in to comment.