Skip to content

Commit

Permalink
add depthwise
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiang Bin committed Oct 31, 2024
1 parent 534bda2 commit ce12a38
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 81 deletions.
81 changes: 0 additions & 81 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,84 +244,3 @@ def upsample_nearest2d_input_fn(shape, dtype, device):
dtypes=FLOAT_DTYPES,
)
bench.run()


def test_conv1d():
def conv1d_arg(dtype, batch, size):
shape = [batch, 2, size]
kernel = [batch // 2, 2, size // 2]
bias = None
stride = 2
padding = 1
dilation = 1

torch.manual_seed(0)
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
weight = torch.randn(kernel, dtype=dtype, device="cuda")

return (inp, weight, bias, stride, padding, dilation)

bench = Benchmark(
op_name="conv1d",
torch_op=torch.nn.functional.conv1d,
arg_func=conv1d_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=[2, 4, 8, 16, 32, 64],
)
bench.run()


def test_conv2d_fwd():
def conv2d_arg(dtype, batch, size):
shape = [batch, 4, 4, size]
kernel = [batch // 2, 2, 2, size // 2]
bias = None
stride = 2
padding = 2
dilation = 1
groups = 2

torch.manual_seed(0)
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
weight = torch.randn(kernel, dtype=dtype, device="cuda")

return (inp, weight, bias, stride, padding, dilation, groups)

bench = Benchmark(
op_name="conv2d",
torch_op=torch.nn.functional.conv2d,
arg_func=conv2d_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=[16, 32, 64, 128, 256, 512],
)
bench.run()


def test_conv2d_bwd():
def conv2d_arg(dtype, batch, size):
shape = [batch, 4, 4, size]
kernel = [batch // 2, 2, 2, size // 2]
bias = None
stride = 2
padding = 2
dilation = 1
groups = 2

torch.manual_seed(0)
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
weight = torch.randn(kernel, dtype=dtype, device="cuda")

return (inp, weight, bias, stride, padding, dilation, groups)

bench = Benchmark(
op_name="conv2d",
torch_op=torch.nn.functional.conv2d,
arg_func=conv2d_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=[16, 32, 64, 128, 256, 512],
is_backward=True,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def enable(lib=aten_lib):
lib.impl("repeat_interleave.Tensor", repeat_interleave_tensor, "CUDA")
lib.impl("conv2d", conv2d, "CUDA")
lib.impl("conv1d", conv1d, "CUDA")
lib.impl("_conv_depthwise2d", _conv_depthwise2d, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .clamp import clamp, clamp_tensor
from .conv1d import conv1d
from .conv2d import conv2d
from .conv_depthwise2d import _conv_depthwise2d
from .cos import cos
from .cross_entropy_loss import cross_entropy_loss
from .cumsum import cumsum, normed_cumsum
Expand Down Expand Up @@ -247,4 +248,5 @@
"repeat_interleave_tensor",
"conv2d",
"conv1d",
"_conv_depthwise2d",
]
39 changes: 39 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,42 @@ def test_accuracy_conv2d(shape, kernel, stride, padding, groups, dtype):
# (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad)
# import numpy as np
# gems_assert_close(res_in_grad, ref_in_grad, dtype)


SHAPE_DEPTHWISE = [
((32, 4, 8, 8), (32, 1, 2, 2), (2, 2)),
((18, 16, 4, 4), (16, 1, 2, 2), (2, 2)),
((9, 32, 4, 4), (128, 1, 2, 2), (2, 2)),
((32, 16, 8, 8), (32, 1, 4, 4), (4, 4)),
((18, 8, 4, 4), (16, 1, 2, 2), (2, 2)),
((9, 4, 4, 4), (128, 1, 2, 2), (2, 2)),
]


@pytest.mark.conv_depthwise2d
@pytest.mark.parametrize("shape_input, shape_weight,kernel ", SHAPE_DEPTHWISE)
@pytest.mark.parametrize("stride", [2])
@pytest.mark.parametrize("padding", [2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_accuracy_depthwise2d(
shape_input, shape_weight, kernel, stride, padding, dtype
):
inp = torch.randn(shape_input, dtype=dtype, device="cuda", requires_grad=True)
ref_inp = to_reference(inp, True)
torch.backends.cudnn.allow_tf32 = False
weight = torch.randn(shape_weight, dtype=dtype, device="cuda")
ref_weight = to_reference(weight, True)
ref_out = torch._C._nn._conv_depthwise2d(
ref_inp,
ref_weight,
kernel,
bias=None,
stride=stride,
padding=padding,
dilation=1,
)
with flag_gems.use_gems():
res_out = torch._C._nn._conv_depthwise2d(
inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1
)
gems_assert_close(res_out, ref_out, dtype)

0 comments on commit ce12a38

Please sign in to comment.