diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 4d379bef..2affe395 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -575,6 +575,12 @@ def test_accuracy_conv1d(shape, kernel, stride, padding, dtype): ((32, 16, 8, 8), (32, 4, 4, 4), 4), ((18, 16, 4, 4), (16, 8, 2, 2), 2), ((9, 16, 4, 4), (128, 8, 2, 2), 2), + ((32, 8, 8, 8), (32, 8, 3, 3), 1), + ((18, 16, 5, 5), (16, 16, 3, 3), 1), + ((9, 16, 7, 7), (128, 4, 3, 3), 4), + ((32, 16, 9, 9), (32, 4, 5, 5), 4), + ((18, 16, 11, 11), (16, 8, 3, 3), 2), + ((9, 16, 6, 6), (128, 8, 3, 3), 2), ] @@ -623,6 +629,12 @@ def test_accuracy_conv2d(shape, kernel, stride, padding, groups, dtype): ((32, 16, 8, 8), (32, 1, 4, 4), (4, 4)), ((18, 8, 4, 4), (16, 1, 2, 2), (2, 2)), ((9, 4, 4, 4), (128, 1, 2, 2), (2, 2)), + ((32, 4, 8, 8), (32, 1, 3, 3), (3, 3)), + ((18, 16, 13, 13), (16, 1, 5, 5), (5, 5)), + ((9, 32, 8, 8), (128, 1, 3, 3), (3, 3)), + ((32, 16, 9, 9), (32, 1, 5, 5), (5, 5)), + ((18, 8, 7, 7), (16, 1, 3, 3), (3, 3)), + ((9, 4, 6, 6), (128, 1, 3, 3), (3, 3)), ] @@ -635,13 +647,11 @@ def test_accuracy_depthwise2d( shape_input, shape_weight, kernel, stride, padding, dtype ): inp = torch.randn(shape_input, dtype=dtype, device="cuda", requires_grad=True) - ref_inp = to_reference(inp, True) torch.backends.cudnn.allow_tf32 = False weight = torch.randn(shape_weight, dtype=dtype, device="cuda") - ref_weight = to_reference(weight, True) ref_out = torch._C._nn._conv_depthwise2d( - ref_inp, - ref_weight, + inp, + weight, kernel, bias=None, stride=stride,