diff --git a/scripts/run_rewardbench.py b/scripts/run_rewardbench.py index 83e9c37..fd31a34 100644 --- a/scripts/run_rewardbench.py +++ b/scripts/run_rewardbench.py @@ -324,8 +324,9 @@ def main(): score_rejected_batch = [result["score"] for result in rewards_rejected] # for classes that directly output scores (custom code) else: - score_chosen_batch = rewards_chosen.cpu().numpy().tolist() - score_rejected_batch = rewards_rejected.cpu().numpy().tolist() + # Cast to float in case of bfloat16 + score_chosen_batch = rewards_chosen.float().cpu().numpy().tolist() + score_rejected_batch = rewards_rejected.float().cpu().numpy().tolist() # log results [