Skip to content

Commit

Permalink
ensure scalars are float32 in test cases for dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
iclementine committed Aug 22, 2024
1 parent adb2094 commit cad9c37
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import numpy as np
import pytest
import torch

Expand All @@ -24,6 +25,11 @@ def test_accuracy_dropout(shape, p, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
ref_inp = to_reference(inp)

# NOTE: ensure that scalars are float32(instead of float64)
# in some cases, casting up then casting down have different result
p = np.float32(p)
one_minus_p = np.float32(1.0) - p

ref_out = torch.nn.functional.dropout(ref_inp, p, True)
with flag_gems.use_gems():
res_out = torch.nn.functional.dropout(inp, p, True)
Expand All @@ -37,13 +43,15 @@ def test_accuracy_dropout(shape, p, dtype):
res_out = to_reference(res_out)
res_in_grad = to_reference(res_in_grad)

exp_equal = (p * p + (1 - p) * (1 - p)) * inp.numel()
exp_equal = (p * p + one_minus_p * one_minus_p) * inp.numel()
num_equal = torch.sum(torch.isclose(ref_out, res_out)).item()
if TO_CPU:
zero_equal = torch.eq(res_out, torch.zeros_like(res_out))
num_zero = torch.sum(zero_equal).item()
assert abs(num_zero / inp.numel() - p) <= 0.05
scale_equal = torch.isclose(res_out, ref_inp / (1 - p), rtol=RESOLUTION[dtype])
scale_equal = torch.isclose(
res_out, ref_inp / one_minus_p, rtol=RESOLUTION[dtype]
)
assert torch.all(torch.logical_or(zero_equal, scale_equal))
else:
assert (
Expand Down

0 comments on commit cad9c37

Please sign in to comment.