From cad9c377352244ca35ff55d6bfa3d3d5e0991876 Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Thu, 22 Aug 2024 16:19:10 +0800 Subject: [PATCH] ensure scalars are float32 in test cases for dropout --- tests/test_special_ops.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index 90f12cb6..1b0f11f2 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -1,5 +1,6 @@ from typing import Optional +import numpy as np import pytest import torch @@ -24,6 +25,11 @@ def test_accuracy_dropout(shape, p, dtype): inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) ref_inp = to_reference(inp) + # NOTE: ensure that scalars are float32(instead of float64) + # in some cases, casting up then casting down have different result + p = np.float32(p) + one_minus_p = np.float32(1.0) - p + ref_out = torch.nn.functional.dropout(ref_inp, p, True) with flag_gems.use_gems(): res_out = torch.nn.functional.dropout(inp, p, True) @@ -37,13 +43,15 @@ def test_accuracy_dropout(shape, p, dtype): res_out = to_reference(res_out) res_in_grad = to_reference(res_in_grad) - exp_equal = (p * p + (1 - p) * (1 - p)) * inp.numel() + exp_equal = (p * p + one_minus_p * one_minus_p) * inp.numel() num_equal = torch.sum(torch.isclose(ref_out, res_out)).item() if TO_CPU: zero_equal = torch.eq(res_out, torch.zeros_like(res_out)) num_zero = torch.sum(zero_equal).item() assert abs(num_zero / inp.numel() - p) <= 0.05 - scale_equal = torch.isclose(res_out, ref_inp / (1 - p), rtol=RESOLUTION[dtype]) + scale_equal = torch.isclose( + res_out, ref_inp / one_minus_p, rtol=RESOLUTION[dtype] + ) assert torch.all(torch.logical_or(zero_equal, scale_equal)) else: assert (