Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch_bfloat16 #17

Merged
merged 3 commits into from
Aug 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions scripts/run_rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@

from scripts.utils import load_multilingual_eval_dataset

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def torch_dtype_mapping(dtype_str):
"""
Helper function for argparse to map string to torch dtype.
"""
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64,
}
if dtype_str not in dtype_map:
raise argparse.ArgumentTypeError(f"Invalid torch dtype: {dtype_str}")
return dtype_map[dtype_str]


def main():
parser = argparse.ArgumentParser(description="Evaluate a reward model.")
Expand All @@ -62,8 +80,10 @@ def main():
parser.add_argument("--output_dir", type=str, default="results/", help="the output directory to save results")
parser.add_argument("--save_all", action="store_true", default=False, help="save all results (include scores per instance)")
parser.add_argument("--force_truncation", action="store_true", default=False, help="force truncation (if model errors)")
parser.add_argument("--torch_dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32", "float64"], help="set PyTorch dtype (default: float16)")
# fmt: on
args = parser.parse_args()
args.torch_dtype = torch_dtype_mapping(args.torch_dtype)

###############
# Setup logging
Expand Down Expand Up @@ -111,6 +131,14 @@ def main():
config = MODEL_CONFIGS["default"]
logger.info(f"Using reward model config: {config}")

torch_dtype = config.get("torch_dtype", None)
if torch_dtype is None:
# if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
if args.torch_dtype == torch.bfloat16:
quantized = False
logger.info("Disabling quantization for bfloat16 datatype")
torch_dtype = args.torch_dtype

# Default entries
# "model_builder": AutoModelForSequenceClassification.from_pretrained,
# "pipeline_builder": pipeline,
Expand All @@ -126,6 +154,7 @@ def main():
or ("Llama3" in args.model)
or ("Llama-3" in args.model)
or ("LLaMA3" in args.model)
or ("llama3" in args.model)
or args.not_quantized
):
quantized = False
Expand Down Expand Up @@ -184,7 +213,7 @@ def main():
model_kwargs = {
"load_in_8bit": True,
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
model = model_builder(
args.model,
Expand Down Expand Up @@ -247,11 +276,14 @@ def main():
model_kwargs = {
"load_in_8bit": True,
"device_map": {"": current_device},
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
else:
# note, device map auto does not work for quantized models
model_kwargs = {"device_map": "auto"}
model_kwargs = {
"device_map": "auto",
"torch_dtype": torch_dtype,
}

model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code)
reward_pipe = pipeline_builder(
Expand Down Expand Up @@ -306,8 +338,9 @@ def main():
score_rejected_batch = [result["score"] for result in rewards_rejected]
# for classes that directly output scores (custom code)
else:
score_chosen_batch = rewards_chosen.cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.cpu().numpy().tolist()
# Cast to float in case of bfloat16
score_chosen_batch = rewards_chosen.float().cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.float().cpu().numpy().tolist()

# log results
[
Expand Down