Skip to content

Commit

Permalink
add more test for conv2d
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiang Bin committed Nov 4, 2024
1 parent 052c619 commit efd6191
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,12 @@ def test_accuracy_conv1d(shape, kernel, stride, padding, dtype):
((32, 16, 8, 8), (32, 4, 4, 4), 4),
((18, 16, 4, 4), (16, 8, 2, 2), 2),
((9, 16, 4, 4), (128, 8, 2, 2), 2),
((32, 8, 8, 8), (32, 8, 3, 3), 1),
((18, 16, 5, 5), (16, 16, 3, 3), 1),
((9, 16, 7, 7), (128, 4, 3, 3), 4),
((32, 16, 9, 9), (32, 4, 5, 5), 4),
((18, 16, 11, 11), (16, 8, 3, 3), 2),
((9, 16, 6, 6), (128, 8, 3, 3), 2),
]


Expand Down Expand Up @@ -623,6 +629,12 @@ def test_accuracy_conv2d(shape, kernel, stride, padding, groups, dtype):
((32, 16, 8, 8), (32, 1, 4, 4), (4, 4)),
((18, 8, 4, 4), (16, 1, 2, 2), (2, 2)),
((9, 4, 4, 4), (128, 1, 2, 2), (2, 2)),
((32, 4, 8, 8), (32, 1, 3, 3), (3, 3)),
((18, 16, 13, 13), (16, 1, 5, 5), (5, 5)),
((9, 32, 8, 8), (128, 1, 3, 3), (3, 3)),
((32, 16, 9, 9), (32, 1, 5, 5), (5, 5)),
((18, 8, 7, 7), (16, 1, 3, 3), (3, 3)),
((9, 4, 6, 6), (128, 1, 3, 3), (3, 3)),
]


Expand All @@ -635,13 +647,11 @@ def test_accuracy_depthwise2d(
shape_input, shape_weight, kernel, stride, padding, dtype
):
inp = torch.randn(shape_input, dtype=dtype, device="cuda", requires_grad=True)
ref_inp = to_reference(inp, True)
torch.backends.cudnn.allow_tf32 = False
weight = torch.randn(shape_weight, dtype=dtype, device="cuda")
ref_weight = to_reference(weight, True)
ref_out = torch._C._nn._conv_depthwise2d(
ref_inp,
ref_weight,
inp,
weight,
kernel,
bias=None,
stride=stride,
Expand Down

0 comments on commit efd6191

Please sign in to comment.