diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/code-format-check.yml similarity index 100% rename from .github/workflows/pre-commit.yml rename to .github/workflows/code-format-check.yml diff --git a/.github/workflows/model-test.yaml b/.github/workflows/model-test.yaml new file mode 100644 index 00000000..c1ad59f8 --- /dev/null +++ b/.github/workflows/model-test.yaml @@ -0,0 +1,29 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: model-test + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + container-model-test: + runs-on: [self-hosted, docker] + container: + image: localhost:5000/flag-gems-ci:v1.0 + ports: + - 82 + options: --gpus all --hostname flag-gems_cicd_model -v /home/flaggems_cicd/huggingface_cache_bert:/__w/_temp/_github_home/.cache/huggingface + steps: + - name: checkout-code + uses: actions/checkout@v4 + + - name: check-gpu-free + run: tests/scripts/gpu_check.sh + + - name: examples-flag-gems + run: | + CUDA_VISIBLE_DEVICES=5 pytest -s examples/model_bert_test.py diff --git a/.github/workflows/python-test.yaml b/.github/workflows/op-unit-test.yaml similarity index 78% rename from .github/workflows/python-test.yaml rename to .github/workflows/op-unit-test.yaml index 7d7b2e9b..b3b8bf82 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/op-unit-test.yaml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: flag-gems-test +name: op-unit-test on: push: @@ -62,21 +62,3 @@ jobs: done exit $overall_status - - container-model-test: - runs-on: [self-hosted, docker] - container: - image: localhost:5000/flag-gems-ci:v1.0 - ports: - - 82 - options: --gpus all --hostname flag-gems_cicd_model -v /home/flaggems_cicd/huggingface_cache_bert:/__w/_temp/_github_home/.cache/huggingface - steps: - - name: checkout-code - uses: actions/checkout@v4 - - - name: check-gpu-free - run: tests/scripts/gpu_check.sh - - - name: examples-flag-gems - run: | - CUDA_VISIBLE_DEVICES=5 pytest -s examples/model_bert_test.py diff --git a/.github/workflows/python-coverage.yaml b/.github/workflows/python-coverage.yaml new file mode 100644 index 00000000..2ba26af4 --- /dev/null +++ b/.github/workflows/python-coverage.yaml @@ -0,0 +1,69 @@ +# https://github.com/marketplace/actions/python-coverage + +name: python-coverage + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +permissions: + pull-requests: write + +jobs: + container-coverage-test: + runs-on: [self-hosted, docker] + container: + image: localhost:5000/flag-gems-ci:v1.0 + ports: + - 81 + options: --gpus all --hostname flag-gems_cicd_coverage + steps: + + - name: check-gpu-free + run: tests/scripts/gpu_check.sh + + - name: run-pytest + shell: bash + run: | + cmds=( + "CUDA_VISIBLE_DEVICES=0 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_unary_pointwise_ops.py &" + "CUDA_VISIBLE_DEVICES=0 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_pointwise_type_promotion.py &" + "CUDA_VISIBLE_DEVICES=1 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_binary_pointwise_ops.py &" + "CUDA_VISIBLE_DEVICES=1 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_tensor_constructor_ops.py &" + "CUDA_VISIBLE_DEVICES=1 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_distribution_ops.py &" + "CUDA_VISIBLE_DEVICES=2 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_blas_ops.py &" + "CUDA_VISIBLE_DEVICES=3 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_reduction_ops.py &" + "CUDA_VISIBLE_DEVICES=4 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_special_ops.py &" + "CUDA_VISIBLE_DEVICES=5 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s tests/test_libentry.py &" + "CUDA_VISIBLE_DEVICES=5 coverage run --parallel-mode --omit "*/.flaggems/*","*/usr/lib/*" -m pytest -s examples/model_bert_test.py &" + ) + + declare -a exit_statuses + + for cmd in "${cmds[@]}"; do + eval "$cmd" + done + + for job in $(jobs -p); do + wait $job + exit_statuses+=($?) + echo "Task $pid completed with exit status ${exit_statuses[-1]}" + done + + echo "Exit statuses of all tasks: ${exit_statuses[@]}" + + - name: get-coverage + run: | + coverage combine --append + coverage report -m + coverage xml -o coverage.xml + + - name: report-coverage + uses: orgoro/coverage@v3.2 + with: + coverageFile: coverage.xml + thresholdNew: 0.8 + thresholdModified: 0.0 + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index d7d9a425..cf5c8601 100644 --- a/README.md +++ b/README.md @@ -187,7 +187,7 @@ Operators will be implemented according to [OperatorList.md](./OperatorList.md). The following chart shows the speedup of FlagGems compared with PyTorch ATen library in eager mode. The speedup is calculated by averaging the speedup on each shape, representing the overall performance of the operator. -![Operator Speedup](./assets/speedup-0708-eng.png) +![Operator Speedup](./assets/speedup-0814-eng.png) ## Contributions diff --git a/README_cn.md b/README_cn.md index 86ee5796..9b32279d 100644 --- a/README_cn.md +++ b/README_cn.md @@ -186,7 +186,7 @@ pip install . FlagGems相比Torch Eager模式下ATen算子库的加速比如下图所示。其中,每个算子的加速比综合了多个形状测例的数据,代表该算子的整体性能。 -![算子加速比](./assets/speedup-0708-chn.png) +![算子加速比](./assets/speedup-0814-chn.png) ## 贡献代码 diff --git a/assets/speedup-0814-chn.png b/assets/speedup-0814-chn.png new file mode 100644 index 00000000..ef73109d Binary files /dev/null and b/assets/speedup-0814-chn.png differ diff --git a/assets/speedup-0814-eng.png b/assets/speedup-0814-eng.png new file mode 100644 index 00000000..49edf099 Binary files /dev/null and b/assets/speedup-0814-eng.png differ diff --git a/benchmark/test_distribution_perf.py b/benchmark/test_distribution_perf.py index 3e7cd605..03202ba4 100644 --- a/benchmark/test_distribution_perf.py +++ b/benchmark/test_distribution_perf.py @@ -9,50 +9,6 @@ ) -def test_perf_rand(): - def rand_kwargs(dtype, batch, size): - return {"size": (batch, size), "dtype": dtype, "device": "cuda"} - - bench = Benchmark( - op_name="rand", - torch_op=torch.rand, - arg_func=None, - dtypes=FLOAT_DTYPES, - batch=POINTWISE_BATCH, - sizes=SIZES, - kwargs_func=rand_kwargs, - ) - bench.run() - - -def test_perf_randn(): - def randn_kwargs(dtype, batch, size): - return {"size": (batch, size), "dtype": dtype, "device": "cuda"} - - bench = Benchmark( - op_name="randn", - torch_op=torch.randn, - arg_func=None, - dtypes=FLOAT_DTYPES, - batch=POINTWISE_BATCH, - sizes=SIZES, - kwargs_func=randn_kwargs, - ) - bench.run() - - -def test_perf_rand_like(): - bench = Benchmark( - op_name="rand_like", - torch_op=torch.rand_like, - arg_func=unary_arg, - dtypes=FLOAT_DTYPES, - batch=POINTWISE_BATCH, - sizes=SIZES, - ) - bench.run() - - def test_perf_normal(): def normal_arg(dtype, batch, size): loc = torch.full(size=(size, batch), fill_value=3.0, dtype=dtype, device="cuda") diff --git a/benchmark/test_pointwise_perf.py b/benchmark/test_pointwise_perf.py index 01234c69..03dbd0c5 100644 --- a/benchmark/test_pointwise_perf.py +++ b/benchmark/test_pointwise_perf.py @@ -134,6 +134,30 @@ def test_perf_eq(): bench.run() +def test_perf_maximum(): + bench = Benchmark( + op_name="maximum", + torch_op=torch.maximum, + arg_func=binary_args, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +def test_perf_minimum(): + bench = Benchmark( + op_name="minimum", + torch_op=torch.minimum, + arg_func=binary_args, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + def test_perf_exp(): bench = Benchmark( op_name="exp", @@ -158,7 +182,59 @@ def test_perf_ge(): bench.run() -def test_perf_gelu(): +def test_perf_gelu_tanh(): + def gelu_kwargs(dtype, batch, size): + return {"approximate": "tanh"} + + bench = Benchmark( + op_name="gelu", + torch_op=torch.nn.functional.gelu, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=gelu_kwargs, + ) + bench.run() + + +def test_perf_gelu_none(): + def gelu_kwargs(dtype, batch, size): + return {"approximate": "none"} + + bench = Benchmark( + op_name="gelu", + torch_op=torch.nn.functional.gelu, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=gelu_kwargs, + ) + bench.run() + + +def test_perf_gelu_backward_tanh(): + def gelu_kwargs(dtype, batch, size): + return {"approximate": "tanh"} + + bench = Benchmark( + op_name="gelu", + torch_op=torch.nn.functional.gelu, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=gelu_kwargs, + is_backward=True, + ) + bench.run() + + +def test_perf_gelu_backward_none(): + def gelu_kwargs(dtype, batch, size): + return {"approximate": "none"} + bench = Benchmark( op_name="gelu", torch_op=torch.nn.functional.gelu, @@ -166,6 +242,8 @@ def test_perf_gelu(): dtypes=FLOAT_DTYPES, batch=POINTWISE_BATCH, sizes=SIZES, + kwargs_func=gelu_kwargs, + is_backward=True, ) bench.run() diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index c5f42847..c38c5235 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -286,3 +286,28 @@ def test_perf_vector_norm(): sizes=SIZES, ) bench.run() + + +def test_perf_index_select(): + def index_select_args(dtype, batch, size): + inp = torch.randn([batch, size], dtype=dtype, device="cuda") + + threshold = 0.1 + dim = 0 + index_size = inp.size(dim) + from math import floor + + index = torch.randint( + 0, index_size, [floor(index_size * threshold)], device="cuda" + ) + return (inp, dim, index) + + bench = Benchmark( + op_name="index_select", + torch_op=torch.index_select, + arg_func=index_select_args, + dtypes=FLOAT_DTYPES, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() diff --git a/benchmark/test_special_perf.py b/benchmark/test_special_perf.py index b6b13181..bc48d6f5 100644 --- a/benchmark/test_special_perf.py +++ b/benchmark/test_special_perf.py @@ -1,6 +1,13 @@ import torch -from .performance_utils import FLOAT_DTYPES, POINTWISE_BATCH, SIZES, Benchmark +from .performance_utils import ( + FLOAT_DTYPES, + INT_DTYPES, + POINTWISE_BATCH, + SIZES, + Benchmark, + unary_int_arg, +) def test_perf_embedding(): @@ -73,3 +80,19 @@ def resolve_conj_arg(dtype, batch, size): sizes=SIZES, ) bench.run() + + +def test_perf_unique(): + def unique_kwargs(dtype, batch, size): + return {"sorted": True, "return_inverse": True, "return_counts": False} + + bench = Benchmark( + op_name="unique", + torch_op=torch.unique, + arg_func=unary_int_arg, + dtypes=INT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=unique_kwargs, + ) + bench.run() diff --git a/benchmark/test_tensor_constructor_perf.py b/benchmark/test_tensor_constructor_perf.py index e033c154..57162005 100644 --- a/benchmark/test_tensor_constructor_perf.py +++ b/benchmark/test_tensor_constructor_perf.py @@ -9,6 +9,62 @@ ) +def test_perf_rand(): + def rand_kwargs(dtype, batch, size): + return {"size": (batch, size), "dtype": dtype, "device": "cuda"} + + bench = Benchmark( + op_name="rand", + torch_op=torch.rand, + arg_func=None, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=rand_kwargs, + ) + bench.run() + + +def test_perf_randn(): + def randn_kwargs(dtype, batch, size): + return {"size": (batch, size), "dtype": dtype, "device": "cuda"} + + bench = Benchmark( + op_name="randn", + torch_op=torch.randn, + arg_func=None, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=randn_kwargs, + ) + bench.run() + + +def test_perf_rand_like(): + bench = Benchmark( + op_name="rand_like", + torch_op=torch.rand_like, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +def test_perf_randn_like(): + bench = Benchmark( + op_name="randn_like", + torch_op=torch.randn_like, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + def test_perf_ones(): def ones_kwargs(dtype, batch, size): return {"size": (batch, size), "dtype": dtype, "device": "cuda"} diff --git a/pyproject.toml b/pyproject.toml index b6903de8..b698c2c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,9 @@ testpaths = [ pythonpath = [ "src", ] + +[tool.coverage.run] +omit = [ + "*/.flaggems/*", + "*/usr/lib/*", + ] diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 743387fd..d0c825b6 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -24,7 +24,18 @@ def enable(lib=aten_lib): lib.impl("clamp.Tensor", clamp_tensor, "CUDA") lib.impl("cos", cos, "CUDA") lib.impl("cumsum", cumsum, "CUDA") - lib.impl("div.Tensor", div, "CUDA") + lib.impl("div.Tensor", true_divide, "CUDA") + lib.impl("div.Scalar", true_divide, "CUDA") + lib.impl("div.Tensor_mode", div_mode, "CUDA") + lib.impl("div.Scalar_mode", div_mode, "CUDA") + lib.impl("divide.Tensor", true_divide, "CUDA") # divide, an alias for div + lib.impl("divide.Scalar", true_divide, "CUDA") + lib.impl("divide.Tensor_mode", div_mode, "CUDA") + lib.impl("divide.Scalar_mode", div_mode, "CUDA") + lib.impl("true_divide.Tensor", true_divide, "CUDA") # true_divide, an alias for div + lib.impl("true_divide.Scalar", true_divide, "CUDA") + lib.impl("floor_divide", floor_divide, "CUDA") + lib.impl("floor_divide.Scalar", floor_divide, "CUDA") lib.impl("native_dropout", native_dropout, "AutogradCUDA") lib.impl("erf", erf, "CUDA") lib.impl("embedding", embedding, "AutogradCUDA") @@ -34,13 +45,15 @@ def enable(lib=aten_lib): lib.impl("exponential_", exponential_, "CUDA") lib.impl("ge.Tensor", ge, "CUDA") lib.impl("ge.Scalar", ge_scalar, "CUDA") - lib.impl("gelu", gelu, "CUDA") + lib.impl("gelu", gelu, "AutogradCUDA") lib.impl("native_group_norm", group_norm, "AutogradCUDA") lib.impl("gt.Tensor", gt, "CUDA") lib.impl("gt.Scalar", gt_scalar, "CUDA") lib.impl("isfinite", isfinite, "CUDA") lib.impl("isinf", isinf, "CUDA") lib.impl("isnan", isnan, "CUDA") + lib.impl("minimum", minimum, "CUDA") + lib.impl("maximum", maximum, "CUDA") lib.impl("native_layer_norm", layer_norm, "AutogradCUDA") lib.impl("le.Tensor", le, "CUDA") lib.impl("le.Scalar", le_scalar, "CUDA") @@ -50,6 +63,7 @@ def enable(lib=aten_lib): lib.impl("rand", rand, "CUDA") lib.impl("randn", randn, "CUDA") lib.impl("rand_like", rand_like, "CUDA") + lib.impl("randn_like", randn_like, "CUDA") lib.impl("zeros", zeros, "CUDA") lib.impl("ones", ones, "CUDA") lib.impl("full", full, "CUDA") @@ -113,7 +127,9 @@ def enable(lib=aten_lib): lib.impl("allclose", allclose, "CUDA") lib.impl("flip", flip, "CUDA") lib.impl("tile", tile, "CUDA") + lib.impl("index_select", index_select, "CUDA") lib.impl("masked_fill", masked_fill, "CUDA") + lib.impl("_unique2", _unique2, "CUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 352dde60..20e8dd00 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -17,7 +17,7 @@ from .cos import cos from .cross_entropy_loss import cross_entropy_loss from .cumsum import cumsum -from .div import div +from .div import div_mode, floor_divide, true_divide from .dropout import native_dropout from .embedding import embedding from .eq import eq, eq_scalar @@ -31,6 +31,7 @@ from .gelu import gelu from .groupnorm import group_norm from .gt import gt, gt_scalar +from .index_select import index_select from .isclose import allclose, isclose from .isfinite import isfinite from .isinf import isinf @@ -41,8 +42,10 @@ from .lt import lt, lt_scalar from .masked_fill import masked_fill from .max import max, max_dim +from .maximum import maximum from .mean import mean, mean_dim from .min import min, min_dim +from .minimum import minimum from .mm import mm from .mul import mul from .mv import mv @@ -62,6 +65,7 @@ from .rand import rand from .rand_like import rand_like from .randn import randn +from .randn_like import randn_like from .reciprocal import reciprocal from .relu import relu from .resolve_conj import resolve_conj @@ -79,6 +83,7 @@ from .topk import topk from .triu import triu from .uniform import uniform_ +from .unique import _unique2 from .var_mean import var_mean from .vector_norm import vector_norm from .where import where_scalar_other, where_scalar_self, where_self @@ -108,7 +113,9 @@ "clamp_tensor", "cos", "cumsum", - "div", + "true_divide", + "div_mode", + "floor_divide", "zeros", "ones", "full", @@ -129,6 +136,7 @@ "group_norm", "gt", "gt_scalar", + "index_select", "isclose", "isfinite", "isinf", @@ -143,8 +151,12 @@ "mean_dim", "mm", "mul", + "maximum", + "minimum", "rand", "randn", + "rand_like", + "randn_like", "resolve_neg", "resolve_conj", "normal_tensor_float", @@ -152,7 +164,6 @@ "normal_tensor_tensor", "normal_float_float", "uniform_", - "rand_like", "mv", "ne", "ne_scalar", @@ -191,4 +202,5 @@ "where_scalar_self", "where_scalar_other", "masked_fill", + "_unique2", ] diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index b16a91b9..a3ea8c0e 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -54,12 +54,9 @@ def heur_block_n(args): @libentry() @triton.autotune( configs=[ - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_warps=8), + triton.Config({"BLOCK_M": 16}, num_warps=8), + triton.Config({"BLOCK_M": 32}, num_warps=8), ], key=[ "M", diff --git a/src/flag_gems/ops/cumsum.py b/src/flag_gems/ops/cumsum.py index 9abd623d..73ae4f06 100644 --- a/src/flag_gems/ops/cumsum.py +++ b/src/flag_gems/ops/cumsum.py @@ -14,12 +14,9 @@ def heur_block_n(args): @libentry() @triton.autotune( configs=[ - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_warps=8), + triton.Config({"BLOCK_M": 16}, num_warps=8), + triton.Config({"BLOCK_M": 32}, num_warps=8), ], key=[ "M", diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index 46ba3565..9c6fc0a8 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -2,6 +2,7 @@ import torch import triton +import triton.language as tl from ..utils import pointwise_dynamic @@ -40,19 +41,19 @@ def true_divide(A, B): @pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def trunc_div_func(x, y): - return triton.div_rz(x, y) + return tl.math.trunc(tl.math.div_rz(x, y)) @pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def trunc_div_func_tensor_scalar(x, y): - return triton.div_rz(x, y) + return tl.math.trunc(tl.math.div_rz(x, y)) @pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def trunc_div_func_scalar_tensor(x, y): - return triton.div_rz(x, y) + return tl.math.trunc(tl.math.div_rz(x, y)) def trunc_divide(A, B): @@ -71,19 +72,19 @@ def trunc_divide(A, B): @pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def floor_div_func(x, y): - return x // y + return tl.math.floor(tl.math.div_rd(x, y)) @pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def floor_div_func_tensor_scalar(x, y): - return x // y + return tl.math.floor(tl.math.div_rd(x, y)) @pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def floor_div_func_scalar_tensor(x, y): - return x // y + return tl.math.floor(tl.math.div_rd(x, y)) def floor_divide(A, B): @@ -99,7 +100,7 @@ def floor_divide(A, B): return A // B -def div(A, B, rounding_mode=None): +def div_mode(A, B, rounding_mode=None): if rounding_mode is None: return true_divide(A, B) elif rounding_mode == "trunc": diff --git a/src/flag_gems/ops/gelu.py b/src/flag_gems/ops/gelu.py index 856e4fc6..fc6dd324 100644 --- a/src/flag_gems/ops/gelu.py +++ b/src/flag_gems/ops/gelu.py @@ -1,23 +1,24 @@ import logging +import torch import triton import triton.language as tl from ..utils import pointwise_dynamic try: - from triton.language.extra.cuda.libdevice import erf, pow, tanh + from triton.language.extra.cuda.libdevice import erf, exp, pow, tanh except ImportError: try: - from triton.language.math import erf, pow, tanh + from triton.language.math import erf, exp, pow, tanh except ImportError: - from triton.language.libdevice import erf, pow, tanh + from triton.language.libdevice import erf, exp, pow, tanh @pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def gelu_none(x): - scale: tl.constexpr = 0.7071067811 + scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) output = 0.5 * x * (1 + erf(x * scale)) return output @@ -31,9 +32,57 @@ def gelu_tanh(x): return output +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def gelu_backward_none(x, dy): + scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) + scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi) + x_fp32 = x.to(tl.float32) + dydx = ( + scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2)) + + 0.5 * erf(scale1 * x_fp32) + + 0.5 + ) + dx = dydx * dy + return dx + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def gelu_backward_tanh(x, dy): + x_fp32 = x.to(tl.float32) + # 0.79788456 = math.sqrt(2 / math.pi) + tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2))) + dydx = 0.5 * x * ( + (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2)) + ) + 0.5 * (1 + tanh_out) + dx = dydx * dy + return dx + + +class Gelu(torch.autograd.Function): + @staticmethod + def forward(ctx, A, approximate): + logging.debug("GEMS GELU FORWARD") + if approximate == "tanh": + out = gelu_tanh(A) + else: + out = gelu_none(A) + ctx.save_for_backward(A) + ctx.approximate = approximate + return out + + @staticmethod + def backward(ctx, out_grad): + logging.debug("GEMS GELU BACKWARD") + (inp,) = ctx.saved_tensors + approximate = ctx.approximate + if approximate == "tanh": + in_grad = gelu_backward_tanh(inp, out_grad) + else: + in_grad = gelu_backward_none(inp, out_grad) + return in_grad, None + + def gelu(A, *, approximate="none"): - logging.debug("GEMS GELU") - if approximate == "tanh": - return gelu_tanh(A) - else: - return gelu_none(A) + return Gelu.apply(A, approximate) diff --git a/src/flag_gems/ops/index_select.py b/src/flag_gems/ops/index_select.py new file mode 100644 index 00000000..aa342914 --- /dev/null +++ b/src/flag_gems/ops/index_select.py @@ -0,0 +1,75 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import dim_compress, libentry + + +def cfggen(): + block_m = [1, 2, 4] + block_n = [1024, 2048, 4096] + configs = [ + triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4) + for m in block_m + for n in block_n + ] + return configs + + +@libentry() +@triton.autotune(configs=cfggen(), key=["M", "N"]) +@triton.jit +def index_select_kernel( + inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr +): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + rows_mask = rows_offsets < M + cols_offsets = pid_y + tl.arange(0, BLOCK_N) + cols_mask = cols_offsets < N + + block_mask = rows_mask and cols_mask + out_mask = rows_mask and (cols_offsets < index_len) + + indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0) + inp_off = rows_offsets * N + indices[None, :] + out_off = rows_offsets * index_len + cols_offsets[None, :] + + selected = tl.load(inp + inp_off, mask=block_mask, other=0.0) + tl.store(out + out_off, selected, mask=out_mask) + + +def index_select(inp, dim, index): + logging.debug("GEMS INDEX SELECT") + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + assert index.ndim <= 1, "Index should have dimension 1 or 0" + assert ((i >= 0 and i < inp.size(dim)) for i in index), "Index out of range" + + if index.ndim == 0: + index = index.unsqueeze(0) + dim = dim % inp.ndim + inp_shape = list(inp.shape) + index_len = index.numel() + + # with dim_compress + inp = dim_compress(inp, dim) + N = inp_shape[dim] + M = inp.numel() // N + out_shape = list(inp.shape) + out_shape[inp.ndim - 1] = index_len + out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(index_len, meta["BLOCK_N"]), + ) + index_select_kernel[grid](inp, out, M, N, index, index_len) + if dim != out.ndim - 1: + order = [i for i in range(out.ndim - 1)] + order.insert(dim, out.ndim - 1) + return out.permute(order) + else: + return out diff --git a/src/flag_gems/ops/log_softmax.py b/src/flag_gems/ops/log_softmax.py index 9daa3a78..ac1ad8da 100644 --- a/src/flag_gems/ops/log_softmax.py +++ b/src/flag_gems/ops/log_softmax.py @@ -23,14 +23,10 @@ def heur_num_warps(args): @libentry() @triton.autotune( configs=[ - triton.Config({"BLOCK_M": 1}, num_stages=4), - triton.Config({"BLOCK_M": 1}, num_stages=5), - triton.Config({"BLOCK_M": 2}, num_stages=4), - triton.Config({"BLOCK_M": 2}, num_stages=5), - triton.Config({"BLOCK_M": 4}, num_stages=4), - triton.Config({"BLOCK_M": 4}, num_stages=5), - triton.Config({"BLOCK_M": 8}, num_stages=4), - triton.Config({"BLOCK_M": 8}, num_stages=5), + triton.Config({"BLOCK_M": 1}), + triton.Config({"BLOCK_M": 2}), + triton.Config({"BLOCK_M": 4}), + triton.Config({"BLOCK_M": 8}), ], key=[ "M", @@ -72,14 +68,10 @@ def log_softmax_kernel( @libentry() @triton.autotune( configs=[ - triton.Config({"BLOCK_M": 1}, num_stages=4), - triton.Config({"BLOCK_M": 1}, num_stages=5), - triton.Config({"BLOCK_M": 2}, num_stages=4), - triton.Config({"BLOCK_M": 2}, num_stages=5), - triton.Config({"BLOCK_M": 4}, num_stages=4), - triton.Config({"BLOCK_M": 4}, num_stages=5), - triton.Config({"BLOCK_M": 8}, num_stages=4), - triton.Config({"BLOCK_M": 8}, num_stages=5), + triton.Config({"BLOCK_M": 1}), + triton.Config({"BLOCK_M": 2}), + triton.Config({"BLOCK_M": 4}), + triton.Config({"BLOCK_M": 8}), ], key=[ "M", diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index e43a32cc..36c24454 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -45,12 +45,9 @@ def heur_block_n(args): @libentry() @triton.autotune( configs=[ - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_warps=8), + triton.Config({"BLOCK_M": 16}, num_warps=8), + triton.Config({"BLOCK_M": 32}, num_warps=8), ], key=[ "M", diff --git a/src/flag_gems/ops/maximum.py b/src/flag_gems/ops/maximum.py new file mode 100644 index 00000000..cc9236fb --- /dev/null +++ b/src/flag_gems/ops/maximum.py @@ -0,0 +1,22 @@ +import logging + +import triton +import triton.language as tl + +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def maximum_kernel(X, Y): + if X.dtype == tl.bfloat16: + X = X.to(tl.float32) + Y = Y.to(tl.float32) + + return tl.maximum(X, Y) + + +def maximum(X, Y): + logging.debug("GEMS MAXIMUM") + assert X.is_cuda and Y.is_cuda + return maximum_kernel(X, Y) diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index e873c4c7..1236d4fc 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -45,12 +45,9 @@ def heur_block_n(args): @libentry() @triton.autotune( configs=[ - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_warps=8), + triton.Config({"BLOCK_M": 16}, num_warps=8), + triton.Config({"BLOCK_M": 32}, num_warps=8), ], key=[ "M", diff --git a/src/flag_gems/ops/minimum.py b/src/flag_gems/ops/minimum.py new file mode 100644 index 00000000..d16f27dd --- /dev/null +++ b/src/flag_gems/ops/minimum.py @@ -0,0 +1,21 @@ +import logging + +import triton +import triton.language as tl + +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 0, "DEFAULT")]) +@triton.jit +def minimum_kernel(X, Y): + if X.dtype == tl.bfloat16: + X = X.to(tl.float32) + Y = Y.to(tl.float32) + return tl.minimum(X, Y) + + +def minimum(X, Y): + logging.debug("GEMS MINIMUM") + assert X.is_cuda and Y.is_cuda + return minimum_kernel(X, Y) diff --git a/src/flag_gems/ops/prod.py b/src/flag_gems/ops/prod.py index dd81a432..35199e9a 100644 --- a/src/flag_gems/ops/prod.py +++ b/src/flag_gems/ops/prod.py @@ -68,12 +68,9 @@ def heur_block_n(args): @libentry() @triton.autotune( configs=[ - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 8}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 16}, num_warps=8, num_stages=5), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=4), - triton.Config({"BLOCK_M": 32}, num_warps=8, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_warps=8), + triton.Config({"BLOCK_M": 16}, num_warps=8), + triton.Config({"BLOCK_M": 32}, num_warps=8), ], key=[ "M", diff --git a/src/flag_gems/ops/randn_like.py b/src/flag_gems/ops/randn_like.py new file mode 100644 index 00000000..77034ce6 --- /dev/null +++ b/src/flag_gems/ops/randn_like.py @@ -0,0 +1,29 @@ +import logging + +import torch +import triton + +from flag_gems.ops.randn import randn_kernel +from flag_gems.utils.random_utils import philox_cuda_seed_offset + +UNROLL = 4 + + +def randn_like( + x, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + logging.debug("GEMS RANDN_LIKE") + if device is None: + device = x.device.index + if dtype is None: + dtype = x.dtype + out = torch.empty_like(x, device=device, dtype=dtype) + N = x.numel() + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, + # hence we cannot obtain the per thread offset as in Pytorch. + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_cuda_seed_offset(increment) + with torch.cuda.device(x.device): + randn_kernel[grid_fn](out, N, philox_seed, philox_offset) + return out diff --git a/src/flag_gems/ops/unique.py b/src/flag_gems/ops/unique.py new file mode 100644 index 00000000..80605b60 --- /dev/null +++ b/src/flag_gems/ops/unique.py @@ -0,0 +1,838 @@ +import torch +import triton +import triton.language as tl + +from flag_gems.utils.libentry import libentry + + +@libentry() +@triton.jit +def simple_unique_flat_kernel( + sorted_data_ptr: tl.tensor, + sorted_indices_ptr: tl.tensor, # in + data_out_ptr: tl.tensor, + inverse_indices_ptr: tl.tensor, + idx_ptr: tl.tensor, + unique_size_ptr: tl.tensor, # out + return_inverse: tl.constexpr, + return_counts: tl.constexpr, + num_tasks: int, + tile_size: tl.constexpr, +): + i0 = tl.arange(0, tile_size) + mask = i0 < num_tasks + + # load + a = tl.load(sorted_data_ptr + i0, mask=mask) + i0_prev = tl.where(i0 > 0, i0 - 1, 0) + b = tl.load(sorted_data_ptr + i0_prev, mask=mask) + + # ne & cumsum + ne_result = tl.where(i0 > 0, a != b, 0) + cumsum = tl.cumsum(ne_result) + + # unique_size + unique_size_mask = i0 == tile_size - 1 + tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask) + + # data_out: scatter_(to=cumsum, sorted_data) + tl.store(data_out_ptr + cumsum, a, mask=mask) + + # inverse_indices: scatter_(to=sorted_indices, cumsum) + if return_inverse: + sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) + tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) + + # idx + if return_counts: + idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & mask + tl.store(idx_ptr + cumsum, i0, mask=idx_mask) + + +@triton.jit +def output_counts_flat_impl( + global_pid, + idx_ptr: tl.tensor, + origin_num_tasks: int, # in + counts_ptr: tl.tensor, # out + num_tasks: int, + tile_size: tl.constexpr, +): + r = tl.arange(0, tile_size) + + # load idx + i0 = global_pid * tile_size + r + mask = i0 < num_tasks + idx = tl.load(idx_ptr + i0, mask=mask) + + # load idx_next + i0_next = i0 + 1 + next_mask = i0_next < num_tasks + idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) + + # diff + counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) + + # store counts + tl.store(counts_ptr + i0, counts, mask=mask) + + +@libentry() +@triton.jit +def output_counts_flat_kernel( + idx_ptr: tl.tensor, + origin_num_tasks: int, # in + counts_ptr: tl.tensor, # out + num_tasks: int, + tiles_per_cta: int, + tile_size: tl.constexpr, + one_tile_per_cta: tl.constexpr, +): + pid = tl.program_id(0) + num_ctas = tl.num_programs(0) + if one_tile_per_cta: # monolitic kernel style + output_counts_flat_impl( + pid, + idx_ptr, + origin_num_tasks, # in + counts_ptr, # out + num_tasks, + tile_size, + ) + else: # grid-stride-loop style kernel + for j in range(0, tiles_per_cta): + global_pid = pid + j * num_ctas if j > 0 else pid + output_counts_flat_impl( + global_pid, + idx_ptr, + origin_num_tasks, # in + counts_ptr, # out + num_tasks, + tile_size, + ) + + +@triton.jit +def quick_output_flat_impl( + global_pid, + sorted_data_ptr: tl.tensor, + idx_ptr: tl.tensor, + origin_num_tasks: int, # in + data_out_ptr: tl.tensor, + counts_ptr: tl.tensor, # out + num_tasks: int, + tile_size: tl.constexpr, +): + r = tl.arange(0, tile_size) + + # load idx + i0 = global_pid * tile_size + r + mask = i0 < num_tasks + idx = tl.load(idx_ptr + i0, mask=mask) + + # load idx_next + i0_next = i0 + 1 + next_mask = i0_next < num_tasks + idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) + + # diff + counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) + + # store counts + tl.store(counts_ptr + i0, counts, mask=mask) + + # data_out: gather(sorted_data, from=idx) + sorted_data = tl.load(sorted_data_ptr + idx, mask=mask) + tl.store(data_out_ptr + i0, sorted_data, mask=mask) + + +@libentry() +@triton.jit +def quick_output_flat_kernel( + sorted_data_ptr: tl.tensor, + idx_ptr: tl.tensor, + origin_num_tasks: int, # in + data_out_ptr: tl.tensor, + counts_ptr: tl.tensor, # out + num_tasks: int, + tiles_per_cta: int, + tile_size: tl.constexpr, + one_tile_per_cta: tl.constexpr, +): + pid = tl.program_id(0) + num_ctas = tl.num_programs(0) + if one_tile_per_cta: # monolitic kernel style + quick_output_flat_impl( + pid, + sorted_data_ptr, + idx_ptr, + origin_num_tasks, # in + data_out_ptr, + counts_ptr, # out + num_tasks, + tile_size, + ) + else: # grid-stride-loop style kernel + for j in range(0, tiles_per_cta): + global_pid = pid + j * num_ctas if j > 0 else pid + quick_output_flat_impl( + global_pid, + sorted_data_ptr, + idx_ptr, + origin_num_tasks, # in + data_out_ptr, + counts_ptr, # out + num_tasks, + tile_size, + ) + + +@triton.jit +def local_quick_unique_flat_impl( + global_pid, + sorted_data_ptr: tl.tensor, # in + local_unique_ptr: tl.tensor, + origin_idx_ptr: tl.tensor, + tile_sum_ptr: tl.tensor, # out + global_num_ctas: int, + num_tasks: int, + tile_size: tl.constexpr, + return_counts: tl.constexpr, +): + offset = global_pid * tile_size + r = tl.arange(0, tile_size) + i0 = offset + r + mask = i0 < num_tasks + + # load + a = tl.load(sorted_data_ptr + i0, mask=mask) + i0_prev = tl.where(i0 > 0, i0 - 1, 0) + b = tl.load(sorted_data_ptr + i0_prev, mask=mask) + + # ne & cumsum + ne_result = tl.where(i0 > 0, a != b, 0) + cumsum = tl.cumsum(ne_result) + + # local_id or local_unique + local_unique_offset = cumsum - (1 if global_pid > 0 else 0) + local_unique_mask = (local_unique_offset >= 0) & mask + if return_counts: + # origin_idx: scatter_(to=cumsum, i0) + origin_idx_mask = ((i0 == 0) | ne_result.to(tl.int1)) & local_unique_mask + tl.store( + origin_idx_ptr + (offset + local_unique_offset), + i0, + mask=origin_idx_mask, + ) + else: + # local_unique: scatter_(to=cumsum, sorted_data) + tl.store( + local_unique_ptr + (offset + local_unique_offset), a, mask=local_unique_mask + ) + + # tile_sum + tile_sum_mask = (r == tile_size - 1) & (global_pid < global_num_ctas) + tile_sum = tl.where(tile_sum_mask & (global_pid == 0), cumsum + 1, cumsum) + tl.store(tile_sum_ptr + global_pid + tl.zeros_like(r), tile_sum, mask=tile_sum_mask) + + +@libentry() +@triton.jit +def local_quick_unique_flat_kernel( + sorted_data_ptr: tl.tensor, # in + local_unique_ptr: tl.tensor, + origin_idx_ptr: tl.tensor, + tile_sum_ptr: tl.tensor, # out + global_num_ctas: int, + num_tasks: int, + tiles_per_cta: int, + tile_size: tl.constexpr, + one_tile_per_cta: tl.constexpr, + return_counts: tl.constexpr, +): + pid = tl.program_id(0) + num_ctas = tl.num_programs(0) + if one_tile_per_cta: # monolitic kernel style + local_quick_unique_flat_impl( + pid, + sorted_data_ptr, # in + local_unique_ptr, + origin_idx_ptr, + tile_sum_ptr, # out + global_num_ctas, + num_tasks, + tile_size, + return_counts, + ) + else: # grid-stride-loop style kernel + for j in range(0, tiles_per_cta): + global_pid = pid + j * num_ctas if j > 0 else pid + local_quick_unique_flat_impl( + global_pid, + sorted_data_ptr, # in + local_unique_ptr, + origin_idx_ptr, + tile_sum_ptr, # out + global_num_ctas, + num_tasks, + tile_size, + return_counts, + ) + + +@triton.jit +def global_quick_unique_flat_impl( + global_pid, + total, + local_unique_ptr: tl.tensor, + origin_idx_ptr: tl.tensor, + tile_sum_ptr: tl.tensor, # in + data_out_ptr: tl.tensor, + idx_ptr: tl.tensor, # out + num_ctas: int, + global_num_ctas: int, + next_power_global_num_ctas: tl.constexpr, + num_tasks: int, + tile_size: tl.constexpr, + return_counts: tl.constexpr, +): + r = tl.arange(0, tile_size) + i0 = global_pid * tile_size + r + mask = i0 < num_tasks + + # load tile_sum + p = tl.arange(0, next_power_global_num_ctas) + pre_tile_sum_mask = ( + (p >= global_pid - num_ctas) + & (p < global_pid) + & (p >= 0) + & (p < global_num_ctas) + ) + pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) + cur_tile_sum_mask = global_pid < global_num_ctas + cur_tile_sum = tl.load(tile_sum_ptr + global_pid, mask=cur_tile_sum_mask) + + # total + total += tl.sum(pre_tile_sum) + if global_pid == global_num_ctas - 1: + last_tile_sum_mask = p == global_pid + tl.store(tile_sum_ptr + p, total + cur_tile_sum, mask=last_tile_sum_mask) + + # idx or data_out + tile_mask = r < cur_tile_sum + out_offset = total + r + if return_counts: + # move origin_idx to idx_ptr + origin_idx = tl.load(origin_idx_ptr + i0, mask=mask) + tl.store(idx_ptr + out_offset, origin_idx, mask=tile_mask) + else: + # move local_unique to data_out_ptr + local_unique = tl.load(local_unique_ptr + i0, mask=mask) + tl.store(data_out_ptr + out_offset, local_unique, mask=tile_mask) + + return total + + +@libentry() +@triton.jit +def global_quick_unique_flat_kernel( + local_unique_ptr: tl.tensor, + origin_idx_ptr: tl.tensor, + tile_sum_ptr: tl.tensor, # in + data_out_ptr: tl.tensor, + idx_ptr: tl.tensor, # out + num_ctas: int, + global_num_ctas: int, + next_power_global_num_ctas: tl.constexpr, + num_tasks: int, + tiles_per_cta: int, + tile_size: tl.constexpr, + one_tile_per_cta: tl.constexpr, + return_counts: tl.constexpr, +): + pid = tl.program_id(0) + num_ctas = tl.num_programs(0) + if one_tile_per_cta: # monolitic kernel style + global_quick_unique_flat_impl( + pid, + 0, + local_unique_ptr, + origin_idx_ptr, + tile_sum_ptr, # in + data_out_ptr, + idx_ptr, # out + num_ctas, + global_num_ctas, + next_power_global_num_ctas, + num_tasks, + tile_size, + return_counts, + ) + else: # grid-stride-loop style kernel + total = tl.zeros([1], dtype=tl.int64) + for j in range(0, tiles_per_cta): + global_pid = pid + j * num_ctas + total = global_quick_unique_flat_impl( + global_pid, + total, + local_unique_ptr, + origin_idx_ptr, + tile_sum_ptr, # in + data_out_ptr, + idx_ptr, # out + num_ctas, + global_num_ctas, + next_power_global_num_ctas, + num_tasks, + tile_size, + return_counts, + ) + + +def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): + num_tasks = sorted_data.numel() + next_power_num_tasks = triton.next_power_of_2(num_tasks) + tile_size = min(8192, next_power_num_tasks) + global_num_ctas = triton.cdiv(num_tasks, tile_size) + if global_num_ctas <= 8192: + tile_size = max( + 32, min(triton.next_power_of_2(global_num_ctas), next_power_num_tasks) + ) + global_num_ctas = triton.cdiv(num_tasks, tile_size) + next_power_global_num_ctas = triton.next_power_of_2(global_num_ctas) + num_ctas = global_num_ctas if global_num_ctas < 65536 else 2048 + tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas) + num_warps = 8 if tiles_per_cta == 1 else 32 + grid = (num_ctas, 1, 1) + + # allocate tensor + if return_counts: + local_unique = None + origin_idx = torch.empty_like(sorted_data, dtype=torch.int32) + idx = torch.empty_like(origin_idx) + else: + local_unique = torch.empty_like(sorted_data) + origin_idx = None + idx = None + counts = None + tile_sum = torch.empty( + (global_num_ctas,), dtype=torch.int32, device=sorted_data.device + ) + data_out = torch.empty_like(sorted_data) + + # launch kernel + with torch.cuda.device(sorted_data.device.index): + local_quick_unique_flat_kernel[grid]( + sorted_data, # in + local_unique, + origin_idx, + tile_sum, # out + global_num_ctas, + num_tasks, + tiles_per_cta=tiles_per_cta, + tile_size=tile_size, + one_tile_per_cta=tiles_per_cta == 1, + return_counts=return_counts, + num_warps=num_warps, + ) + global_quick_unique_flat_kernel[grid]( + local_unique, + origin_idx, + tile_sum, # in + data_out, + idx, # out + num_ctas, + global_num_ctas, + next_power_global_num_ctas, + num_tasks, + tiles_per_cta=tiles_per_cta, + tile_size=tile_size, + one_tile_per_cta=tiles_per_cta == 1, + return_counts=return_counts, + num_warps=num_warps, + ) + out_size = tile_sum[-1].item() + if return_counts: + idx = idx[:out_size] + counts = origin_idx[:out_size] + quick_output_flat_kernel[grid]( + sorted_data, + idx, + num_tasks, # in + data_out, + counts, # out + out_size, + tiles_per_cta, + tile_size, + one_tile_per_cta=tiles_per_cta == 1, + num_warps=num_warps, + ) + + if return_counts: + return data_out[:out_size], None, counts + else: + return data_out[:out_size], None, None + + +@triton.jit +def local_ne_flat_impl( + global_pid, + sorted_data_ptr: tl.tensor, # in + ne_result_ptr: tl.tensor, + tile_sum_ptr: tl.tensor, # out + global_num_ctas: int, + num_tasks: int, + tile_size: tl.constexpr, +): + r = tl.arange(0, tile_size) + i0 = global_pid * tile_size + r + mask = i0 < num_tasks + i0_prev = tl.where(i0 > 0, i0 - 1, 0) + + # load + a = tl.load(sorted_data_ptr + i0, mask=mask) + b = tl.load(sorted_data_ptr + i0_prev, mask=mask) + + # compute + ne_result = tl.where(i0 > 0, a != b, 0) + + # store ne_result + tl.store(ne_result_ptr + i0, ne_result, mask=mask) + + # store tile_sum + tile_sum = tl.sum(ne_result) + tile_sum_mask = global_pid < global_num_ctas + tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask) + + +@libentry() +@triton.jit +def local_ne_flat_kernel( + sorted_data_ptr: tl.tensor, # in + ne_result_ptr: tl.tensor, + tile_sum_ptr: tl.tensor, # out + global_num_ctas: int, + num_tasks: int, + tiles_per_cta: int, + tile_size: tl.constexpr, + one_tile_per_cta: tl.constexpr, +): + pid = tl.program_id(0) + num_ctas = tl.num_programs(0) + if one_tile_per_cta: # monolitic kernel style + local_ne_flat_impl( + pid, + sorted_data_ptr, # in + ne_result_ptr, + tile_sum_ptr, # out + global_num_ctas, + num_tasks, + tile_size, + ) + else: # grid-stride-loop style kernel + for j in range(0, tiles_per_cta): + global_pid = pid + j * num_ctas if j > 0 else pid + local_ne_flat_impl( + global_pid, + sorted_data_ptr, # in + ne_result_ptr, + tile_sum_ptr, # out + global_num_ctas, + num_tasks, + tile_size, + ) + + +@triton.jit +def global_cumsum_flat_impl( + global_pid, + total, + ne_result_ptr: tl.tensor, + tile_sum_ptr: tl.tensor, # in + sorted_data_ptr: tl.tensor, + sorted_indices_ptr: tl.tensor, # in + data_out_ptr: tl.tensor, + inverse_indices_ptr: tl.tensor, + idx_ptr: tl.tensor, # out + num_ctas: tl.constexpr, + global_num_ctas: int, + next_power_global_num_ctas: tl.constexpr, + num_tasks: int, + tile_size: tl.constexpr, + return_counts: tl.constexpr, +): + offset = global_pid * tile_size + r = tl.arange(0, tile_size) + i0 = offset + r + mask = i0 < num_tasks + + # load sorted_data, sorted_indices + sorted_data = tl.load(sorted_data_ptr + i0, mask=mask) + sorted_indices = tl.load(sorted_indices_ptr + i0, mask=mask) + + # load tile_sum + p = tl.arange(0, next_power_global_num_ctas) + pre_tile_sum_mask = ( + (p >= global_pid - num_ctas) + & (p < global_pid) + & (p >= 0) + & (p < global_num_ctas) + ) + pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) + + # cumsum + total += tl.sum(pre_tile_sum) + ne_result = tl.load(ne_result_ptr + i0, mask=mask) + ne_result_i1 = ne_result.to(tl.int1) + ne_result = ne_result.to(tl.int32) + cumsum = tl.cumsum(ne_result) + + # tile_sum + if global_pid == global_num_ctas - 1: + last_tile_sum_mask = i0 == num_tasks - 1 + tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum) + tl.store( + tile_sum_ptr + global_pid + tl.zeros_like(r), + tile_sum, + mask=last_tile_sum_mask, + ) + cumsum += total + + # data_out: scatter_(to=cumsum, sorted_data) + tl.store(data_out_ptr + cumsum, sorted_data, mask=mask) + + # inverse_indices: scatter_(to=sorted_indices, cumsum) + tl.store(inverse_indices_ptr + sorted_indices, cumsum, mask=mask) + + # idx + if return_counts: + idx_mask = ((i0 == 0) | ne_result_i1) & mask + tl.store(idx_ptr + cumsum, i0, mask=idx_mask) + + return total + + +@libentry() +@triton.jit +def global_cumsum_flat_kernel( + ne_result_ptr: tl.tensor, + tile_sum_ptr: tl.tensor, # in + sorted_data_ptr: tl.tensor, + sorted_indices_ptr: tl.tensor, # in + data_out_ptr: tl.tensor, + inverse_indices_ptr: tl.tensor, + idx_ptr: tl.tensor, # out + num_ctas: int, + global_num_ctas: int, + next_power_global_num_ctas: tl.constexpr, + num_tasks: int, + tiles_per_cta: int, + tile_size: tl.constexpr, + one_tile_per_cta: tl.constexpr, + return_counts: tl.constexpr, +): + pid = tl.program_id(0) + num_ctas = tl.num_programs(0) + if one_tile_per_cta: # monolitic kernel style + global_cumsum_flat_impl( + pid, + 0, + ne_result_ptr, + tile_sum_ptr, # in + sorted_data_ptr, + sorted_indices_ptr, # in + data_out_ptr, + inverse_indices_ptr, + idx_ptr, # out + num_ctas, + global_num_ctas, + next_power_global_num_ctas, + num_tasks, + tile_size, + return_counts, + ) + else: # grid-stride-loop style kernel + total = tl.zeros([1], dtype=tl.int64) + for j in range(0, tiles_per_cta): + global_pid = pid + j * num_ctas + total = global_cumsum_flat_impl( + global_pid, + total, + ne_result_ptr, + tile_sum_ptr, # in + sorted_data_ptr, + sorted_indices_ptr, # in + data_out_ptr, + inverse_indices_ptr, + idx_ptr, # out + num_ctas, + global_num_ctas, + next_power_global_num_ctas, + num_tasks, + tile_size, + return_counts, + ) + + +def sorted_indices_unique_flat( + sorted_data: torch.Tensor, sorted_indices: torch.Tensor, return_counts: bool +): + num_tasks = sorted_data.numel() + next_power_num_tasks = triton.next_power_of_2(num_tasks) + tile_size = min(8192, next_power_num_tasks) + global_num_ctas = triton.cdiv(num_tasks, tile_size) + if global_num_ctas <= 8192: + min_tile_size = 512 if global_num_ctas > 32 else 256 + tile_size = max( + min_tile_size, + min(triton.next_power_of_2(global_num_ctas), next_power_num_tasks), + ) + global_num_ctas = triton.cdiv(num_tasks, tile_size) + next_power_global_num_ctas = triton.next_power_of_2(global_num_ctas) + num_ctas = global_num_ctas if global_num_ctas < 32768 else 8192 + tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas) + num_warps = 8 if tiles_per_cta == 1 else 32 + grid = (num_ctas, 1, 1) + + # allocate tensor + ne_result = torch.empty_like(sorted_data, dtype=torch.bool) + tile_sum = torch.empty( + (global_num_ctas,), dtype=torch.int32, device=sorted_data.device + ) + data_out = torch.empty_like(sorted_data) + inverse_indices = torch.empty_like(sorted_data, dtype=torch.int32) + idx = None + if return_counts: + idx = torch.empty_like(inverse_indices) + + # launch kernel + with torch.cuda.device(sorted_data.device.index): + local_ne_flat_kernel[grid]( + sorted_data, # in + ne_result, + tile_sum, # out + global_num_ctas, + num_tasks, + tiles_per_cta=tiles_per_cta, + tile_size=tile_size, + one_tile_per_cta=tiles_per_cta == 1, + num_warps=num_warps, + ) + global_cumsum_flat_kernel[grid]( + ne_result, + tile_sum, # in + sorted_data, + sorted_indices, # in + data_out, + inverse_indices, + idx, # out + num_ctas, + global_num_ctas, + next_power_global_num_ctas, + num_tasks, + tiles_per_cta=tiles_per_cta, + tile_size=tile_size, + one_tile_per_cta=tiles_per_cta == 1, + return_counts=return_counts, + num_warps=num_warps, + ) + out_size = tile_sum[-1].item() + 1 + counts = None + if return_counts: + idx = idx[:out_size] + counts = torch.empty_like(idx) + output_counts_flat_kernel[grid]( + idx, + num_tasks, # in + counts, # out + out_size, + tiles_per_cta, + tile_size, + one_tile_per_cta=tiles_per_cta == 1, + num_warps=num_warps, + ) + + return data_out[:out_size], inverse_indices, counts + + +def simple_unique_flat( + sorted_data: torch.Tensor, + sorted_indices: torch.Tensor, + return_inverse: bool, + return_counts: bool, +): + num_tasks = sorted_data.numel() + grid = (1, 1, 1) + + # allocate tensor + data_out = torch.empty_like(sorted_data) + if return_inverse: + inverse_indices = torch.empty_like(sorted_data, dtype=torch.int32) + else: + inverse_indices = None + if return_counts: + idx = torch.empty_like(sorted_data, dtype=torch.int32) + else: + idx = None + unique_size = torch.empty([1], dtype=torch.int32, device=sorted_data.device) + + # launch kernel + with torch.cuda.device(sorted_data.device.index): + simple_unique_flat_kernel[grid]( + sorted_data, + sorted_indices, # in + data_out, + inverse_indices, + idx, + unique_size, # out + return_inverse, + return_counts, + num_tasks, + tile_size=triton.next_power_of_2(num_tasks), + num_warps=8, + ) + out_size = unique_size.item() + 1 + counts = None + if return_counts: + idx = idx[:out_size] + counts = torch.empty_like(idx) + with torch.cuda.device(sorted_data.device.index): + output_counts_flat_kernel[grid]( + idx, + num_tasks, # in + counts, # out + num_tasks=out_size, + tiles_per_cta=1, + tile_size=triton.next_power_of_2(out_size), + one_tile_per_cta=True, + num_warps=8, + ) + return data_out[:out_size], inverse_indices, counts + + +def _unique2( + in0: torch.Tensor, + sorted: bool = True, + return_inverse: bool = False, + return_counts: bool = False, +): + if in0.numel() <= 8192: + sorted_data, sorted_indices = torch.sort(in0.ravel(), stable=False) + data_out, inverse_indices, counts = simple_unique_flat( + sorted_data, sorted_indices, return_inverse, return_counts + ) + elif return_inverse: + sorted_data, sorted_indices = torch.sort(in0.ravel(), stable=False) + data_out, inverse_indices, counts = sorted_indices_unique_flat( + sorted_data, sorted_indices, return_counts + ) + else: + sorted_data, _ = torch.sort(in0.ravel(), stable=False) + data_out, inverse_indices, counts = sorted_quick_unique_flat( + sorted_data, return_counts + ) + return ( + data_out, + inverse_indices if inverse_indices is None else inverse_indices.view_as(in0), + counts, + ) diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index f7944e92..0600966f 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -255,6 +255,42 @@ def test_accuracy_div_scalar_tensor(shape, scalar, dtype): gems_assert_close(res_out, ref_out, dtype, equal_nan=True) +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float32]) +# Note : tl.math.div_rz only support float32, cast will cause diff +# with torch, so we only do float32 test for now. +def test_accuracy_trunc_div(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.div(ref_inp1, ref_inp2, rounding_mode="trunc") + with flag_gems.use_gems(): + res_out = torch.div(inp1, inp2, rounding_mode="trunc") + + logging.debug( + f"The maximum difference between torch and triton is " + f"{torch.max(torch.abs(ref_out - res_out))}" + ) + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_accuracy_floor_div(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.div(ref_inp1, ref_inp2, rounding_mode="floor") + with flag_gems.use_gems(): + res_out = torch.div(inp1, inp2, rounding_mode="floor") + + gems_assert_equal(res_out, ref_out) + + @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_eq(shape, dtype): @@ -507,6 +543,32 @@ def test_accuracy_pow(shape, dtype): gems_assert_close(res_out, ref_out, dtype, equal_nan=True) +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_maximum(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + ref_out = torch.maximum(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.maximum(inp1, inp2) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_minimum(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + ref_out = torch.minimum(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.minimum(inp1, inp2) + gems_assert_equal(res_out, ref_out) + + @pytest.mark.parametrize("scalar", SCALARS) @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) diff --git a/tests/test_named_ops.py b/tests/test_named_ops.py index 053c98bf..906a01b6 100644 --- a/tests/test_named_ops.py +++ b/tests/test_named_ops.py @@ -220,7 +220,7 @@ for file_name, collection in op_name_to_unit_test_maps.items(): for op, uts in collection.items(): for ut in uts: - cmd = f"{file_name}::{ut} --device {device}" + cmd = f"{file_name}::{ut}" result = pytest.main(["-s", cmd, "--device", device]) print("final_result: ", final_result) exit(final_result) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index eb658a0c..6c9c7ee7 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -670,3 +670,22 @@ def test_accuracy_vectornorm(shape, ord, dim, keepdim, dtype): res_out = torch.linalg.vector_norm(inp, ord, dim, keepdim) gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_index_select(shape, dim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + index_size = inp.size(dim) + from math import floor + + index = torch.randint(0, index_size, [floor(index_size * 0.8)], device="cuda") + + ref_inp = to_reference(inp) + ref_index = to_reference(index) + ref_out = torch.index_select(ref_inp, dim, ref_index) + with flag_gems.use_gems(): + res_out = torch.index_select(inp, dim, index) + + gems_assert_equal(res_out, ref_out) diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index ddf2f2a0..90f12cb6 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -7,6 +7,7 @@ from .accuracy_utils import ( FLOAT_DTYPES, + INT_DTYPES, POINTWISE_SHAPES, RESOLUTION, gems_assert_close, @@ -254,3 +255,83 @@ def test_accuracy_resolve_conj(shape, dtype): with flag_gems.use_gems(): z = y.resolve_conj() assert not z.is_conj() + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES + [(8191,), (8192, 73739)]) +@pytest.mark.parametrize("dtype", INT_DTYPES) +@pytest.mark.parametrize("sorted", [True]) +@pytest.mark.parametrize("return_inverse", [True, False]) +@pytest.mark.parametrize("return_counts", [False, True]) +def test_accuracy_unique(shape, dtype, sorted, return_inverse, return_counts): + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device="cuda") + else: + inp = torch.randint(-10, 10, shape, device="cuda").to(dtype) + ref_inp = to_reference(inp, False) + + if return_counts: + if return_inverse: + with flag_gems.use_gems(): + res_out, res_unique_order, res_counts = torch.unique( + inp, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + ref_out, ref_unique_order, ref_counts = torch.unique( + ref_inp, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + assert res_out.numel() == ref_out.numel() + gems_assert_equal(res_unique_order, ref_unique_order) + else: + with flag_gems.use_gems(): + res_out, res_counts = torch.unique( + inp, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + ref_out, ref_counts = torch.unique( + ref_inp, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + assert res_out.numel() == ref_out.numel() + gems_assert_equal(res_counts, ref_counts) + else: + if return_inverse: + with flag_gems.use_gems(): + res_out, res_unique_order = torch.unique( + inp, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + ref_out, ref_unique_order = torch.unique( + ref_inp, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + assert res_out.numel() == ref_out.numel() + gems_assert_equal(res_unique_order, ref_unique_order) + else: + with flag_gems.use_gems(): + res_out = torch.unique( + inp, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + ref_out = torch.unique( + ref_inp, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + ) + assert res_out.numel() == ref_out.numel() + gems_assert_equal(res_out, ref_out) diff --git a/tests/test_tensor_constructor_ops.py b/tests/test_tensor_constructor_ops.py index d39d60a7..88829e71 100644 --- a/tests/test_tensor_constructor_ops.py +++ b/tests/test_tensor_constructor_ops.py @@ -41,6 +41,18 @@ def test_accuracy_rand_like(shape, dtype): assert (res_out >= 0.0).all() +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_randn_like(shape, dtype): + x = torch.randn(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.randn_like(x) + mean = torch.mean(res_out) + std = torch.std(res_out) + assert torch.abs(mean) < 0.01 + assert torch.abs(std - 1) < 0.01 + + @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_zeros(shape, dtype): diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py index ac961a45..a73fd0cc 100644 --- a/tests/test_unary_pointwise_ops.py +++ b/tests/test_unary_pointwise_ops.py @@ -74,16 +74,24 @@ def test_accuracy_exp(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_gelu(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") +@pytest.mark.parametrize("approximate", ["none", "tanh"]) +def test_accuracy_gelu(shape, dtype, approximate): + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) ref_inp = to_reference(inp, True) - ref_out = torch.nn.functional.gelu(ref_inp) + ref_out = torch.nn.functional.gelu(ref_inp, approximate=approximate) with flag_gems.use_gems(): - res_out = torch.nn.functional.gelu(inp) + res_out = torch.nn.functional.gelu(inp, approximate=approximate) gems_assert_close(res_out, ref_out, dtype) + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype) + @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES)