diff --git a/.gitignore b/.gitignore index 686ca07df..ca788f982 100644 --- a/.gitignore +++ b/.gitignore @@ -69,4 +69,10 @@ models/frozenmodels/ .pytest_cache # .hypothesis -.hypothesis \ No newline at end of file +.hypothesis + +# .ruff_cache +.ruff_cache + +# .bitblas_database +.bitblas_database diff --git a/3rdparty/.gitignore b/3rdparty/.gitignore new file mode 100644 index 000000000..f2ce68266 --- /dev/null +++ b/3rdparty/.gitignore @@ -0,0 +1,3 @@ +clang* + +llvm* diff --git a/3rdparty/tvm b/3rdparty/tvm index 0fd8338b8..93840447d 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0fd8338b815194f46b3b8ea40cd9cd1c670bfc0a +Subproject commit 93840447d3be9a34e0e32c4496ca5b62856bb220 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fa25b190e..bdf41586c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,6 +27,8 @@ Please ask questions in issues. All pull requests are super welcomed and greatly appreciated! Issues in need of a solution are marked with a [`♥ help`](https://github.com/ianstormtaylor/BitBLAS/issues?q=is%3Aissue+is%3Aopen+label%3A%22%E2%99%A5+help%22) label if you're looking for somewhere to start. +Please run `./format.sh` before submitting a pull request to make sure that your code is formatted correctly. + Please include tests and docs with every pull request! ## Repository Setup diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..d4920f73a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +recursive-include 3rdparty/tvm * +recursive-exclude 3rdparty/tvm/build * +recursive-exclude 3rdparty/clang* * +recursive-exclude 3rdparty/llvm* * diff --git a/README.md b/README.md index 8d101a066..c765027b8 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,83 @@ # BitBLAS -BitBLAS is a lightweight framework designed to generate high-performance CUDA/HIP code for BLAS operators, featuring swizzling and layout propagation. It achieves performance comparable to vendor libraries across various platforms and hardware. BitBLAS aims to assist algorithm developers working on projects like BitNet, GPTQ, and similar endeavors by enabling the rapid implementation of accelerated kernels and their efficient deployment. +BitBLAS is a library to support mixed-precision BLAS operations on GPUs, for example, the $W_{wdtype}A_{adtype}$ mixed-precision matrix multiplication where $C_{cdtype}[M, N] = A_{adtype}[M, K] \times W_{wdtype}[N, K]$. +BitBLAS aims to support efficient mixed-precision DNN model deployment, especially the $W_{wdtype}A_{adtype}$ quantization in large language models (LLMs), for example, the $W_{INT4}A_{FP16}$ in [GPTQ](https://arxiv.org/abs/2210.17323), the $W_{INT2}A_{FP16}$ in [BitDistiller](https://arxiv.org/abs/2402.10631), the $W_{INT1}A_{INT8}$ and $W_{INT2}A_{INT8}$ in [BitNet](https://arxiv.org/abs/2310.11453) and [BitNet-b1.58](https://arxiv.org/abs/2402.17764). BitBLAS is based on techniques from our accepted submission at OSDI'24. + Some of the key features of BitBLAS include: - - Auto Tensorize compute with TensorCore-like hardware instructions. - - High Performance (Not only FP16xFP16, INT8xINT8, but also FP16xINT4/2/1, INT8xINT4/2/1). - - With the flexible DSL (TIR Script) to effortlessly craft domain-specific kernels for your situations. - - Support with dynamic symbolic throuth tvm unity -> generate source code with dynamic shape. - - BitBLAS first proposed int8xint1 gemv/gemm with 10x/2x speedup over float16xfloat16 on A100, please checkout [op_benchmark_a100_int1_scaling](images/figures/op_benchmark_a100_int1_scaling.png) for detailed input scaling benchmark results. + - High performance matrix multiplication for both GEMV (e.g., the single batch auto-regressive decode phase in LLM) and GEMM (e.g., the batched auto-regressive decode phase and the prefill phase in LLM): + - $W_{wdtype}A_{adtype}$ mixed-precision matrix multiplication including FP16xINT4/2/1, INT8xINT4/2/1, etc. Please checkout [support matrix](#support-matrix) for detailed data types support. + - Matrix multiplication like FP16xFP16 and INT8xINT8. + - Auto-Tensorization for TensorCore-like hardware instructions. + - Implemented [integration](./integration/) to [PyTorch](https://pytorch.org/), [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ) and [vLLM](https://github.com/vllm-project/vllm) for LLM deployment. Please checkout [benchmark summary](#benchmark-summary) for detailed end2end LLM inference performance. + - BitBLAS first implemented $W_{INT1}A_{INT8}$ GEMV/GEMM with 10x/2x speedup over $W_{FP16}A_{FP16}$ on A100, please checkout [op_benchmark_a100_int1_scaling](images/figures/op_benchmark_a100_int1_scaling.png) for detailed benchmark results. + - Support customizing mixed-precision DNN operations for your specific scenarios via the flexible DSL (TIR Script). + +## Integration Example of FasterTransformer with BitBLAS +![FasterTransformer Integration](images/gif/FasterTransformer.gif) + + +## Benchmark Summary + +BitBLAS achieves exceptional performance across a variety of computational patterns. Below are selected results showcasing its capabilities: + +- End2End Integration with Quantize Inference Kernel for AutoGPTQ and vLLM. + +
+ AutoGPTQ end2end performance of llama13b on A100 + AutoGPTQ end2end performance of llama13b on A100 + vLLM end2end performance of llama13b on A100 + vLLM end2end performance of llama13b on A100 +
+ +- Weight Only Matmul performance on A100 +
+ gemm weight only performance on A100 + gemm weight only performance on A100 +
-## Benchmark -BitBLAS can achieve optimal performance across various compute patterns: -- GTX 3090 - - FLOAT16xFLOAT16 with TensorCore ![3090-gemm-fp16](./images/figures/op_benchmark_3090_fp16_gemm.png) - - INT8xINT8 with TensorCore ![3090-gemm-s8](./images/figures/op_benchmark_3090_s8_gemm.png) - - FLOAT16xAF4(LUT4) GEMV ![3090-af4-gemv](./images/figures/op_benchmark_3090_af4_gemv.png) - - FLOAT16xAF4(LUT4) with TensorCore ![3090-af4-gemm](./images/figures/op_benchmark_3090_af4_gemm.png) +- TensorCore FP16/INT8 GEMM Performance Vs. Vendor Library on A100 and RTX4090 -- A100 - - WeightOnly GEMV ![a100-wq-gemv](./images/figures/op_benchmark_a100_wq_gemv.png) - - WeightOnly GEMM with TensorCore ![a100-wq-gemm](./images/figures/op_benchmark_a100_wq_gemm.png) +
+ gemm fp16 performance on 4090 and a100 + gemm int8 performance on 4090 and a100 +
-See more details in our [benchmark](./benchmark) directory. +For more detailed information on benchmark sets with other formats (NF4/FP4) and other devices (GTX 3090), please refer to the [benchmark](./benchmark/README.md). + +## Support Matrix + +| **A_dtype** | **W_dtype** | **Accum_dtype** | **Out_dtype** | **BitBLAS
Support** | **Tested
Platform** | +|:-----------:|:-----------:|:---------------:|:---------------:|:----------------------:|:----------------------:| +| FP16 | FP16 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | FP4_E2M1 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | INT8 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | INT4 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | INT2 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | INT1 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | NF4 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| INT8 | INT8 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| INT8 | INT4 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| INT8 | INT2 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| INT8 | INT1 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | + +We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR. ## Getting Started -- Installation: - To manually install BitBLAS, please checkout `maint/scripts/installation.sh`. Also Make sure you already have the cuda toolkit (version >= 11) installed in the system. Or you can install from `python setup.py install` or `pip install .` in the root directory. +- [Installation](./docs/Installation.md): + To install BitBLAS, please checkout the document [installation](./docs/Installation.md). Also Make sure you already have the cuda toolkit (version >= 11) installed in the system. Or you can easily install from `pip install bitblas` in the root directory. + +- [QuickStart](./docs/QuickStart.md): BitBLAS provides two Python APIs to perform mixed-precision matrix multiplication: + - ```bitblas.Matmul``` implements the $W_{wdtype}A_{adtype}$ mixed-precision matrix multiplication of $C_{cdtype}[M, N] = A_{adtype}[M, K] \times W_{wdtype}[N, K]$. + - ```bitblas.Linear``` is a PyTorch ```nn.Linear```-like module to support a Linear of mixed-precision. -- [QuickStart](./docs/QuickStart.md): We provide two primary ways to do the code generation: using a high-level DSL (TensorIR Script), or using packed Operators, from the quick start guide, you can learn how to use BitBLAS to generate high performance kernels with both methods. +- [Integration](./integration/): Explore how BitBLAS seamlessly integrates with LLM deployment frameworks through our examples. Discover the ease of integrating BitBLAS with PyTorch, AutoGPTQ, and vLLM in the 3rd-party integration examples. + +- [Customization](./docs/ExtendOperatorsWithDSL.md): BitBLAS supports implementing customized mixed-precision DNN operations rather than matrix multiplication with the flexible DSL (TIR Script). -- [3rd Party Integration](./integration/): BitBLAS can also be easily integrated to other frameworks, the integration provides some examples of integrating BitBLAS with PyTorch, AutoGPTQ and vLLM. ## Contributing @@ -46,9 +91,3 @@ This project has adopted the Microsoft Open Source Code of Conduct. For more inf This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies. -## Acknowledgement - -We learned a lot from the following projects. - -- [Apache TVM](https://github.com/apache/tvm): BitBLAS havs adopted TensorIR as our DSL. Additionally, we have customized TVM from the unity branch to incorporate specific features that were required for our project. -- [Microsoft Roller](https://github.com/microsoft/nnfusion/tree/roller): The design and algo inspiration of hardware aware tuning in BitBLAS comes from Roller,. diff --git a/SUPPORT.md b/SUPPORT.md index 291d4d437..3e6ec834e 100644 --- a/SUPPORT.md +++ b/SUPPORT.md @@ -1,25 +1,29 @@ -# TODO: The maintainer of this repo has not yet edited this file +# Support -**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? +Welcome to the BitBLAS support page! BitBLAS is a cutting-edge framework designed for generating high-performance CUDA/HIP code for BLAS operators. Whether you're working on projects like BitNet, GPTQ, or similar, BitBLAS is here to accelerate your development with its robust features. -- **No CSS support:** Fill out this template with information about how to file issues and get help. -- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. -- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. +## How to File Issues and Get Help -*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* +### Reporting Bugs or Requesting Features -# Support +If you encounter a bug or have a feature request, we encourage you to file an issue through our GitHub Issues page. Please follow these steps: + +1. **Search Existing Issues**: Before creating a new issue, please search the existing ones to avoid duplicates. +2. **Create a New Issue**: If your issue is new, go ahead and file it as a new issue. Provide as much detail as possible to help us understand and address it efficiently. + +### Seeking Help and Questions + +For questions and help with using BitBLAS, we offer the following channels: + +- **GitHub Discussions**: For community support, sharing ideas, and discussing best practices, please visit our [GitHub Discussions](https://github.com/YOUR_REPO/discussions). +- **Stack Overflow**: Use the tag `BitBLAS` when posting questions. This is monitored by our team and the community. -## How to file issues and get help +## Microsoft Support Policy -This project uses GitHub Issues to track bugs and feature requests. Please search the existing -issues before filing new issues to avoid duplicates. For new issues, file your bug or -feature request as a new Issue. +Support for BitBLAS is primarily provided through the above-mentioned community channels. We strive to address issues and questions in a timely manner, leveraging the collective knowledge and experience of the BitBLAS community. -For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE -FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER -CHANNEL. WHERE WILL YOU HELP PEOPLE?**. +## Contributing to BitBLAS -## Microsoft Support Policy +We warmly welcome contributions to the BitBLAS project. Whether it's improving the documentation, adding new features, or fixing bugs, your contributions are invaluable to us. Please refer to our [CONTRIBUTING.md](./CONTRIBUTING.md) file for more details on how to contribute. -Support for this **PROJECT or PRODUCT** is limited to the resources listed above. +Before submitting a pull request, you may need to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. The CLA process is straightforward and only needs to be completed once. diff --git a/benchmark/README.md b/benchmark/README.md index e80c302db..dc86280a7 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,13 +1,13 @@ # Speedup Benchmark vs Vendor Libraries -This part presents a benchmark comparison between our custom library, BitBLAS, and various vendor libraries (cuBLAS, CUTLASS, bitsandbytes, faster-transformer, tensorrt-llm, vLLM, and Marlin) across different matrix operation types (GEMM, GEMV) and data formats (float16xfloat16, int8xint8, float16xint4/af4). The benchmarks are conducted on NVIDIA GPUs - 24GB GTX 3090 and 80GB A100, with CUDA 12.1 installed. +This part presents a benchmark comparison between our custom library, BitBLAS, and various vendor libraries (cuBLAS, CUTLASS, bitsandbytes, faster-transformer, tensorrt-llm, vLLM, and Marlin) across different matrix operation types (GEMM, GEMV) and data formats (float16xfloat16, int8xint8, float16xint4/nf4). The benchmarks are conducted on NVIDIA GPUs - 24GB GTX 3090 and 80GB A100, with CUDA 12.1 installed. ## Benchmark Overview ### Tested Operations and Formats - GEMM (General Matrix Multiply) and GEMV (General Matrix-Vector Multiply) -- Data formats: float16, int8, float16xint4/af4 +- Data formats: float16, int8, float16xint4/nf4 ### Hardware @@ -18,13 +18,20 @@ This part presents a benchmark comparison between our custom library, BitBLAS, a - CUDA 12.1 - Compared libraries: cuBLAS, CUTLASS, bitsandbytes, faster-transformer, tensorrt-llm, vLLM, Marlin +- Commit ID: + - bitsandbytes == 0.43.0 + - vLLM: 865732342b4e3b8a4ef38f28a2a5bdb87cf3f970 + - FasterTransformer: 1afbf20129647a35d108152fc6789bc1d029cda5 + - TensorRT-LLM: 2bf3a0a4287069ac55ee3304c285b08592d3d1bc + - CUTLASS: 629f4653c3ea3db3264030382956fabe715f3436 + - Marlin: 512f1b1ba39ff708bcc95419f11cfd1285cd31b3 ## Results Summary ### GTX 3090 Benchmarks - **Float16 and Int8 GEMM with Tensorcore**: BitBLAS matches the performance of cuBLAS and CUTLASS. -- **Float16xaf4 GEMV and GEMM**: BitBLAS achieves 2x the speed of bitsandbytes and 4x the base float16 performance. +- **Float16xnf4 GEMV and GEMM**: BitBLAS achieves 2x the speed of bitsandbytes and 4x the base float16 performance. - **Optimal performance** in float16xint4 GEMM. ### A100 Benchmarks @@ -35,48 +42,60 @@ This part presents a benchmark comparison between our custom library, BitBLAS, a The benchmark configurations for each test scenario are detailed below: - -|config|Provider|M|N|K| -|:---:|:---:|:---:|:---:|:---:| -|V0|None|1|16384|16384| -|V1|BLOOM|1|43008|14336| -|V2|BLOOM|1|14336|14336| -|V3|BLOOM|1|57344|14336| -|V4|BLOOM|1|14336|57344| -|V5|OPT|1|9216|9216| -|V6|OPT|1|36864|9216| -|V7|OPT|1|9216|36864| -|V8|LLAMA|1|22016|8192| -|V9|LLAMA|1|8192|22016| -|V10|LLAMA-2|1|8192|8192| -|V11|LLAMA-2|1|28672|8192| -|V12|LLAMA-2|1|8192|28672| -|M0|None|16384|16384|16384| -|M1|BLOOM|8192|43008|14336| -|M2|BLOOM|8192|14336|14336| -|M3|BLOOM|8192|57344|14336| -|M4|BLOOM|8192|14336|57344| -|M5|OPT|8192|9216|9216| -|M6|OPT|8192|36864|9216| -|M7|OPT|8192|9216|36864| -|M8|LLAMA|8192|22016|8192| -|M9|LLAMA|8192|8192|22016| -|M10|LLAMA-2|8192|8192|8192| -|M11|LLAMA-2|8192|28672|8192| -|M12|LLAMA-2|8192|8192|28672| - + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
configProviderMNK
V0None11638416384
V1BLOOM14300814336
V2BLOOM11433614336
V3BLOOM15734414336
V4BLOOM11433657344
V5OPT192169216
V6OPT1368649216
V7OPT1921636864
V8LLAMA1220168192
V9LLAMA1819222016
V10LLAMA-2181928192
V11LLAMA-21286728192
V12LLAMA-21819228672
M0None163841638416384
M1BLOOM81924300814336
M2BLOOM81921433614336
M3BLOOM81925734414336
M4BLOOM81921433657344
M5OPT819292169216
M6OPT8192368649216
M7OPT8192921636864
M8LLAMA8192220168192
M9LLAMA8192819222016
M10LLAMA-2819281928192
M11LLAMA-28192286728192
M12LLAMA-28192819228672
+
**Note:** To reproduce the 3rdparty frameworks' benchmark results, please refer to [mlc-benchmark](https://github.com/LeiWang1999/mlc-benchmark). ## Benchmark Images -- GTX 3090 - - ![3090-gemm-fp16](../images/figures/op_benchmark_3090_fp16_gemm.png) - - ![3090-gemm-s8](../images/figures/op_benchmark_3090_s8_gemm.png) - - ![3090-af4-gemv](../images/figures/op_benchmark_3090_af4_gemv.png) - - ![3090-af4-gemm](../images/figures/op_benchmark_3090_af4_gemm.png) +INT8xINT1 Matmul BS Scaling on A100. + +![int8xint1_scaling](../images/figures/op_benchmark_a100_int1_scaling.png) + +3090 Related benchmark numbers + +![3090-gemm-fp16](../images/figures/op_benchmark_3090_fp16_gemm.png) + +![3090-gemm-s8](../images/figures/op_benchmark_3090_s8_gemm.png) + +![3090-nf4-gemv](../images/figures/op_benchmark_3090_nf4_gemv.png) + +![3090-nf4-gemm](../images/figures/op_benchmark_3090_nf4_gemm.png) + +A100 Related Benchmark Result -- A100 - - ![a100-wq-gemv](../images/figures/op_benchmark_a100_wq_gemv.png) - - ![a100-wq-gemm](../images/figures/op_benchmark_a100_wq_gemm.png) +![a100-wq-gemv](../images/figures/op_benchmark_a100_wq_gemv_e8.png) +![a100-wq-gemm](../images/figures/op_benchmark_a100_wq_gemm_e8.png) diff --git a/benchmark/dsl/convolution.py b/benchmark/dsl/convolution.py index 592544c3b..3d9b5ac87 100644 --- a/benchmark/dsl/convolution.py +++ b/benchmark/dsl/convolution.py @@ -2,19 +2,16 @@ # Licensed under the MIT License. import numpy as np import tvm -from tvm.script import tir as T from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.base.roller.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.gpu import Matmul, matmul_mma +from bitblas.gpu import Matmul from bitblas.base.utils import apply_and_build import time from tvm import te, tir -def conv2d_nhwc_hwio( - n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dtype="float16" -): +def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dtype="float16"): A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) B = te.placeholder((kh, kw, c, f), name="weight", dtype=in_dtype) @@ -46,13 +43,8 @@ def conv2d_nhwc_hwio( C = te.compute( out_shape, lambda n, h, w, f: te.sum( - pad[ - n, - h * stride_h + kh * dilation_h, - w * stride_w + kw * dilation_w, - c, - ] - * B[kh, kw, c, f], + pad[n, h * stride_h + kh * dilation_h, w * stride_w + kw * dilation_w, c,] * B[kh, kw, + c, f], axis=[kh, kw, c], ), name="C", @@ -64,8 +56,8 @@ def conv2d_nhwc_hwio( benchmark_sets = [ # (prim_func, input_args, BitBLAS_default_schedule), (conv2d_nhwc_hwio, (128, 64, 224, 224, 64, 1, 1, 2, 1, 3, "float16", "float16"), Matmul), - # (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float32", "float32"), Matmul), - # (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float16", "float16"), Matmul), + (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float32", "float32"), Matmul), + (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float16", "float16"), Matmul), ] # fmt:on benchmark_results = {} @@ -77,7 +69,7 @@ def conv2d_nhwc_hwio( policy = DefaultPolicy(func=func, arch=arch) try: tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: + except Exception: tags = None if tags: policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) @@ -86,14 +78,8 @@ def conv2d_nhwc_hwio( tune_start = time.time() cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) fast_tune_time = time.time() - tune_start - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print( - "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) - ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency * 1e3)) + print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3)) # evaluate the performance of the default schedule @@ -112,12 +98,9 @@ def conv2d_nhwc_hwio( tvm.nd.array( np.random.uniform(0, 1, [int(i) for i in arg.shape]).astype(arg.dtype), device=arch.device, - ) - ) + )) - timer_cuda_mod = mod_default.time_evaluator( - mod_default.entry_name, arch.device, number=5 - ) + timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) t = timer_cuda_mod(*profile_tensors).mean print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) @@ -143,10 +126,8 @@ def conv2d_nhwc_hwio( "BitBLAS Default Latency", ] -col_width = ( - max(len(word) for row in [headers] + list(profile_config.values()) for word in row) - + 2 -) # padding +col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 + ) # padding print("".join(word.ljust(col_width) for word in headers)) diff --git a/benchmark/dsl/matmul.py b/benchmark/dsl/matmul.py index 19a0083b8..85068b1ef 100644 --- a/benchmark/dsl/matmul.py +++ b/benchmark/dsl/matmul.py @@ -1,13 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import numpy as np + import tvm -from tvm.script import tir as T from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.base.roller.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul -from bitblas.utils import get_target_from_env +from bitblas.utils import auto_detect_nvidia_target from bitblas.base.utils import apply_and_build from bitblas.ops.impl.matmul_impl import ( matmul_nn, @@ -16,7 +15,6 @@ ) import time - # fmt:off test_shapes = [ # (prim_func, input_args, default_dlight_schedule), @@ -29,10 +27,10 @@ (matmul_nn, (8192, 8192, 8192, "float16", "float16"), Matmul), (matmul_nn, (16384, 16384, 16384, "float16", "float16"), Matmul), (matmul_nt, (1024, 1024, 1024, "float32", "float32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16", "float16"), + Matmul), ] - llm_shapes = [ # square test (matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16"), Matmul), @@ -51,7 +49,7 @@ (matmul_nt_propagate_a_propagate_b, (8192, 8192, 8192, "float16", "float16"), Matmul), (matmul_nt_propagate_a_propagate_b, (8192, 28672, 8192, "float16", "float16"), Matmul), (matmul_nt_propagate_a_propagate_b, (8192, 8192, 28672, "float16", "float16"), Matmul), - + # square test (matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "int8", "int8", "int32"), Matmul), # BLOOM-176B @@ -76,7 +74,7 @@ # fmt:on -target = tvm.target.Target(get_target_from_env()) +target = tvm.target.Target(auto_detect_nvidia_target()) benchmark_results = {} for get_prim_func, input_args, d_schedule in benchmark_sets: @@ -86,7 +84,7 @@ policy = DefaultPolicy(func=func, arch=arch) try: tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: + except Exception: tags = None if tags: policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) @@ -96,14 +94,8 @@ tune_start = time.time() cpresults, best = apply_and_build(func, configs, arch, parallel_build=False) fast_tune_time = time.time() - tune_start - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print( - "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) - ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency * 1e3)) + print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3)) # evaluate the performance of the default schedule @@ -118,9 +110,7 @@ profile_tensors = best.profile_tensors - timer_cuda_mod = mod_default.time_evaluator( - mod_default.entry_name, arch.device, number=5 - ) + timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) t = timer_cuda_mod(*profile_tensors).mean print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) @@ -147,10 +137,8 @@ "DefaultDLight Latency", ] -col_width = ( - max(len(word) for row in [headers] + list(profile_config.values()) for word in row) - + 2 -) # padding +col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 + ) # padding print("".join(word.ljust(col_width) for word in headers)) diff --git a/benchmark/dsl/matmul_dequantize_af.py b/benchmark/dsl/matmul_dequantize_af.py new file mode 100644 index 000000000..e370de3c0 --- /dev/null +++ b/benchmark/dsl/matmul_dequantize_af.py @@ -0,0 +1,228 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import bitblas +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.gpu import Matmul +from bitblas.utils import auto_detect_nvidia_target +from bitblas.base.utils import apply_and_build +from bitblas.ops.impl.matmul_dequantize_impl import ( + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, +) +import time +import argparse + +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), +) +parser.add_argument( + "--batch_seq", + type=int, + default=1, + help="The batch size of the sequence", +) +parser.add_argument( + "--group_size", + type=int, + default=-1, + help="The group size of the sequence", +) +parser.add_argument( + "--benchmark_sets", + nargs="+", + default=["llm_shape_fp16xnf4"], + help="List of benchmark sets, e.g., llm_shape_fp16xnf4", +) + +args = parser.parse_args() +group_size = args.group_size + +# fmt:off +llm_shape_fp16xnf4 = [ + # square test + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b, (1, 43008, 14336, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 14336, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 57344, 14336, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 57344, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + # # OPT-65B + (matmul_nt_dequantize_b, (1, 9216, 9216, "float16", "float16", "float16", 4, "int8", "nf", True, + False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 36864, 9216, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 9216, 36864, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 22016, 8192, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b, (1, 8192, 22016, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 8192, "float16", "float16", "float16", 4, "int8", "nf", True, + False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 28672, 8192, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 28672, "float16", "float16", "float16", 4, "int8", "nf", + True, False, group_size, False, False), Matmul), + + # square test + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, + group_size, False, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 43008, 14336, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 14336, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 57344, 14336, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 57344, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + # OPT-65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 9216, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 36864, 9216, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 36864, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 22016, 8192, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 22016, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 8192, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 28672, 8192, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 28672, "float16", "float16", "float16", 4, "int8", "nf", True, False, group_size, + False, False), Matmul), +] + +target = tvm.target.Target(args.target) +benchmark_sets = [] +for benchmark_set in args.benchmark_sets: + benchmark_sets.extend(eval(benchmark_set)) +benchmark_results = {} +# fmt:on + +target = tvm.target.Target(auto_detect_nvidia_target()) + +benchmark_results = {} +for get_prim_func, input_args, d_schedule in benchmark_sets: + ir_module = get_prim_func(*input_args) + func = ir_module["main"] + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + tune_start = time.time() + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + fast_tune_time = time.time() - tune_start + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency)) + print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency)) + + # evaluate the performance of the default schedule + + rule = d_schedule() + default_tune_start = time.time() + with arch.target: + mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable + bitblas.gpu.Matmul(), + bitblas.gpu.GEMV(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + )( + ir_module) + try: + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + mod_default = tvm.build(mod, target="cuda") + except Exception: + mod_default = None + + default_tune_time = time.time() - default_tune_start + + args = func.buffer_map.values() + + profile_tensors = best.profile_tensors + if mod_default is not None: + timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) + t = timer_cuda_mod(*profile_tensors).mean + else: + t = 1e4 - 1 + + print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) + + profile_config = { + f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { + "fast_dlight_top20_tune_time": fast_tune_time, + "fast_dlight_top1_latency": cpresults[0].latency, + "fast_dlight_top20_latency": best.latency, + "default_dlight_tune_time": default_tune_time, + "default_dlight_latency": t * 1e3 if t is not None else "Failed", + } + } + + benchmark_results.update(profile_config) + +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Tune Time", + "BitBLAS Top1 Latency", + "BitBLAS Top20 Latency", + "DefaultDLight Tune Time", + "DefaultDLight Latency", +] + +col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 + ) # padding + +print("".join(word.ljust(col_width) for word in headers)) + +print("-" * col_width * len(headers)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f" {str(values['fast_dlight_top20_tune_time'])} s", + f"{values['fast_dlight_top1_latency']:.3f} ms", + f"{values['fast_dlight_top20_latency']:.3f} ms", + str(values["default_dlight_tune_time"]), + f"{values['default_dlight_latency']:.3e} ms", + ] + print("".join(word.ljust(col_width) for word in row)) diff --git a/benchmark/dsl/matmul_dequantize_fp.py b/benchmark/dsl/matmul_dequantize_fp.py new file mode 100644 index 000000000..66774aafa --- /dev/null +++ b/benchmark/dsl/matmul_dequantize_fp.py @@ -0,0 +1,227 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import bitblas +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.gpu import Matmul +from bitblas.utils import auto_detect_nvidia_target +from bitblas.base.utils import apply_and_build +from bitblas.ops.impl.matmul_dequantize_impl import ( + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, +) +import time +import argparse + +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), +) +parser.add_argument( + "--batch_seq", + type=int, + default=1, + help="The batch size of the sequence", +) +parser.add_argument( + "--group_size", + type=int, + default=-1, + help="The group size of the sequence", +) +parser.add_argument( + "--benchmark_sets", + nargs="+", + default=["llm_shape_fp16xfp4"], + help="List of benchmark sets, e.g., llm_shape_fp16xnf4", +) + +args = parser.parse_args() +group_size = args.group_size + +# fmt:off +llm_shape_fp16xfp4 = [ + # square test + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, -1, False, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b, (1, 43008, 14336, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 14336, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 57344, 14336, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 57344, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + # # OPT-65B + (matmul_nt_dequantize_b, (1, 9216, 9216, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 36864, 9216, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 9216, 36864, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 22016, 8192, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b, (1, 8192, 22016, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 8192, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 28672, 8192, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 28672, "float16", "float16", "float16", 4, "uint32", "fp", + True, False, group_size, False, False), Matmul), + + # square test + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 43008, 14336, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 14336, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 57344, 14336, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 57344, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + # OPT-65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 9216, "float16", "float16", "float16", 4, "uint32", "fp", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 36864, 9216, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 36864, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 22016, 8192, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 22016, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 8192, "float16", "float16", "float16", 4, "uint32", "fp", True, False, group_size, + False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 28672, 8192, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 28672, "float16", "float16", "float16", 4, "uint32", "fp", True, False, + group_size, False, False), Matmul), +] + +target = tvm.target.Target(args.target) +benchmark_sets = [] +for benchmark_set in args.benchmark_sets: + benchmark_sets.extend(eval(benchmark_set)) +benchmark_results = {} +# fmt:on + +target = tvm.target.Target(auto_detect_nvidia_target()) + +benchmark_results = {} +for get_prim_func, input_args, d_schedule in benchmark_sets: + ir_module = get_prim_func(*input_args) + func = ir_module["main"] + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + tune_start = time.time() + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + fast_tune_time = time.time() - tune_start + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency)) + print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency)) + # evaluate the performance of the default schedule + + rule = d_schedule() + default_tune_start = time.time() + with arch.target: + mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable + bitblas.gpu.Matmul(), + bitblas.gpu.GEMV(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + )( + ir_module) + try: + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + mod_default = tvm.build(mod, target="cuda") + except Exception: + mod_default = None + + default_tune_time = time.time() - default_tune_start + + args = func.buffer_map.values() + + profile_tensors = best.profile_tensors + if mod_default is not None: + timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) + t = timer_cuda_mod(*profile_tensors).mean + else: + t = 1e4 - 1 + + print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) + + profile_config = { + f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { + "fast_dlight_top20_tune_time": fast_tune_time, + "fast_dlight_top1_latency": cpresults[0].latency, + "fast_dlight_top20_latency": best.latency, + "default_dlight_tune_time": default_tune_time, + "default_dlight_latency": t * 1e3 if t is not None else "Failed", + } + } + + benchmark_results.update(profile_config) + +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Tune Time", + "BitBLAS Top1 Latency", + "BitBLAS Top20 Latency", + "DefaultDLight Tune Time", + "DefaultDLight Latency", +] + +col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 + ) # padding + +print("".join(word.ljust(col_width) for word in headers)) + +print("-" * col_width * len(headers)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f" {str(values['fast_dlight_top20_tune_time'])} s", + f"{values['fast_dlight_top1_latency']:.3f} ms", + f"{values['fast_dlight_top20_latency']:.3f} ms", + str(values["default_dlight_tune_time"]), + f"{values['default_dlight_latency']:.3e} ms", + ] + print("".join(word.ljust(col_width) for word in row)) diff --git a/benchmark/dsl/matmul_dequantize_int1.py b/benchmark/dsl/matmul_dequantize_int1.py new file mode 100644 index 000000000..1f8b1775e --- /dev/null +++ b/benchmark/dsl/matmul_dequantize_int1.py @@ -0,0 +1,230 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import bitblas +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.gpu import Matmul +from bitblas.utils import auto_detect_nvidia_target +from bitblas.base.utils import apply_and_build +from bitblas.ops.impl.matmul_dequantize_impl import ( + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, +) +import time +import argparse + +# append a parser for the benchmark set + +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int8xint1 on a specific target.") +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), +) +parser.add_argument( + "--batch_seq", + type=int, + default=1, + help="The batch size of the sequence", +) +parser.add_argument( + "--group_size", + type=int, + default=-1, + help="The group size of the sequence", +) +parser.add_argument( + "--benchmark_sets", + nargs="+", + default=["llm_int8xint1"], + help="List of benchmark sets, e.g., llm_int8xint1_bs4096", +) + +args = parser.parse_args() +batch_seq = args.batch_seq +group_size = args.group_size + +# fmt:off + +llm_int8xint1 = [ + # square test + (matmul_nt_dequantize_b, (1, 16384, 16384, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b, (1, 43008, 14336, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 14336, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 57344, 14336, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 57344, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + # # OPT-65B + (matmul_nt_dequantize_b, (1, 9216, 9216, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 36864, 9216, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 9216, 36864, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 22016, 8192, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b, (1, 8192, 22016, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 8192, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 28672, 8192, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 28672, "int8", "int8", "int32", 1, "int8", "uint", False, + False, group_size, True, False), Matmul), + + # square test + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (16384, 16384, 16384, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, + True, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 43008, 14336, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 14336, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 57344, 14336, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 57344, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, + True, False), Matmul), + # OPT-65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 9216, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 36864, 9216, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 36864, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 22016, 8192, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, True, + False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 22016, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 8192, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 28672, 8192, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 28672, "int8", "int8", "int32", 1, "int8", "uint", False, False, group_size, True, + False), Matmul), +] + +# fmt:on + +target = tvm.target.Target(args.target) +benchmark_sets = [] +for benchmark_set in args.benchmark_sets: + benchmark_sets.extend(eval(benchmark_set)) + +benchmark_results = {} +for get_prim_func, input_args, d_schedule in benchmark_sets: + ir_module = get_prim_func(*input_args) + func = ir_module["main"] + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + tune_start = time.time() + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + fast_tune_time = time.time() - tune_start + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency)) + print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency)) + + # evaluate the performance of the default schedule + + rule = d_schedule() + default_tune_start = time.time() + with arch.target: + mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable + bitblas.gpu.Matmul(), + bitblas.gpu.GEMV(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + )( + ir_module) + try: + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + mod_default = tvm.build(mod, target="cuda") + except Exception: + mod_default = None + + default_tune_time = time.time() - default_tune_start + + args = func.buffer_map.values() + + profile_tensors = best.profile_tensors + if mod_default is not None: + timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) + t = timer_cuda_mod(*profile_tensors).mean + else: + t = 1e4 - 1 + + print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) + + profile_config = { + f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { + "fast_dlight_top20_tune_time": fast_tune_time, + "fast_dlight_top1_latency": cpresults[0].latency, + "fast_dlight_top20_latency": best.latency, + "default_dlight_tune_time": default_tune_time, + "default_dlight_latency": t * 1e3 if t is not None else "Failed", + } + } + + benchmark_results.update(profile_config) + +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Tune Time", + "BitBLAS Top1 Latency", + "BitBLAS Top20 Latency", + "DefaultDLight Tune Time", + "DefaultDLight Latency", +] + +col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 + ) # padding + +print("".join(word.ljust(col_width) for word in headers)) + +print("-" * col_width * len(headers)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f" {str(values['fast_dlight_top20_tune_time'])} s", + f"{values['fast_dlight_top1_latency']:.3f} ms", + f"{values['fast_dlight_top20_latency']:.3f} ms", + str(values["default_dlight_tune_time"]), + f"{values['default_dlight_latency']:.3e} ms", + ] + print("".join(word.ljust(col_width) for word in row)) diff --git a/benchmark/dsl/matmul_dequantize_int4.py b/benchmark/dsl/matmul_dequantize_int4.py new file mode 100644 index 000000000..7eb6cf8dd --- /dev/null +++ b/benchmark/dsl/matmul_dequantize_int4.py @@ -0,0 +1,297 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import bitblas +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.gpu import Matmul +from bitblas.utils import auto_detect_nvidia_target +from bitblas.base.utils import apply_and_build +from bitblas.ops.impl.matmul_dequantize_impl import ( + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, +) +import time +import argparse + +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), +) +parser.add_argument( + "--group_size", + type=int, + default=-1, + help="The group size of the sequence", +) +parser.add_argument( + "--benchmark_sets", + nargs="+", + default=["llm_shape_fp16xint4"], + help="List of benchmark sets, e.g., llm_int8xint1_bs4096", +) + +args = parser.parse_args() +group_size = args.group_size + +# fmt:off +llm_shape_fp16xint4 = [ + # square test + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b, (1, 43008, 14336, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 14336, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 57344, 14336, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 57344, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + # # OPT-65B + (matmul_nt_dequantize_b, (1, 9216, 9216, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 36864, 9216, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 9216, 36864, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 22016, 8192, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b, (1, 8192, 22016, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 8192, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 28672, 8192, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 28672, "float16", "float16", "float16", 4, "int8", "int", + True, False, group_size, True, False), Matmul), + + # square test + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "int", True, False, + group_size, True, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 43008, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, + group_size, True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, + group_size, True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 57344, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, + group_size, True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 57344, "float16", "float16", "float16", 4, "int8", "int", True, False, + group_size, True, False), Matmul), + # OPT-65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 9216, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 36864, 9216, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 36864, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 22016, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, + True, False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 22016, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 28672, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, + True, False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 28672, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, + True, False), Matmul), +] + +llm_shape_int8xint4 = [ + # square test + (matmul_nt_dequantize_b, (1, 16384, 16384, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b, (1, 43008, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 57344, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 14336, 57344, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + # # OPT-65B + (matmul_nt_dequantize_b, (1, 9216, 9216, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 36864, 9216, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 9216, 36864, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 22016, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b, (1, 8192, 22016, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 28672, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + (matmul_nt_dequantize_b, (1, 8192, 28672, "int8", "int8", "int32", 2, "int8", "uint", False, + False, -1, True, False), Matmul), + + # square test + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (16384, 16384, 16384, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + # BLOOM-176B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 43008, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 57344, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 14336, 57344, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + # # OPT-65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 9216, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 36864, 9216, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 9216, 36864, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 22016, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + # LLAMA-70B/65B + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 22016, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 28672, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), + (matmul_nt_dequantize_b_propagate_a_propagate_b, + (8192, 8192, 28672, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, + False), Matmul), +] + +# fmt:on + +target = tvm.target.Target(args.target) +benchmark_sets = [] +for benchmark_set in args.benchmark_sets: + benchmark_sets.extend(eval(benchmark_set)) +benchmark_results = {} + +for get_prim_func, input_args, d_schedule in benchmark_sets: + ir_module = get_prim_func(*input_args) + func = ir_module["main"] + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + tune_start = time.time() + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + fast_tune_time = time.time() - tune_start + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency)) + print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency)) + + # evaluate the performance of the default schedule + + rule = d_schedule() + default_tune_start = time.time() + with arch.target: + mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable + bitblas.gpu.Matmul(), + bitblas.gpu.GEMV(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + )( + ir_module) + try: + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + mod_default = tvm.build(mod, target="cuda") + except Exception: + mod_default = None + + default_tune_time = time.time() - default_tune_start + + args = func.buffer_map.values() + + profile_tensors = best.profile_tensors + if mod_default is not None: + timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) + t = timer_cuda_mod(*profile_tensors).mean + else: + t = 1e4 - 1 + + print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) + + profile_config = { + f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { + "fast_dlight_top20_tune_time": fast_tune_time, + "fast_dlight_top1_latency": cpresults[0].latency, + "fast_dlight_top20_latency": best.latency, + "default_dlight_tune_time": default_tune_time, + "default_dlight_latency": t * 1e3 if t is not None else "Failed", + } + } + + benchmark_results.update(profile_config) + +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Tune Time", + "BitBLAS Top1 Latency", + "BitBLAS Top20 Latency", + "DefaultDLight Tune Time", + "DefaultDLight Latency", +] + +col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 + ) # padding + +print("".join(word.ljust(col_width) for word in headers)) + +print("-" * col_width * len(headers)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f" {str(values['fast_dlight_top20_tune_time'])} s", + f"{values['fast_dlight_top1_latency']:.3f} ms", + f"{values['fast_dlight_top20_latency']:.3f} ms", + str(values["default_dlight_tune_time"]), + f"{values['default_dlight_latency']:.3e} ms", + ] + print("".join(word.ljust(col_width) for word in row)) diff --git a/benchmark/dsl/weight_propagate.py b/benchmark/dsl/weight_propagate.py new file mode 100644 index 000000000..e69310af9 --- /dev/null +++ b/benchmark/dsl/weight_propagate.py @@ -0,0 +1,559 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import bitblas +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.gpu import Matmul +from bitblas.utils import auto_detect_nvidia_target +from bitblas.base.utils import apply_and_build +from bitblas.ops.impl.matmul_impl import ( + matmul_nt, + matmul_nt_propagate_b, + matmul_nt_propagate_a_propagate_b, +) +from bitblas.ops.impl.matmul_dequantize_impl import ( + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, +) +import time +import argparse + +bitblas.set_log_level("DEBUG") +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), +) +parser.add_argument( + "--group_size", + type=int, + default=-1, + help="The group size of the sequence", +) +parser.add_argument( + "--benchmark_sets", + nargs="+", + default=[ + "llm_shape_fp16xint4_with_scaling_zeros_original_g128", + "llm_shape_fp16xint4_with_scaling_zeros_rescale_g128", + "llm_shape_fp16xint4_with_scaling_zeros_quantized_g128" + ], + help="List of benchmark sets, e.g., llm_int8xint1_bs4096", +) + +args = parser.parse_args() +group_size = args.group_size + +llm_shape_fp16xfp16 = [ + (matmul_nt, (1, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt, (16, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt, (32, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt, (64, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt, (128, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt, (256, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt, (512, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt, (16384, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_b, (1, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_b, (16, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_b, (32, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_b, (64, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_b, (128, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_b, (256, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_b, (512, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_b, (16384, 16384, 16384, "float16", "float16", "float16"), Matmul), + (matmul_nt_propagate_a_propagate_b, (16, 16384, 16384, "float16", "float16", "float16"), + Matmul), + (matmul_nt_propagate_a_propagate_b, (32, 16384, 16384, "float16", "float16", "float16"), + Matmul), + (matmul_nt_propagate_a_propagate_b, (64, 16384, 16384, "float16", "float16", "float16"), + Matmul), + (matmul_nt_propagate_a_propagate_b, (128, 16384, 16384, "float16", "float16", "float16"), + Matmul), + (matmul_nt_propagate_a_propagate_b, (256, 16384, 16384, "float16", "float16", "float16"), + Matmul), + (matmul_nt_propagate_a_propagate_b, (512, 16384, 16384, "float16", "float16", "float16"), + Matmul), + (matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16", "float16"), + Matmul), +] + +group_size = 128 +with_scaling = False +with_zeros = False +zeros_mode = "original" +# fmt:off +llm_shape_fp16xint4 = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = 128 +with_scaling = True +with_zeros = False +zeros_mode = "original" +# fmt:off +llm_shape_fp16xint4_with_scaling = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = -1 +with_scaling = True +with_zeros = True +zeros_mode = "original" +# fmt:off +llm_shape_fp16xint4_with_scaling_zeros_original = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = -1 +with_scaling = True +with_zeros = True +zeros_mode = "rescale" +# fmt:off +llm_shape_fp16xint4_with_scaling_zeros_rescale = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = -1 +with_scaling = True +with_zeros = True +zeros_mode = "quantized" +# fmt:off +llm_shape_fp16xint4_with_scaling_zeros_quantized = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = 128 +with_scaling = True +with_zeros = True +zeros_mode = "original" +# fmt:off +llm_shape_fp16xint4_with_scaling_zeros_original_g128 = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = 128 +with_scaling = True +with_zeros = True +zeros_mode = "rescale" +# fmt:off +llm_shape_fp16xint4_with_scaling_zeros_rescale_g128 = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = 128 +with_scaling = True +with_zeros = True +zeros_mode = "quantized" +# fmt:off +llm_shape_fp16xint4_with_scaling_zeros_quantized_g128 = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = 128 +with_scaling = True +with_zeros = True +zeros_mode = "original" +# fmt:off +llm_shape_fp16xint4_with_scaling_zeros = [ + (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + with_scaling, with_zeros, group_size, True, False, zeros_mode), + Matmul), + (matmul_nt_dequantize_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = 128 +with_scaling = False +with_zeros = False +zeros_mode = "original" +llm_shape_fp16xint4_propagate_b = [ + # (matmul_nt_dequantize_b_propagate_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + # with_scaling, with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + # (matmul_nt_dequantize_b_propagate_b, (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + # with_scaling, with_zeros, group_size, True, False, zeros_mode), Matmul), + # (matmul_nt_dequantize_b_propagate_b, (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + # with_scaling, with_zeros, group_size, True, False, zeros_mode), Matmul), + # (matmul_nt_dequantize_b_propagate_b, (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + # with_scaling, with_zeros, group_size, True, False, zeros_mode), Matmul), + # (matmul_nt_dequantize_b_propagate_b, (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + # with_scaling, with_zeros, group_size, True, False, zeros_mode), Matmul), + # (matmul_nt_dequantize_b_propagate_b, (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + # with_scaling, with_zeros, group_size, True, False, zeros_mode), Matmul), + # (matmul_nt_dequantize_b_propagate_b, (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", + # with_scaling, with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = 128 +with_scaling = True +with_zeros = False +zeros_mode = "original" +llm_shape_fp16xint4_propagate_b_with_scaling = [ + (matmul_nt_dequantize_b_propagate_b, + (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, with_zeros, + group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +group_size = 128 +with_scaling = True +with_zeros = True +zeros_mode = "original" +llm_shape_fp16xint4_propagate_b_with_scaling_zeros = [ + (matmul_nt_dequantize_b_propagate_b, + (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, with_zeros, + group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (16, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (32, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (64, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (128, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (256, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (512, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), + (matmul_nt_dequantize_b_propagate_b, + (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", with_scaling, + with_zeros, group_size, True, False, zeros_mode), Matmul), +] + +# fmt:on + +target = tvm.target.Target(args.target) +benchmark_sets = [] +for benchmark_set in args.benchmark_sets: + benchmark_sets.extend(eval(benchmark_set)) +benchmark_results = {} + +for get_prim_func, input_args, d_schedule in benchmark_sets: + ir_module = get_prim_func(*input_args) + func = ir_module["main"] + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + tune_start = time.time() + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + fast_tune_time = time.time() - tune_start + print(best.code) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency)) + print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency)) + + # evaluate the performance of the default schedule + + rule = d_schedule() + default_tune_start = time.time() + with arch.target: + mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable + bitblas.gpu.Matmul(), + bitblas.gpu.GEMV(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + )( + ir_module) + try: + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + mod_default = tvm.build(mod, target="cuda") + except Exception: + mod_default = None + + default_tune_time = time.time() - default_tune_start + + args = func.buffer_map.values() + + profile_tensors = best.profile_tensors + if mod_default is not None: + timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) + t = timer_cuda_mod(*profile_tensors).mean + else: + t = 1e4 - 1 + + print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) + + profile_config = { + f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { + "fast_dlight_top20_tune_time": fast_tune_time, + "fast_dlight_top1_latency": cpresults[0].latency, + "fast_dlight_top20_latency": best.latency, + "default_dlight_tune_time": default_tune_time, + "default_dlight_latency": t * 1e3 if t is not None else "Failed", + } + } + + benchmark_results.update(profile_config) + +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Tune Time", + "BitBLAS Top1 Latency", + "BitBLAS Top20 Latency", + "DefaultDLight Tune Time", + "DefaultDLight Latency", +] + +col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 + ) # padding + +print("".join(word.ljust(col_width) for word in headers)) + +print("-" * col_width * len(headers)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f" {str(values['fast_dlight_top20_tune_time'])} s", + f"{values['fast_dlight_top1_latency']:.3f} ms", + f"{values['fast_dlight_top20_latency']:.3f} ms", + str(values["default_dlight_tune_time"]), + f"{values['default_dlight_latency']:.3e} ms", + ] + print("".join(word.ljust(col_width) for word in row)) diff --git a/benchmark/operators/benchmark_bitblas_matmul.py b/benchmark/operators/benchmark_bitblas_matmul.py new file mode 100644 index 000000000..fb3927bdb --- /dev/null +++ b/benchmark/operators/benchmark_bitblas_matmul.py @@ -0,0 +1,196 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas.utils.target_detector import auto_detect_nvidia_target +from bitblas import Matmul, MatmulConfig +import argparse + + +# Initialize the parser +parser = argparse.ArgumentParser( + description="Benchmark BitBLAS int4 on a specific target." +) + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), + help="Specify the target device for benchmarking." +) +parser.add_argument( + "--group_size", + type=int, + default=None, + help="Group size for grouped quantization." +) +parser.add_argument( + "--A_dtype", + type=str, + default="float16", + choices=["float16", "float32", "float64", "int32", "int8"], # Assuming these are the valid choices + help="Data type of activation A." +) +parser.add_argument( + "--W_dtype", + type=str, + default="int4", + choices=["float16", "float32", "float64", "int32", "int8", "int4", "int2", "int1", "nf4", "fp4_e2m1"], # Assuming these are the valid choices + help="Data type of weight W." +) +parser.add_argument( + "--accum_dtype", + type=str, + default="float16", + choices=["float16", "int32"], # Assuming these are the valid choices + help="Data type for accumulation." +) +parser.add_argument( + "--out_dtype", + type=str, + default="float16", + choices=["float16", "float32", "int32", "int8"], # Assuming these are the valid choices + help="Data type for output." +) +parser.add_argument( + "--layout", + type=str, + default="nt", + choices=["nt", "nn"], # Assuming these are the valid choices + help="Matrix layout, 'nt' for non-transpose A and transpose W." +) +parser.add_argument( + "--with_bias", + action="store_true", + help="Include bias in the benchmark." +) +parser.add_argument( + "--with_scaling", + action="store_true", + help="Include scaling factor in the quantization." +) +parser.add_argument( + "--with_zeros", + action="store_true", + help="Include zeros in the quantization." +) +parser.add_argument( + "--zeros_mode", + type=str, + default=None, + choices=["original", "rescale", "quantized"], # Replace with actual modes if applicable + help="Specify the mode for calculating zeros." +) + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +group_size = args.group_size +A_dtype = args.A_dtype +W_dtype = args.W_dtype +accum_dtype = args.accum_dtype +out_dtype = args.out_dtype +layout = args.layout +with_bias = args.with_bias +group_size = args.group_size +with_scaling = args.with_scaling +with_zeros = args.with_zeros +zeros_mode = args.zeros_mode + +test_shapes = [ + # square test + (MatmulConfig, Matmul, (1, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + # BLOOM-176B + (MatmulConfig, Matmul, (1, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + # # OPT-65B + (MatmulConfig, Matmul, (1, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + # # LLAMA-70B/65B + (MatmulConfig, Matmul, (1, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + + # square test + (MatmulConfig, Matmul, (16384, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + # BLOOM-176B + (MatmulConfig, Matmul, (8192, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + # # OPT-65B + (MatmulConfig, Matmul, (8192, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + # # LLAMA-70B/65B + (MatmulConfig, Matmul, (8192, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (8192, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)), +] + +benchmark_sets = [] +benchmark_sets.extend(test_shapes) + +# fmt:on + +benchmark_results = {} +for config, operator, input_args in benchmark_sets: + config = config(*input_args) + matmul = operator(config, target=target, enable_tuning=True) + kernel_latency = matmul.profile_latency() + if matmul.input_transform is not None: + kernel_latency += matmul.ladder_permutate_a.profile_latency() + + print("Time cost is: {:.3f} ms".format(kernel_latency)) + + profile_config = { + f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { + "BitBLAS_top20_latency": kernel_latency, + } + } + + benchmark_results.update(profile_config) + +# Define headers for the table +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Latency", +] + +col_widths = [0, 0, 0] +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0]) + col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, col_widths[1])) + col_widths[2] = max(max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, col_widths[2]) + break + +for i, header in enumerate(headers): + headers[i] = header.ljust(col_widths[i]) + +print("".join(headers)) + +print("-" * sum(col_widths)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f"{values['BitBLAS_top20_latency']:.3f} ms", + ] + print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]) + "\n") diff --git a/benchmark/operators/matmul.py b/benchmark/operators/matmul.py deleted file mode 100644 index e66e8b402..000000000 --- a/benchmark/operators/matmul.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import numpy as np -import tvm -from tvm.script import tir as T -import bitblas -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.gpu import Matmul -from bitblas.utils import get_target_from_env -from bitblas.base.utils import apply_and_build -from bitblas.ops.impl.matmul_impl import ( - matmul_nn, - matmul_nt, - matmul_nt, - matmul_nt_propagate_a_propagate_b, -) -import time - - -# fmt:off -test_shapes = [ - # (prim_func, input_args, default_dlight_schedule), - (matmul_nt, (1024, 1024, 1024, "float16", "float16"), Matmul), - (matmul_nt, (16, 8192, 8192, "float16", "float16"), Matmul), - (matmul_nt, (32, 8192, 8192, "float16", "float16"), Matmul), - (matmul_nt, (16384, 16384, 16384, "float16", "float16"), Matmul), - (matmul_nt, (16384, 16384, 16384, "int8", "int32"), Matmul), - (matmul_nn, (1024, 1024, 1024, "float16", "float16"), Matmul), - (matmul_nn, (8192, 8192, 8192, "float16", "float16"), Matmul), - (matmul_nn, (16384, 16384, 16384, "float16", "float16"), Matmul), - (matmul_nt, (1024, 1024, 1024, "float32", "float32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16", "float16"), Matmul), -] - - -llm_shapes = [ - # square test - (matmul_nt, (1, 16384, 16384, "float16", "float16"), Matmul), - # BLOOM-176B - (matmul_nt, (1, 43008, 14336, "float16", "float16"), Matmul), - (matmul_nt, (1, 14336, 14336, "float16", "float16"), Matmul), - (matmul_nt, (1, 57344, 14336, "float16", "float16"), Matmul), - (matmul_nt, (1, 14336, 57344, "float16", "float16"), Matmul), - # # OPT-65B - (matmul_nt, (1, 9216, 9216, "float16", "float16"), Matmul), - (matmul_nt, (1, 36864, 9216, "float16", "float16"), Matmul), - (matmul_nt, (1, 9216, 36864, "float16", "float16"), Matmul), - (matmul_nt, (1, 22016, 8192, "float16", "float16"), Matmul), - # # LLAMA-70B/65B - (matmul_nt, (1, 8192, 22016, "float16", "float16"), Matmul), - (matmul_nt, (1, 8192, 8192, "float16", "float16"), Matmul), - (matmul_nt, (1, 28672, 8192, "float16", "float16"), Matmul), - (matmul_nt, (1, 8192, 28672, "float16", "float16"), Matmul), - - # square test - (matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16"), Matmul), - # BLOOM-176B - (matmul_nt_propagate_a_propagate_b, (8192, 43008, 14336, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 14336, 14336, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 57344, 14336, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 14336, 57344, "float16", "float16"), Matmul), - # # OPT-65B - (matmul_nt_propagate_a_propagate_b, (8192, 9216, 9216, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 36864, 9216, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 9216, 36864, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 22016, 8192, "float16", "float16"), Matmul), - # # LLAMA-70B/65B - (matmul_nt_propagate_a_propagate_b, (8192, 8192, 22016, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 8192, 8192, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 28672, 8192, "float16", "float16"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 8192, 28672, "float16", "float16"), Matmul), - - # square test - (matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "int8", "int8", "int32"), Matmul), - # BLOOM-176B - (matmul_nt_propagate_a_propagate_b, (8192, 43008, 14336, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 14336, 14336, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 57344, 14336, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 14336, 57344, "int8", "int8", "int32"), Matmul), - # OPT-65B - (matmul_nt_propagate_a_propagate_b, (8192, 9216, 9216, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 36864, 9216, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 9216, 36864, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 22016, 8192, "int8", "int8", "int32"), Matmul), - # LLAMA-70B/65B - (matmul_nt_propagate_a_propagate_b, (8192, 8192, 22016, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 8192, 8192, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 28672, 8192, "int8", "int8", "int32"), Matmul), - (matmul_nt_propagate_a_propagate_b, (8192, 8192, 28672, "int8", "int8", "int32"), Matmul), -] - -benchmark_sets = [] -benchmark_sets.extend(llm_shapes) - -# fmt:on - -target = tvm.target.Target(get_target_from_env()) - -benchmark_results = {} -for get_prim_func, input_args, d_schedule in benchmark_sets: - ir_module = get_prim_func(*input_args) - func = ir_module["main"] - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - tune_start = time.time() - cpresults, best = apply_and_build(func, configs, arch, parallel_build=False) - fast_tune_time = time.time() - tune_start - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print( - "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) - ) - - # evaluate the performance of the default schedule - - rule = d_schedule() - default_tune_start = time.time() - with arch.target: - mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable - bitblas.gpu.Matmul(), - bitblas.gpu.GEMV(), - bitblas.gpu.Reduction(), - bitblas.gpu.GeneralReduction(), - bitblas.gpu.Fallback(), - )(ir_module) - try: - with tvm.transform.PassContext(config={"tir.use_async_copy": True}): - mod_default = tvm.build(mod, target="cuda") - except Exception: - mod_default = None - default_tune_time = time.time() - default_tune_start - - args = func.buffer_map.values() - - profile_tensors = best.profile_tensors - if mod_default is not None: - timer_cuda_mod = mod_default.time_evaluator( - mod_default.entry_name, arch.device, number=5 - ) - t = timer_cuda_mod(*profile_tensors).mean - else: - t = 1e4 - 1 - - print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) - - profile_config = { - f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { - "fast_dlight_top20_tune_time": fast_tune_time, - "fast_dlight_top1_latency": cpresults[0].latency * 1e3, - "fast_dlight_top20_latency": best.latency * 1e3, - "default_dlight_tune_time": default_tune_time, - "default_dlight_latency": t * 1e3 if t is not None else "Failed", - } - } - - benchmark_results.update(profile_config) - -headers = [ - "PrimFunc", - "Input Arguments", - "BitBLAS Top20 Tune Time", - "BitBLAS Top1 Latency", - "BitBLAS Top20 Latency", - "DefaultDLight Tune Time", - "DefaultDLight Latency", -] - -col_width = ( - max(len(word) for row in [headers] + list(profile_config.values()) for word in row) - + 2 -) # padding - -print("".join(word.ljust(col_width) for word in headers)) - -print("-" * col_width * len(headers)) - -for config, values in benchmark_results.items(): - args = config.split("-") - func_name = args[0] - input_args = "-".join(args[1:]) - row = [ - func_name, - input_args, - f" {str(values['fast_dlight_top20_tune_time'])} s", - f"{values['fast_dlight_top1_latency']:.3f} ms", - f"{values['fast_dlight_top20_latency']:.3f} ms", - str(values["default_dlight_tune_time"]), - f"{values['default_dlight_latency']:.3e} ms", - ] - print("".join(word.ljust(col_width) for word in row)) diff --git a/benchmark/operators/matmul_dequantize_af.py b/benchmark/operators/matmul_dequantize_af.py deleted file mode 100644 index 1c7f9e272..000000000 --- a/benchmark/operators/matmul_dequantize_af.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import numpy as np -import tvm -from tvm.script import tir as T -import bitblas -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.gpu import Matmul -from bitblas.utils import get_target_from_env -from bitblas.base.utils import apply_and_build -from bitblas.ops.impl.matmul_dequantize_impl import ( - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_a_propagate_b, -) -import time - - -# fmt:off -llm_shapes = [ - # square test - (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - # BLOOM-176B - (matmul_nt_dequantize_b, (1, 43008, 14336, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 14336, 14336, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 57344, 14336, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 14336, 57344, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - # # OPT-65B - (matmul_nt_dequantize_b, (1, 9216, 9216, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 36864, 9216, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 9216, 36864, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 22016, 8192, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - # LLAMA-70B/65B - (matmul_nt_dequantize_b, (1, 8192, 22016, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 8192, 8192, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 28672, 8192, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b, (1, 8192, 28672, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - - # square test - (matmul_nt_dequantize_b_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - # # BLOOM-176B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 43008, 14336, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 14336, 14336, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 57344, 14336, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 14336, 57344, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - # # OPT-65B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 9216, 9216, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 36864, 9216, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 9216, 36864, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 22016, 8192, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - # # LLAMA-70B/65B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 22016, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 8192, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 28672, 8192, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 28672, "float16", "float16", "float16", 4, "int8", "af", True, 128, False, False), Matmul), -] - -benchmark_sets = [] -benchmark_sets.extend(llm_shapes) - -# fmt:on - -target = tvm.target.Target(get_target_from_env()) - -benchmark_results = {} -for get_prim_func, input_args, d_schedule in benchmark_sets: - ir_module = get_prim_func(*input_args) - func = ir_module["main"] - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - tune_start = time.time() - cpresults, best = apply_and_build(func, configs, arch, parallel_build=False) - fast_tune_time = time.time() - tune_start - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print( - "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) - ) - - # evaluate the performance of the default schedule - - rule = d_schedule() - default_tune_start = time.time() - with arch.target: - mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable - bitblas.gpu.Matmul(), - bitblas.gpu.GEMV(), - bitblas.gpu.Reduction(), - bitblas.gpu.GeneralReduction(), - bitblas.gpu.Fallback(), - )(ir_module) - try: - with tvm.transform.PassContext(config={"tir.use_async_copy": True}): - mod_default = tvm.build(mod, target="cuda") - except Exception: - mod_default = None - - default_tune_time = time.time() - default_tune_start - - args = func.buffer_map.values() - - profile_tensors = best.profile_tensors - if mod_default is not None: - timer_cuda_mod = mod_default.time_evaluator( - mod_default.entry_name, arch.device, number=5 - ) - t = timer_cuda_mod(*profile_tensors).mean - else: - t = 1e4 - 1 - - print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) - - profile_config = { - f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { - "fast_dlight_top20_tune_time": fast_tune_time, - "fast_dlight_top1_latency": cpresults[0].latency * 1e3, - "fast_dlight_top20_latency": best.latency * 1e3, - "default_dlight_tune_time": default_tune_time, - "default_dlight_latency": t * 1e3 if t is not None else "Failed", - } - } - - benchmark_results.update(profile_config) - -headers = [ - "PrimFunc", - "Input Arguments", - "BitBLAS Top20 Tune Time", - "BitBLAS Top1 Latency", - "BitBLAS Top20 Latency", - "DefaultDLight Tune Time", - "DefaultDLight Latency", -] - -col_width = ( - max(len(word) for row in [headers] + list(profile_config.values()) for word in row) - + 2 -) # padding - -print("".join(word.ljust(col_width) for word in headers)) - -print("-" * col_width * len(headers)) - -for config, values in benchmark_results.items(): - args = config.split("-") - func_name = args[0] - input_args = "-".join(args[1:]) - row = [ - func_name, - input_args, - f" {str(values['fast_dlight_top20_tune_time'])} s", - f"{values['fast_dlight_top1_latency']:.3f} ms", - f"{values['fast_dlight_top20_latency']:.3f} ms", - str(values["default_dlight_tune_time"]), - f"{values['default_dlight_latency']:.3e} ms", - ] - print("".join(word.ljust(col_width) for word in row)) diff --git a/benchmark/operators/matmul_dequantize_int4.py b/benchmark/operators/matmul_dequantize_int4.py deleted file mode 100644 index e3ae814df..000000000 --- a/benchmark/operators/matmul_dequantize_int4.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import numpy as np -import tvm -from tvm.script import tir as T -import bitblas -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.gpu import Matmul -from bitblas.utils import get_target_from_env -from bitblas.base.utils import apply_and_build -from bitblas.ops.impl.matmul_dequantize_impl import ( - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_a_propagate_b, -) -import time - -group_size = -1 -# fmt:off -llm_shapes = [ - # square test - (matmul_nt_dequantize_b, (1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - # BLOOM-176B - (matmul_nt_dequantize_b, (1, 43008, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 14336, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 57344, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 14336, 57344, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - # # OPT-65B - (matmul_nt_dequantize_b, (1, 9216, 9216, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 36864, 9216, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 9216, 36864, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 22016, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - # LLAMA-70B/65B - (matmul_nt_dequantize_b, (1, 8192, 22016, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 8192, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 28672, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 8192, 28672, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - - # square test - (matmul_nt_dequantize_b_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - # BLOOM-176B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 43008, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 14336, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 57344, 14336, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 14336, 57344, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - # OPT-65B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 9216, 9216, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 36864, 9216, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 9216, 36864, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 22016, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - # LLAMA-70B/65B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 22016, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 28672, 8192, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 28672, "float16", "float16", "float16", 4, "int8", "int", True, False, group_size, True, False), Matmul), - - # square test - (matmul_nt_dequantize_b, (1, 16384, 16384, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - # BLOOM-176B - (matmul_nt_dequantize_b, (1, 43008, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 14336, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 57344, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 14336, 57344, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - # # OPT-65B - (matmul_nt_dequantize_b, (1, 9216, 9216, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 36864, 9216, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 9216, 36864, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 22016, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - # LLAMA-70B/65B - (matmul_nt_dequantize_b, (1, 8192, 22016, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 8192, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 28672, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b, (1, 8192, 28672, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - - # square test - (matmul_nt_dequantize_b_propagate_a_propagate_b, (16384, 16384, 16384, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - # BLOOM-176B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 43008, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 14336, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 57344, 14336, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 14336, 57344, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - # # OPT-65B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 9216, 9216, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 36864, 9216, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 9216, 36864, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 22016, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - # LLAMA-70B/65B - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 22016, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 28672, 8192, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - (matmul_nt_dequantize_b_propagate_a_propagate_b, (8192, 8192, 28672, "int8", "int8", "int32", 2, "int8", "uint", False, False, -1, True, False), Matmul), - -] - -benchmark_sets = [] -benchmark_sets.extend(llm_shapes) - -# fmt:on - -target = tvm.target.Target(get_target_from_env()) - -benchmark_results = {} -for get_prim_func, input_args, d_schedule in benchmark_sets: - ir_module = get_prim_func(*input_args) - func = ir_module["main"] - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - tune_start = time.time() - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - fast_tune_time = time.time() - tune_start - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print( - "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) - ) - - # evaluate the performance of the default schedule - - rule = d_schedule() - default_tune_start = time.time() - with arch.target: - mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable - bitblas.gpu.Matmul(), - bitblas.gpu.GEMV(), - bitblas.gpu.Reduction(), - bitblas.gpu.GeneralReduction(), - bitblas.gpu.Fallback(), - )(ir_module) - try: - with tvm.transform.PassContext(config={"tir.use_async_copy": True}): - mod_default = tvm.build(mod, target="cuda") - except Exception: - mod_default = None - - default_tune_time = time.time() - default_tune_start - - args = func.buffer_map.values() - - profile_tensors = best.profile_tensors - if mod_default is not None: - timer_cuda_mod = mod_default.time_evaluator( - mod_default.entry_name, arch.device, number=5 - ) - t = timer_cuda_mod(*profile_tensors).mean - else: - t = 1e4 - 1 - - print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) - - profile_config = { - f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { - "fast_dlight_top20_tune_time": fast_tune_time, - "fast_dlight_top1_latency": cpresults[0].latency * 1e3, - "fast_dlight_top20_latency": best.latency * 1e3, - "default_dlight_tune_time": default_tune_time, - "default_dlight_latency": t * 1e3 if t is not None else "Failed", - } - } - - benchmark_results.update(profile_config) - -headers = [ - "PrimFunc", - "Input Arguments", - "BitBLAS Top20 Tune Time", - "BitBLAS Top1 Latency", - "BitBLAS Top20 Latency", - "DefaultDLight Tune Time", - "DefaultDLight Latency", -] - -col_width = ( - max(len(word) for row in [headers] + list(profile_config.values()) for word in row) - + 2 -) # padding - -print("".join(word.ljust(col_width) for word in headers)) - -print("-" * col_width * len(headers)) - -for config, values in benchmark_results.items(): - args = config.split("-") - func_name = args[0] - input_args = "-".join(args[1:]) - row = [ - func_name, - input_args, - f" {str(values['fast_dlight_top20_tune_time'])} s", - f"{values['fast_dlight_top1_latency']:.3f} ms", - f"{values['fast_dlight_top20_latency']:.3f} ms", - str(values["default_dlight_tune_time"]), - f"{values['default_dlight_latency']:.3e} ms", - ] - print("".join(word.ljust(col_width) for word in row)) diff --git a/docs/ExtendOperatorsWithDSL.md b/docs/ExtendOperatorsWithDSL.md new file mode 100644 index 000000000..1c01602b9 --- /dev/null +++ b/docs/ExtendOperatorsWithDSL.md @@ -0,0 +1,168 @@ +### Using BitBLAS from DSL +```python +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.base.utils import apply_and_build +@tvm.script.ir_module +class MatmulNT: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [M, K], dtype=in_dtype) + B = T.match_buffer(b, [N, K], dtype=in_dtype) + C = T.match_buffer(c, [M, N], dtype=out_dtype) + + for i, j, k in T.grid(M, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ + vj, vk + ].astype(out_dtype) + +ir_module = MatmulNT +func = ir_module["main"] +target = tvm.target.Target("nvidia/nvidia-a100") +arch = CUDA(target) +``` + +Get tuning policy and candidates: + +```python +# Tune with SIMT Cuda Core +policy = DefaultPolicy(func=func, arch=arch) +try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) +except Exception: + tags = None +# Tune with Tensor Core if possible +if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + +configs = policy.emit_config(topk=20) +''' +[BitBLAS] Evaluation with config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.032 ms +[BitBLAS] Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.021 ms +[BitBLAS] Evaluation with config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.023 ms +[BitBLAS] Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.023 ms +[BitBLAS] Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.027 ms +[BitBLAS] Evaluation with config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.025 ms +[BitBLAS] Evaluation with config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.023 ms +[BitBLAS] Evaluation with config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.025 ms +[BitBLAS] Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.037 ms +[BitBLAS] Evaluation with config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.037 ms +[BitBLAS] Evaluation with config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.026 ms +[BitBLAS] Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.043 ms +[BitBLAS] Evaluation with config {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.042 ms +[BitBLAS] Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.025 ms +[BitBLAS] Evaluation with config {'block': [256, 32], 'warp': [128, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.029 ms +[BitBLAS] Evaluation with config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.028 ms +[BitBLAS] Evaluation with config {'block': [256, 64], 'warp': [128, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.027 ms +[BitBLAS] Evaluation with config {'block': [128, 256], 'warp': [64, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.044 ms +[BitBLAS] Evaluation with config {'block': [256, 128], 'warp': [128, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.040 ms +[BitBLAS] Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.047 ms +''' +``` + +Apply and build and get best code generation result: +```python +cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) +# get the best code generation result. +print(best.code) +''' +extern "C" __global__ void __launch_bounds__(128) default_function_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) { + ... +} +''' +``` + +we also provide something interesting with DSL. + +#### Auto Tensorization + +Say we currently have two policies, one is for SIMT Cuda Core, another is for TensorCore. The decision to utilize a TensorCore policy over a SIMT Cuda Core policy can be enhanced by the integration of an auto-tensorization strategy, it allows BitBLAS to automatically select if the DSL Expression can uitlize TensorCore. + +![Auto Tensorization](./images/auto_tensorize.png) + +```python +# Assume func is conv2d, after this invocation, the tensorized_func is the tensorized version of the conv2d, otherwise, the tags is None. +tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) +``` + +#### Tune with dynamic symbolic + +As in LLM Serving, the input shape is dynamic, we can use the dynamic symbolic to generate high performance kernel with dynamic shape. + +```python +@tvm.script.ir_module +class MatmulNT: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.int32() + A = T.match_buffer(a, [m, K], dtype=in_dtype) + B = T.match_buffer(b, [N, K], dtype=in_dtype) + C = T.match_buffer(c, [m, N], dtype=out_dtype) + + for i, j, k in T.grid(m, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ + vj, vk + ].astype(out_dtype) + +from bitblas import fast_tune_with_dynamic_range +# Tune with dynamic symbolic +optimized_mod = fast_tune_with_dynamic_range( + func, target, topk=topk, parallel_build=True, + dynamic_range={ + "M": [1, 1024] + } +) + +# fianlly, we will generate a dispatch func to dispatch the kernel with dynamic symbolic. +''' +@IRModule +class MatmulNT: + + def matmul_nt_opt_m_1(A: Tensor, T_reshape: Tensor, m: int): + ... + + def matmul_nt_opt_m_256(A: Tensor, T_reshape: Tensor, m: int): + ... + + def dispatcher(args): + if m <= 1: + matmul_nt_opt_m_1(A.data, T_reshape.data, m) + if m > 1 and m <= 256: + matmul_nt_opt_m_256(A.data, T_reshape.data, m) + if m > 256: + matmul_nt_m_256(A.data, T_reshape.data, m) +''' +``` + +You can find some example dsl implementation in `python/bitblas/ops/impl` and `benchmark/dsl`, see more examples and tutorials in [apache/tvm](https://github.com/apache/tvm) + diff --git a/docs/Installation.md b/docs/Installation.md new file mode 100644 index 000000000..fd90f6d45 --- /dev/null +++ b/docs/Installation.md @@ -0,0 +1,58 @@ +# Installation Guide + +## Prerequisites + + **Operating System**: Linux (Ubuntu 20.04 or later recommended for installation via wheel or PyPI or you may need to checkout the [Building from Source](#building-from-source) section for other Linux distributions.) +- **Python Version**: >= 3.7 +- **CUDA Version**: >= 10.0 + +## Installing with pip + +The easiest way to install BitBLAS is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal. + +**Note**: Currently, bitblas whl is only supported on Linux systems. We recommend using Ubuntu 20.04 or later version as we build the whl files on this platform. + +```bash +pip install bitblas +``` + +Alternatively, you may choose to install BitBLAS using prebuilt packages available on the Release Page: + +```bash +pip install bitblas-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl +``` + +After installing BitBLAS, you can verify the installation by running: + +```bash +python -c "import bitblas; print(bitblas.__version__)" +``` + +## Building from Source + +We recommend using a docker container with the necessary dependencies to build BitBLAS from source. You can use the following command to run a docker container with the necessary dependencies: + +```bash +docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3 +``` + +To build and install BitBLAS directly from source, follow the steps below. This process requires certain pre-requisites from apache tvm, which can be installed on Ubuntu/Debian-based systems using the following commands: + +```bash +sudo apt-get update +sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev +``` + +After installing the prerequisites, you can clone the BitBLAS repository and install it using pip: + +```bash +git clone --recursive https://github.com/Microsoft/BitBLAS.git +cd BitBLAS +pip install . # Please be patient, this may take some time. +``` + +if you want to install BitBLAS with the development mode, you can run the following command: + +```bash +pip install -e . +``` diff --git a/docs/PythonAPI.md b/docs/PythonAPI.md new file mode 100644 index 000000000..e2925a147 --- /dev/null +++ b/docs/PythonAPI.md @@ -0,0 +1,150 @@ +## Matmul + +`Matmul` is an operator class that performs matrix multiplication, supporting various optimizations and quantization strategies. + +### MatmulConfig: + +`MatmulConfig` is a configuration class for the `Matmul` operator, specifying the matrix multiplication operation's parameters and behaviors. + +### Parameters: + +- **M** *(Union[int, Tuple[int]])*: The size of the first dimension of the matrix A, or a range of sizes if dynamic shape support is needed. Can be an integer or a tuple representing the dynamic range. + - If `int`, the bitblas matmul will generate a static shape kernel, which can only be used for the input shape of the specified value. + - If `List[int]`, the bitblas matmul will generate a dynamic shape kernel, which can be used for the input shape of the specified values. While the input shape represents the target optimized range. + - If `None`, the bitblas matmul will use a default value [1, 16, 32, 64, 128, 256, 512, 1024]. +- **N** *(int)*: The size of the second dimension of matrix W and the output matrix. +- **K** *(int)*: The common dimension of matrices A and W. +- **A_dtype** *(str, default='float16')*: The data type of matrix A. + - Choices: `'float16'`, `'int8'`. +- **W_dtype** *(str, default='float16')*: The data type of matrix W. Also acts as a wrapper for source_format and bit. + - Choices: `'float16'`, `'int8'`, `'int4'`, `'int2'`, `'int1'`, `'fp4_e2m1'`, `'nf4'`. +- **accum_dtype** *(str, default='float16')*: The data type used for accumulation during the matrix multiplication. + - Choices: `'float16'`, `'int32'`. +- **out_dtype** *(str, default='float16')*: The data type of the output matrix. + - Choices: `'float32'`, `'float16'`, `'int8'`, `'int32'`. +- **layout** *(Literal['nn', 'nt', 'tn', 'tt'], default='nt')*: The layout of the matrix multiplication operation. The matrix is stored in row-major. + - `'nn'`: Both matrices are non-transposed. + - `'nt'`: Matrix A is non-transposed, and matrix W is transposed. + - `'tn'`: Matrix A is transposed, and matrix W is non-transposed. + - `'tt'`: Both matrices are transposed. +- **with_bias** *(bool, default=False)*: Indicates whether a bias vector is added to the output. +- **group_size** *(int, default=-1)*: The group size for quantization, -1 indicates no grouping. +- **with_scaling** *(bool, default=False)*: Indicates whether scaling is applied during quantization. +- **with_zeros** *(bool, default=False)*: Indicates whether zero optimization is applied. +- **zeros_mode** *(Literal['original', 'rescale', 'quantized'], default='original')*: The mode of zero optimization. + - Choices: `None`, `'original'`, `'rescale'`, `'quantized'`. + - `'original'`: Subtract zero-point before scaling. Formula: `target = (dequantize_weight - zero_point) * scale`. where `zero_point` has the same datatype with scale. + - `'rescale'`: Apply scale factor directly to dequantized weight and then subtract zero-point. Formula: `target = dequantize_weight * scale - zero_point`. + - `'quantized'`: Apply zero-point adjustment after dequantization and additional dequantization of zero values. Formula: `target = (dequantize_weight - dequantize_qzeros) * scale`, where `dequantize_zeros` represents the dequantized representation of zero values, which can be adapted to qzeros params. + +### Initialization: + +```python +Matmul(config: MatmulConfig) +``` + +- **config** *(MatmulConfig)*: The configuration for the matrix multiplication operation. + +### Methods: + +#### `forward(A, W, scale=None, seros=None, bias=None, output=None) -> Any` + +Performs the matrix multiplication operation with the given input tensors and optional scaling, zeros, and bias. + +- **A** *(Tensor)*: The input tensor A. +- **W** *(Tensor)*: The input tensor W. +- **scale** *(Optional[Tensor], default=None)*: The scaling tensor. +- **zeros** *(Optional[Tensor], default=None)*: The zeros tensor. +- **bias** *(Optional[Tensor], default=None)*: The bias tensor. +- **output** *(Optional[Tensor], default=None)*: The pre-allocated output tensor. + +#### `transform_weight(weight, scale=None, zeros=None, bias=None)` + +Transforms the given weight tensor based on the specified quantization parameters. + +- **weight** *(Tensor)*: The input weight tensor to be transformed. +- **scale** *(Optional[Tensor], default=None)*: Scaling factor for the weight tensor. +- **zeros** *(Optional[Tensor], default=None)*: Zero-point adjustment for the weight tensor. +- **bias** *(Optional[Tensor], default=None)*: Bias to be added to the weight tensor. + +#### `__call__(*args: Any) -> Any` + +Allows the object to be called like a function, forwarding the call to the `forward` method. + +### Properties: + +- **M**, **N**, **K**, **A_dtype**, **W_dtype**, **out_dtype**, **accum_dtype**, **storage_dtype**, **with_scaling**, **with_zeros**, **group_size**, **fast_decoding**, **with_bias**, **layout**, **zeros_mode**: These properties correspond to the parameters defined in `MatmulConfig`, providing easy access to the configuration details. + + +## Linear + +`Linear(in_features: int, out_features: int, bias: bool = False, A_dtype: str = 'float16', W_dtype: str = 'float16', accum_dtype: str = 'float16', out_dtype: str = 'float16', group_size: int = -1, with_scaling: bool = None, with_zeros: bool = False, zeros_mode: str = None, opt_M: Union[int, List[int]] = [1, 16, 32, 64, 128, 256, 512])` + +Applies a linear transformation to the incoming data: $out[M, N] = A[M, K] \times W[N, K]$ . This module supports quantization and optimization for NVIDIA GPUs using the BitBLAS library. + +### Parameters: + +- **in_features** *(int)*: size of each input sample. +- **out_features** *(int)*: size of each output sample. +- **bias** *(bool, optional)*: If set to `False`, the layer will not learn an additive bias. Default: `False`. +- **A_dtype** *(str, optional)*: Data type of the input tensor. Default: `'float16'`. + - Choices: `'float16'`, `'int8'`. +- **W_dtype** *(str, optional)*: Data type of the weights. Default: `'float16'`. + - Choices: `'float16'`, `'int8'`, `'int4'`, `'int2'`, `'int1'`, `'fp4_e2m1'`, `'af4'`. +- **accum_dtype** *(str, optional)*: Data type for accumulation. Default: `'float16'`. + - Choices: `'float16'`, `'int32'`. +- **out_dtype** *(str, optional)*: Data type of the output tensor. Default: `'float16'`. + - Choices: `'float32'`, `'float16'`, `'int8'`, `'int32'`. +- **group_size** *(int, optional)*: Group size for quantization. Default: `-1` (no grouping). +- **with_scaling** *(bool, optional)*: Whether to use scaling during quantization. Default: `False`. +- **with_zeros** *(bool, optional)*: Whether to use zeropoints . Default: `False`. +- **zeros_mode** *(str, optional)*: Mode for zero zeropoints. Default: `None`. + - Choices: `None`, `'original'`, `'rescale'`, `'quantized'`. + - `'original'`: Subtract zero-point before scaling. Formula: `target = (dequantize_weight - zero_point) * scale`. where `zero_point` has the same datatype with scale. + - `'rescale'`: Apply scale factor directly to dequantized weight and then subtract zero-point. Formula: `target = dequantize_weight * scale - zero_point`. + - `'quantized'`: Apply zero-point adjustment after dequantization and additional dequantization of zero values. Formula: `target = (dequantize_weight - dequantize_qzeros) * scale`, where `dequantize_zeros` represents the dequantized representation of zero values, which can be adapted to qzeros params. +- **opt_M** *(Union[int, List[int]], optional)*: Optimize range of the input shape for dynamic symbolic. Default: `[1, 16, 32, 64, 128, 256, 512]`. + - If `int`, the bitblas matmul will generate a static shape kernel, which can only be used for the input shape of the specified value. + - If `List[int]`, the bitblas matmul will generate a dynamic shape kernel, which can be used for the input shape of the specified values. While the input shape represents the target optimized range. It is important to note that if an input size is provided that is not explicitly listed, such as 15, bitblas matmul will select the nearest larger kernel available. In the case where opt_M is `[1, 16, 32, 64, 128, 256, 512]`, an input size of 15 would utilize the kernel optimized for size 16. T + +### Methods: + +#### `forward(A, output=None)` + +Defines the computation performed at every call. + +- **A** *(Tensor)*: Input tensor. +- **Output** *(Tensor, optional)*: Pre-allocated output tensor. Default: `None`. + - If `None`, the module will allocate a new tensor for the output. + - If not `None`, the module will use the pre-allocated tensor for the output. + +Returns: The output tensor. + +#### `init_params()` + +Initializes parameters handles (convert constant params into ctypes void pointer) for the computation. We currently put this fuction in the forward function, so you do not need to call it manually. But if you lift this function out of the forward function, you can call it manually to aoid the transformation. + +#### `load_and_transform_weight(weight, scales=None, zeros=None, bias=None)` + +This method is designed to load and optionally transform the weight matrix along with scales, zeros, and bias for use in quantized computations. It is particularly useful when transitioning from a floating-point model to a quantized model, allowing for the adjustment of model parameters to fit the requirements of quantization and optimization processes. + +- **Parameters:** + - **weight** *(Tensor)*: The weight tensor to be loaded into the layer. This tensor should have dimensions that match the expected input features and output features of the layer. The method will also apply any necessary transformations to the weight tensor to align with the quantization and optimization configurations of the layer. + - **scales** *(Tensor, optional)*: A tensor containing scale factors for quantization. These scales are used to adjust the weight values during the quantization process, ensuring that the dynamic range of the weights is appropriately represented in the quantized format. If not provided, the method assumes that either scaling is not required or has already been applied to the weights. + - **zeros** *(Tensor, optional)*: A tensor indicating the optimized representation of zeros, particularly useful in sparse models where zero values can be efficiently encoded. This parameter is only relevant if zero points (`with_zeros`) is enabled for the layer. Providing this tensor allows for further memory and computation optimizations during the forward pass. + - **bias** *(Tensor, optional)*: The bias tensor to be loaded into the layer. If the layer is configured to use a bias (`bias=True` during initialization), this tensor provides the bias values for each output feature. If `None`, it is assumed that the layer does not use a bias or that the bias is already incorporated into another parameter. +`load_and_transform_weight(weight, scales=None, zeros=None, bias=None)` + +Loads and transforms the weight matrix and optional scales, zeros, and bias for quantized computation. + +- **weight** *(Tensor)*: Weight tensor. +- **scales** *(Tensor, optional)*: Scales tensor for quantization. Default: `None`. +- **zeros** *(Tensor, optional)*: Zeros tensor for zeropoints. Default: `None`. +- **bias** *(Tensor, optional)*: Bias tensor. Default: `None`. + +### `repack_from_gptq(gptq_module)` + +This method facilitates the integration of parameters from a module that has undergone Generalized Post Training Quantization (GPTQ), repacking and transforming these parameters as necessary for compatibility with the BitBLAS-optimized `Linear` layer. The `gptq_module` must have its parameters in a format that is compatible with the expectations of the `Linear` layer's quantization and optimization configurations. This includes the shape and data type of the quantized weights, scales, and zeros. The method automatically handles the transformation and repacking of these parameters, including transposing weights if necessary, converting quantized zeros into the expected format, and adjusting scales and biases for direct use in the optimized forward pass of the `Linear` layer. + +- **Parameters:** + - **gptq_module** *(Module)*: A module that contains quantized parameters following the GPTQ process. This module should have attributes corresponding to quantized weights (`qweight`), scales (`scales`), optimized zeros (`qzeros`), and optionally biases (`bias`). The method extracts these parameters, applies any necessary transformations for compatibility with the BitBLAS optimizations, and loads them into the `Linear` layer. diff --git a/docs/QuickStart.md b/docs/QuickStart.md index c0f427026..2285a2313 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -1,267 +1,234 @@ -## Quick Start +# Quick Start -We provide two primary ways to do the code generation: using a high-level DSL (TensorIR Script), or using packed Operators. +BitBLAS provides two Python APIs to perform mixed-precision matrix multiplication: + - ```bitblas.Matmul``` implements the $W_{wdtype}A_{adtype}$ mixed-precision matrix multiplication of $C_{cdtype}[M, N] = A_{adtype}[M, K] \times W_{wdtype}[N, K]$ where $W_{wdtype}$ indicates the weight of $wtype$, A_{adtype} indicates the activation of $adtype$, and C_{cdtype} indicates the output of $cdtype$. + - ```bitblas.Linear``` is a PyTorch ```nn.Linear```-like module to support a Linear of mixed-precision. -You can find some example dsl implementation in `python/bitblas/ops/impl` and `benchmark/dsl`, see more examples and tutorials in [apache/tvm](https://github.com/apache/tvm) +## Example: $W_{INT4}A_{FP16}$ mixed-precision matrix multiplication -### Using BitBLAS from DSL -```python -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.base.utils import apply_and_build -@tvm.script.ir_module -class MatmulNT: - @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [M, K], dtype=in_dtype) - B = T.match_buffer(b, [N, K], dtype=in_dtype) - C = T.match_buffer(c, [M, N], dtype=out_dtype) - - for i, j, k in T.grid(M, N, K): - with T.block("B"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = tvm.tir.const(0, out_dtype) - C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ - vj, vk - ].astype(out_dtype) - -ir_module = MatmulNT -func = ir_module["main"] -target = tvm.target.Target("nvidia/nvidia-a100") -arch = CUDA(target) -``` - -Get tuning policy and candidates: - -```python -# Tune with SIMT Cuda Core -policy = DefaultPolicy(func=func, arch=arch) -try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) -except: - tags = None -# Tune with Tensor Core if possible -if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - -configs = policy.emit_config(topk=20) -''' -[BitBLAS] Evaluation with config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.032 ms -[BitBLAS] Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.021 ms -[BitBLAS] Evaluation with config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.023 ms -[BitBLAS] Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.023 ms -[BitBLAS] Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.027 ms -[BitBLAS] Evaluation with config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.025 ms -[BitBLAS] Evaluation with config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.023 ms -[BitBLAS] Evaluation with config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.025 ms -[BitBLAS] Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.037 ms -[BitBLAS] Evaluation with config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.037 ms -[BitBLAS] Evaluation with config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.026 ms -[BitBLAS] Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.043 ms -[BitBLAS] Evaluation with config {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.042 ms -[BitBLAS] Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.025 ms -[BitBLAS] Evaluation with config {'block': [256, 32], 'warp': [128, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.029 ms -[BitBLAS] Evaluation with config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.028 ms -[BitBLAS] Evaluation with config {'block': [256, 64], 'warp': [128, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.027 ms -[BitBLAS] Evaluation with config {'block': [128, 256], 'warp': [64, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.044 ms -[BitBLAS] Evaluation with config {'block': [256, 128], 'warp': [128, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.040 ms -[BitBLAS] Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} -[BitBLAS] Time cost of this config: 0.047 ms -''' -``` +Here is an example for a $W_{INT4}A_{FP16}$ mixed-precision matrix multiplication: $out_{FP16}[M, N] = A_{FP16}[M, K] \times W_{INT4}[N, K]$, the example includes the creation of input matrices, quantization of weight matrices, and execution of the multiplication. The result is then compared against a reference result obtained through conventional methods to ensure accuracy. -Apply and build and get best code generation result: ```python -cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) -# get the best code generation result. -print(best.code) -''' -extern "C" __global__ void __launch_bounds__(128) default_function_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) { - ... -} -''' -``` +import bitblas +import torch + +matmul_config = bitblas.MatmulConfig( + M=1, # M dimension + N=1024, # N dimension + K=1024, # K dimension + A_dtype="float16", # activation A dtype + W_dtype="int4", # weight W dtype + accum_dtype="float16", # accumulation dtype + out_dtype="float16", # output dtype + layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose + with_bias=False, # bias + # configs for weight only quantization + group_size=None, # setting for grouped quantization + with_scaling=False, # setting for scaling factor + with_zeros=False, # setting for zeros + zeros_mode=None, # setting for how to calculating zeros +) -we also provide something interesting with DSL. +matmul = bitblas.Matmul(config=matmul_config) -#### Auto Tensorization +# Create input matrices +input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda() +weight_tensor = torch.randint(0, 7, (1024, 1024), dtype=torch.int8).cuda() -Say we currently have two policies, one is for SIMT Cuda Core, another is for TensorCore. The decision to utilize a TensorCore policy over a SIMT Cuda Core policy can be enhanced by the integration of an auto-tensorization strategy, it allows BitBLAS to automatically select if the DSL Expression can uitlize TensorCore. +# Transform weight tensor to int4 data type +weight_tensor_int4 = matmul.transform_weight(weight_tensor) -![Auto Tensorization](./images/auto_tensorize.png) +# Perform mixed-precision matrix multiplication +output_tensor = matmul(input_tensor, weight_tensor_int4) -```python -# Assume func is conv2d, after this invocation, the tensorized_func is the tensorized version of the conv2d, otherwise, the tags is None. -tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) +# Reference result using PyTorch matmul for comparison +ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16)) +# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0 +print("Ref output:", ref_result) +print("BitBLAS output:", output_tensor) +torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0) ``` -#### Tune with dynamic symbolic - -As in LLM Serving, the input shape is dynamic, we can use the dynamic symbolic to generate high performance kernel with dynamic shape. +The same example can be extended to include the quantization of the weight tensor with scaling and zeros. The following code snippet demonstrates how to quantize the weight tensor with scaling and zeros and execute the mixed-precision matrix multiplication. ```python -@tvm.script.ir_module -class MatmulNT: - @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - m = T.int32() - A = T.match_buffer(a, [m, K], dtype=in_dtype) - B = T.match_buffer(b, [N, K], dtype=in_dtype) - C = T.match_buffer(c, [m, N], dtype=out_dtype) - - for i, j, k in T.grid(m, N, K): - with T.block("B"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = tvm.tir.const(0, out_dtype) - C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ - vj, vk - ].astype(out_dtype) - -from bitblas import fast_tune_with_dynamic_range -# Tune with dynamic symbolic -optimized_mod = fast_tune_with_dynamic_range( - func, target, topk=topk, parallel_build=True, - dynamic_range={ - "M": [1, 1024] - } +import bitblas +import torch + +in_features = 1024 +out_features = 1024 +group_size = 128 + +matmul_config = bitblas.MatmulConfig( + M=1, # M dimension + N=out_features, # N dimension + K=in_features, # K dimension + A_dtype="float16", # activation A dtype + W_dtype="uint4", # weight W dtype + accum_dtype="float16", # accumulation dtype + out_dtype="float16", # output dtype + layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose + with_bias=False, # bias + # configs for weight only quantization + group_size=group_size, # setting for grouped quantization + with_scaling=True, # setting for scaling factor + with_zeros=True, # setting for zeros + zeros_mode="original", # setting for how to calculating zeros ) - -# fianlly, we will generate a dispatch func to dispatch the kernel with dynamic symbolic. -''' -@IRModule -class MatmulNT: - - def matmul_nt_opt_m_1(A: Tensor, T_reshape: Tensor, m: int): - ... - - def matmul_nt_opt_m_256(A: Tensor, T_reshape: Tensor, m: int): - ... - - def dispatcher(args): - if m <= 1: - matmul_nt_opt_m_1(A.data, T_reshape.data, m) - if m > 1 and m <= 256: - matmul_nt_opt_m_256(A.data, T_reshape.data, m) - if m > 256: - matmul_nt_m_256(A.data, T_reshape.data, m) -''' - +matmul = bitblas.Matmul(config=matmul_config) + +# Define shapes for tensors +input_shape = (1, 1024) +weight_shape = (1024, 1024) +scaling_shape = (1024, 1024 // 128) +zeros_shape = (1024, 1024 // 128) +output_shape = (1, 1024) + +# Create scaling and zeros tensors for quantization +scaling = torch.rand(scaling_shape, dtype=torch.float16).cuda() +zeros = torch.rand(zeros_shape, dtype=torch.float16).cuda() + +# Create input tensor +input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda() + +# Create and transform weight tensor +weight_tensor = torch.randint(0, 7, weight_shape, dtype=torch.int8).cuda() +weight_tensor_int4 = matmul.transform_weight(weight_tensor) + +# Perform mixed-precision matrix multiplication with quantization +output_tensor = matmul(input_tensor, weight_tensor_int4, scale=scaling, zeros=zeros) + +rescaling_tensor = torch.zeros_like(weight_tensor, dtype=torch.float16).cuda() +# Compute reference result with manual scaling and zero-point adjustment +# rescale = (weight - zeros) * scaling +for i in range(in_features // group_size): + for j in range(group_size): + rescaling_tensor[:, i * group_size + j] = ( + weight_tensor[:, i * group_size + j].to(torch.float16) - zeros[:, i] + ) * scaling[:, i] +ref_result = torch.matmul(input_tensor, rescaling_tensor.t().to(torch.float16)) +# Assert that the results are close within a specified tolerance +print("Ref output:", ref_result) +print("BitBLAS output:", output_tensor) +torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-2) ``` +The init stage of the ```bitblas.Matmul``` class will take minutes to finish, as it will use hardware informations to do a one-time kernel library initialization. +## Example: bitblas.Linear module for PyTorch -### Using BitBLAS from packed Operators +BitBLAS also implemented a variant PyTorch ```nn.Linear``` module, i.e., ```bitblas.Linear```, to support a Linear of mixed-precision. See code [implementation](../python/bitblas/module/__init__.py) -We packed some operators in `bitblas/ops/impl` with configs, you can use them directly. Please see more examples in `testing/python/operators` +Here is an example to define a ```bitblas.Linear``` of $W_{INT4}A_{FP16}$: ```python -from bitblas.ops.matmul import Matmul, MatmulConfig -matmul_config = MatmulConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, -) -matmul = Matmul( - config=matmul_config, - target=target, +import bitblas +import torch + +model = bitblas.Linear( + in_features=1024, + out_features=1024, + bias=False, + A_dtype="float16", # activation A dtype + W_dtype="int4", # weight W dtype + accum_dtype="float16", # accumulation dtype + out_dtype="float16", # output dtype + # configs for weight only quantization + group_size=None, # setting for grouped quantization + with_scaling=False, # setting for scaling factor + with_zeros=False, # setting for zeros + zeros_mode=None, # setting for how to calculating zeros + # Target optimization var for dynamic symbolic. + # For detailed information please checkout docs/PythonAPI.md + # By default, the optimization var is [1, 16, 32, 64, 128, 256, 512] + opt_M=[1, 16, 32, 64, 128], ) -``` -By default, we will apply a default schedule into the operator, you can also get code generation result by calling matmul.codegen(). +# Create an integer weight tensor +intweight = torch.randint(-7, 7, (1024, 1024), dtype=torch.int8) -```python -print(matmul.codegen()) -''' -extern "C" __global__ void __launch_bounds__(128) default_function_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) { - ... -} -''' -``` +# Load and transform weights into the BitBLAS linear module +model.load_and_transform_weight(intweight) -If you want to tune the operator to get better performance, you can use the api `hardware_aware_finetune`. +# Save the state of the model +torch.save(model.state_dict(), "./model.pth") -```python -print(matmul.profile_latency()) -matmul.hardware_aware_finetune(topk=20) -print(matmul.profile_latency()) +# Load the model state +model.load_state_dict(torch.load("./model.pth")) + +# Set the model to evaluation mode +model.eval() + +# Create a dummy input tensor +dummpy_input = torch.randn(1, 1024, dtype=torch.float16) + +# Perform inference +output = model(dummpy_input) +print("BitBLAS output:", output) +# Please checkout the correctness evaluation code in `testing/python/module/test_bitblas_linear.py` ``` -The latency will be reduced after tuning. We re-implement OSDI'22 paper Roller to do fast tuning with hardware information. Typically, the 20 candidates is good enough. -#### Tune with Dynamic Symbolic +we also provide repack interface to repack the pretrained weight of AutoGPTQ into the format of BitBLAS. Here is an example to repack the pretrained weight of AutoGPTQ: ```python -matmul_config = MatmulConfig( - M=[1, 1024], - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, +# !pip install auto-gptq +import bitblas +import torch +from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( + QuantLinear as CudaOldQuantLinear, ) -``` -#### Tune with FPA INTB Operators -Generate High Performance Kernel for WeightOnly Quantization. +in_features = 1024 +out_features = 1024 +group_size = 128 -```python -from bitblas.ops.matmul_dequantize import ( - MatmulWeightOnlyDequantize, - MatmulWeightOnlyDequantizeConfig, +original_w, linear, s, qw = bitblas.quantization.gen_quant4( + in_features, out_features, group_size ) -matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, +zeros = torch.full((in_features // group_size, out_features), 7, dtype=torch.int32) + +cuda_old_linear = CudaOldQuantLinear( + bits=4, group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, + infeatures=in_features, + outfeatures=out_features, + bias=False, ) -matmul = MatmulWeightOnlyDequantize( - config=matmul_config, - target=target, +cuda_old_linear.pack(linear, s.T, zeros.T, g_idx=None) + +bitblas_linear = bitblas.Linear( + in_features=in_features, + out_features=out_features, + bias=False, + A_dtype="float16", # activation A dtype + W_dtype="uint4", # weight W dtype + accum_dtype="float16", # accumulation dtype + out_dtype="float16", # output dtype + # configs for weight only quantization + group_size=group_size, # setting for grouped quantization + with_scaling=True, # setting for scaling factor + with_zeros=True, # setting for zeros + zeros_mode="quantized", # setting for how to calculating zeros ) -``` +# Repack weights from CudaOldQuantLinear to BitBLAS linear module +bitblas_linear.repack_from_gptq(cuda_old_linear) + +# Prepare input data +m = 1 # Batch size +inp = torch.rand(m, in_features, dtype=torch.float16, device="cuda") + +# Move models to CUDA for execution +cuda_old_linear = cuda_old_linear.to("cuda") +bitblas_linear = bitblas_linear.to("cuda") + +# Perform inference without gradient calculations +with torch.no_grad(): + res_cuda_old = cuda_old_linear(inp) + res_bitblas = bitblas_linear(inp) + +print("CudaOldQuantLinear output:", res_cuda_old) +print("BitBLAS output:", res_bitblas) + +# Verify the outputs are close within specified tolerances +torch.testing.assert_close(res_bitblas, res_cuda_old, rtol=1e-0, atol=1e-1) +``` \ No newline at end of file diff --git a/format.sh b/format.sh new file mode 100755 index 000000000..b6974fa60 --- /dev/null +++ b/format.sh @@ -0,0 +1,185 @@ +#!/usr/bin/env bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Usage: +# # Do work and commit your work. + +# # Format files that differ from origin/main. +# bash format.sh + +# # Commit changed files with message 'Run yapf and ruff' +# +# +# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# You are encouraged to run this locally before pushing changes for review. + +# Cause the script to exit if a single command fails +set -eo pipefail + +# this stops git rev-parse from failing if we run this from the .git directory +builtin cd "$(dirname "${BASH_SOURCE:-$0}")" +ROOT="$(git rev-parse --show-toplevel)" +builtin cd "$ROOT" || exit 1 + +YAPF_VERSION=$(yapf --version | awk '{print $2}') +RUFF_VERSION=$(ruff --version | awk '{print $2}') +CODESPELL_VERSION=$(codespell --version) + +# # params: tool name, tool version, required version +tool_version_check() { + if [[ $2 != $3 ]]; then + echo "Wrong $1 version installed: $3 is required, not $2." + exit 1 + fi +} + +tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" + +echo 'bitblas yapf: Check Start' + +YAPF_FLAGS=( + '--recursive' + '--parallel' +) + +YAPF_EXCLUDES=( + '--exclude' 'build/**' +) + +# Format specified files +format() { + yapf --in-place "${YAPF_FLAGS[@]}" "$@" +} + +# Format files that differ from main branch. Ignores dirs that are not slated +# for autoformat yet. +format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause yapf to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \ + yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" + fi + +} + +# Format all files +format_all() { + yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" . +} + +## This flag formats individual files. --files *must* be the first command line +## arg to use this option. +if [[ "$1" == '--files' ]]; then + format "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is formatted. +elif [[ "$1" == '--all' ]]; then + format_all +else + # Format only the files that changed in last commit. + format_changed +fi +echo 'bitblas yapf: Done' + +echo 'bitblas codespell: Check Start' +# check spelling of specified files +spell_check() { + codespell "$@" +} + +spell_check_all(){ + codespell --toml pyproject.toml +} + +# Spelling check of files that differ from main branch. +spell_check_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause ruff to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ + codespell + fi +} + +# Run Codespell +## This flag runs spell check of individual files. --files *must* be the first command line +## arg to use this option. +if [[ "$1" == '--files' ]]; then + spell_check "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is linted. +elif [[ "$1" == '--all' ]]; then + spell_check_all +else + # Check spelling only of the files that changed in last commit. + spell_check_changed +fi +echo 'BitBLAS codespell: Done' + +echo 'bitblas ruff: Check Start' +# Lint specified files +lint() { + ruff "$@" +} + +# Lint files that differ from main branch. Ignores dirs that are not slated +# for autolint yet. +lint_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause ruff to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ + ruff + fi + +} + +# Run Ruff +### This flag lints individual files. --files *must* be the first command line +### arg to use this option. +if [[ "$1" == '--files' ]]; then + lint "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is linted. +elif [[ "$1" == '--all' ]]; then + lint BitBLAS tests +else + # Format only the files that changed in last commit. + lint_changed +fi + +if ! git diff --quiet &>/dev/null; then + echo 'Reformatted files. Please review and stage the changes.' + echo 'Changes not staged for commit:' + echo + git --no-pager diff --name-only + + exit 1 +fi + +echo 'bitblas ruff: Done' + +echo 'bitblas: All checks passed' diff --git a/images/figures/end2end_llama_13b_auto_gptq.png b/images/figures/end2end_llama_13b_auto_gptq.png new file mode 100644 index 000000000..ff003f089 Binary files /dev/null and b/images/figures/end2end_llama_13b_auto_gptq.png differ diff --git a/images/figures/end2end_llama_13b_vllm.png b/images/figures/end2end_llama_13b_vllm.png new file mode 100644 index 000000000..68dc8f2c7 Binary files /dev/null and b/images/figures/end2end_llama_13b_vllm.png differ diff --git a/images/figures/end2end_llama_70B_vllm.png b/images/figures/end2end_llama_70B_vllm.png new file mode 100644 index 000000000..33b237d00 Binary files /dev/null and b/images/figures/end2end_llama_70B_vllm.png differ diff --git a/images/figures/end2end_llama_70b_auto_gptq.png b/images/figures/end2end_llama_70b_auto_gptq.png new file mode 100644 index 000000000..f3942a158 Binary files /dev/null and b/images/figures/end2end_llama_70b_auto_gptq.png differ diff --git a/images/figures/op_benchmark_3090_af4_gemv.png b/images/figures/op_benchmark_3090_af4_gemv.png deleted file mode 100644 index 1d03da5dd..000000000 Binary files a/images/figures/op_benchmark_3090_af4_gemv.png and /dev/null differ diff --git a/images/figures/op_benchmark_3090_fp16_gemm.png b/images/figures/op_benchmark_3090_fp16_gemm.png index a21e64567..f4ca05628 100644 Binary files a/images/figures/op_benchmark_3090_fp16_gemm.png and b/images/figures/op_benchmark_3090_fp16_gemm.png differ diff --git a/images/figures/op_benchmark_3090_fp16_gemm_e8.png b/images/figures/op_benchmark_3090_fp16_gemm_e8.png new file mode 100644 index 000000000..106f113d5 Binary files /dev/null and b/images/figures/op_benchmark_3090_fp16_gemm_e8.png differ diff --git a/images/figures/op_benchmark_3090_nf4_gemm.png b/images/figures/op_benchmark_3090_nf4_gemm.png new file mode 100644 index 000000000..a1e755634 Binary files /dev/null and b/images/figures/op_benchmark_3090_nf4_gemm.png differ diff --git a/images/figures/op_benchmark_3090_nf4_gemv.png b/images/figures/op_benchmark_3090_nf4_gemv.png new file mode 100644 index 000000000..88fd8f134 Binary files /dev/null and b/images/figures/op_benchmark_3090_nf4_gemv.png differ diff --git a/images/figures/op_benchmark_3090_s8_gemm.png b/images/figures/op_benchmark_3090_s8_gemm.png index 2d114e550..9f46ec50f 100644 Binary files a/images/figures/op_benchmark_3090_s8_gemm.png and b/images/figures/op_benchmark_3090_s8_gemm.png differ diff --git a/images/figures/op_benchmark_3090_s8_gemm_e8.png b/images/figures/op_benchmark_3090_s8_gemm_e8.png new file mode 100644 index 000000000..a62bee325 Binary files /dev/null and b/images/figures/op_benchmark_3090_s8_gemm_e8.png differ diff --git a/images/figures/op_benchmark_a100_int1_scaling.png b/images/figures/op_benchmark_a100_int1_scaling.png index 09fbf921d..6f14cfc70 100644 Binary files a/images/figures/op_benchmark_a100_int1_scaling.png and b/images/figures/op_benchmark_a100_int1_scaling.png differ diff --git a/images/figures/op_benchmark_a100_wq_gemm_e7.png b/images/figures/op_benchmark_a100_wq_gemm_e7.png new file mode 100644 index 000000000..1388dcc66 Binary files /dev/null and b/images/figures/op_benchmark_a100_wq_gemm_e7.png differ diff --git a/images/figures/op_benchmark_a100_wq_gemm_e8.png b/images/figures/op_benchmark_a100_wq_gemm_e8.png new file mode 100644 index 000000000..fc869dba6 Binary files /dev/null and b/images/figures/op_benchmark_a100_wq_gemm_e8.png differ diff --git a/images/figures/op_benchmark_a100_wq_gemv_e7.png b/images/figures/op_benchmark_a100_wq_gemv_e7.png new file mode 100644 index 000000000..21933677e Binary files /dev/null and b/images/figures/op_benchmark_a100_wq_gemv_e7.png differ diff --git a/images/figures/op_benchmark_a100_wq_gemv_e8.png b/images/figures/op_benchmark_a100_wq_gemv_e8.png new file mode 100644 index 000000000..21b369bd1 Binary files /dev/null and b/images/figures/op_benchmark_a100_wq_gemv_e8.png differ diff --git a/images/figures/op_benchmark_consistent_gemm_fp16.png b/images/figures/op_benchmark_consistent_gemm_fp16.png new file mode 100644 index 000000000..a833b9b84 Binary files /dev/null and b/images/figures/op_benchmark_consistent_gemm_fp16.png differ diff --git a/images/figures/op_benchmark_consistent_gemm_int8.png b/images/figures/op_benchmark_consistent_gemm_int8.png new file mode 100644 index 000000000..1b5869f3f Binary files /dev/null and b/images/figures/op_benchmark_consistent_gemm_int8.png differ diff --git a/images/gif/FasterTransformer.gif b/images/gif/FasterTransformer.gif new file mode 100644 index 000000000..f82a67930 Binary files /dev/null and b/images/gif/FasterTransformer.gif differ diff --git a/integration/bitdistiller/kernel_generator_dynzeros.py b/integration/bitdistiller/kernel_generator_dynzeros.py index 499a0d8ce..3d1ea4ecb 100644 --- a/integration/bitdistiller/kernel_generator_dynzeros.py +++ b/integration/bitdistiller/kernel_generator_dynzeros.py @@ -6,7 +6,7 @@ import tvm from tvm import IRModule from tvm.target import Target -from bitblas.utils import match_global_kernel, get_target_from_env +from bitblas.utils import match_global_kernel, auto_detect_nvidia_target from bitblas.base.analysis import get_reduction_blocks from bitblas.ops import Operator from bitblas.ops.matmul_dequantize import ( @@ -19,38 +19,35 @@ decode_i2_to_f16_scale_zeros, decode_i4_to_f16, decode_i4_to_f16_scale, - decode_i4_to_f16_scale_zeros + decode_i4_to_f16_scale_zeros, ) + bit = 2 mask = (1 << bit) - 1 group_size = 128 - ft_shapes = [ [1, 15360, 5120], [128, 15360, 5120], ] - -target = tvm.target.Target(get_target_from_env()) +target = tvm.target.Target(auto_detect_nvidia_target()) def get_template_path(): cur_dir = os.path.dirname(os.path.abspath(__file__)) - return os.path.join( - cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template" - ) + return os.path.join(cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template") template_path = get_template_path() def get_codegen_result(ops: Operator, target: Target): - code = ops.codegen(target=target) + code = ops.get_source(target=target) return code -def get_thread_block_infomation(mod: IRModule): +def get_thread_block_information(mod: IRModule): sch = tvm.tir.Schedule(mod) root_block = sch.get_block("root") child_blocks = sch.get_child_blocks(root_block) @@ -80,6 +77,7 @@ def get_thread_block_infomation(mod: IRModule): grid_info[2] = extent return block_info, grid_info + kernel_body = "" kernel_call = "" for M, N, K in ft_shapes: @@ -115,13 +113,11 @@ def get_thread_block_infomation(mod: IRModule): index = code.index("{", index) function_body = declarations + code[index:] - # get block infomation from mod - block_size, grid_size = get_thread_block_infomation(matmul.optimized_func) + # get block information from mod + block_size, grid_size = get_thread_block_information(matmul.optimized_func) if M != 1 and block_size[0] == 1: block_size[0] = 32 - new_kernel_name = ( - f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt" - ) + new_kernel_name = f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt" Qweight_bytes = N * K // 8 * bit Scale_bytes = N * K // group_size * 2 function_body = function_body.replace("main_kernel", new_kernel_name) @@ -148,18 +144,17 @@ def get_thread_block_infomation(mod: IRModule): """ kernel_call += real_call - # make output cur_dir = os.path.dirname(os.path.abspath(__file__)) -ladder_path = os.path.join(cur_dir, f"kenrel_output") +ladder_path = os.path.join(cur_dir, "kenrel_output") if not os.path.exists(ladder_path): os.makedirs(ladder_path) -ladder_kernel_path = os.path.join(ladder_path, f"ladder_kernel.cu") -ladder_header_path = os.path.join(ladder_path, f"ladder_kernel.h") +ladder_kernel_path = os.path.join(ladder_path, "ladder_kernel.cu") +ladder_header_path = os.path.join(ladder_path, "ladder_kernel.h") -with open(template_path, mode="r", encoding="utf-8") as r_f, open( - ladder_kernel_path, mode="w", encoding="utf8" -) as w_f: +with open( + template_path, mode="r", encoding="utf-8") as r_f, open( + ladder_kernel_path, mode="w", encoding="utf8") as w_f: template_content = r_f.read() template = Template(template_content) data = template.substitute(kernel_body=kernel_body, kernel_call=kernel_call) @@ -174,9 +169,7 @@ def get_thread_block_infomation(mod: IRModule): return (v1 << 16) | v0; } """ -with open( - ladder_header_path, mode="w", encoding="utf8" -) as w_f: +with open(ladder_header_path, mode="w", encoding="utf8") as w_f: headers = f"""// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #ifndef __LADDER_KERNEL_H__ diff --git a/integration/bitdistiller/kernel_generator_dynzeros_original.py b/integration/bitdistiller/kernel_generator_dynzeros_original.py index e4c435904..201395837 100644 --- a/integration/bitdistiller/kernel_generator_dynzeros_original.py +++ b/integration/bitdistiller/kernel_generator_dynzeros_original.py @@ -6,52 +6,45 @@ import tvm from tvm import IRModule from tvm.target import Target -from bitblas.utils import match_global_kernel, get_target_from_env +from bitblas.utils import match_global_kernel, auto_detect_nvidia_target from bitblas.base.analysis import get_reduction_blocks from bitblas.ops import Operator from bitblas.ops.matmul_dequantize import ( MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig, ) -from bitblas.gpu.intrin.lop3 import ( - decode_i2_to_f16, - decode_i2_to_f16_scale, - decode_i2_to_f16_scale_zeros_original, - decode_i2_to_f16_scale_zeros_rescale, - decode_i4_to_f16, - decode_i4_to_f16_scale_zeros_original, - decode_i4_to_f16_scale_zeros_rescale -) +from bitblas.gpu.intrin.lop3 import (decode_i2_to_f16, decode_i2_to_f16_scale, + decode_i2_to_f16_scale_zeros_original, + decode_i2_to_f16_scale_zeros_rescale, decode_i4_to_f16, + decode_i4_to_f16_scale_zeros_original, + decode_i4_to_f16_scale_zeros_rescale) + bit = 2 mask = (1 << bit) - 1 group_size = 128 - ft_shapes = [ [1, 15360, 5120], [128, 15360, 5120], ] - -target = tvm.target.Target(get_target_from_env()) +target = tvm.target.Target(auto_detect_nvidia_target()) def get_template_path(): cur_dir = os.path.dirname(os.path.abspath(__file__)) - return os.path.join( - cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template" - ) + return os.path.join(cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template") template_path = get_template_path() def get_codegen_result(ops: Operator, target: Target): - code = ops.codegen(target=target) + code = ops.get_source(target=target) return code -def get_thread_block_infomation(mod: IRModule): +def get_thread_block_information(mod: IRModule): sch = tvm.tir.Schedule(mod) root_block = sch.get_block("root") child_blocks = sch.get_child_blocks(root_block) @@ -81,6 +74,7 @@ def get_thread_block_infomation(mod: IRModule): grid_info[2] = extent return block_info, grid_info + kernel_body = "" kernel_call = "" for M, N, K in ft_shapes: @@ -99,7 +93,7 @@ def get_thread_block_infomation(mod: IRModule): group_size=group_size, fast_decoding=True, with_bias=False, - zeros_type="original", + zeros_mode="original", propagate_a=False, propagate_b=False, layout="nt", @@ -117,13 +111,11 @@ def get_thread_block_infomation(mod: IRModule): index = code.index("{", index) function_body = declarations + code[index:] - # get block infomation from mod - block_size, grid_size = get_thread_block_infomation(matmul.optimized_func) + # get block information from mod + block_size, grid_size = get_thread_block_information(matmul.optimized_func) if M != 1 and block_size[0] == 1: block_size[0] = 32 - new_kernel_name = ( - f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt" - ) + new_kernel_name = (f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt") Qweight_bytes = N * K // 8 * bit Scale_bytes = N * K // group_size * 2 function_body = function_body.replace("main_kernel", new_kernel_name) @@ -150,18 +142,17 @@ def get_thread_block_infomation(mod: IRModule): """ kernel_call += real_call - # make output cur_dir = os.path.dirname(os.path.abspath(__file__)) -ladder_path = os.path.join(cur_dir, f"kenrel_output") +ladder_path = os.path.join(cur_dir, "kenrel_output") if not os.path.exists(ladder_path): os.makedirs(ladder_path) -ladder_kernel_path = os.path.join(ladder_path, f"ladder_kernel.cu") -ladder_header_path = os.path.join(ladder_path, f"ladder_kernel.h") +ladder_kernel_path = os.path.join(ladder_path, "ladder_kernel.cu") +ladder_header_path = os.path.join(ladder_path, "ladder_kernel.h") -with open(template_path, mode="r", encoding="utf-8") as r_f, open( - ladder_kernel_path, mode="w", encoding="utf8" -) as w_f: +with open( + template_path, mode="r", encoding="utf-8") as r_f, open( + ladder_kernel_path, mode="w", encoding="utf8") as w_f: template_content = r_f.read() template = Template(template_content) data = template.substitute(kernel_body=kernel_body, kernel_call=kernel_call) @@ -176,9 +167,7 @@ def get_thread_block_infomation(mod: IRModule): return (v1 << 16) | v0; } """ -with open( - ladder_header_path, mode="w", encoding="utf8" -) as w_f: +with open(ladder_header_path, mode="w", encoding="utf8") as w_f: headers = f"""// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #ifndef __LADDER_KERNEL_H__ diff --git a/integration/fastertransformer/kernel_generator.py b/integration/fastertransformer/kernel_generator.py index 3d0e8ce4b..f633a4509 100644 --- a/integration/fastertransformer/kernel_generator.py +++ b/integration/fastertransformer/kernel_generator.py @@ -6,8 +6,7 @@ import tvm from tvm import IRModule from tvm.target import Target -from bitblas.utils import match_global_kernel, get_target_from_env -from bitblas.base.analysis import get_reduction_blocks +from bitblas.utils import match_global_kernel, auto_detect_nvidia_target from bitblas.ops import Operator from bitblas.ops.matmul_dequantize import ( MatmulWeightOnlyDequantize, @@ -19,37 +18,34 @@ decode_i4_to_f16, decode_i4_to_f16_scale, ) + bit = 2 mask = (1 << bit) - 1 group_size = 128 - ft_shapes = [ # [1, 5120, 5120], # [1, 15360, 5120], [128, 15360, 5120], ] - -target = tvm.target.Target(get_target_from_env()) +target = tvm.target.Target(auto_detect_nvidia_target()) def get_template_path(): cur_dir = os.path.dirname(os.path.abspath(__file__)) - return os.path.join( - cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template" - ) + return os.path.join(cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template") template_path = get_template_path() def get_codegen_result(ops: Operator, target: Target): - code = ops.codegen(target=target) + code = ops.get_source(target=target) return code -def get_thread_block_infomation(mod: IRModule): +def get_thread_block_information(mod: IRModule): sch = tvm.tir.Schedule(mod) root_block = sch.get_block("root") child_blocks = sch.get_child_blocks(root_block) @@ -77,6 +73,7 @@ def get_thread_block_infomation(mod: IRModule): grid_info[2] = extent return block_info, grid_info + kernel_body = "" kernel_call = "" for M, N, K in ft_shapes: @@ -111,13 +108,11 @@ def get_thread_block_infomation(mod: IRModule): index = code.index("{", index) function_body = declarations + code[index:] - # get block infomation from mod - block_size, grid_size = get_thread_block_infomation(matmul.optimized_func) + # get block information from mod + block_size, grid_size = get_thread_block_information(matmul.optimized_func) if M != 1 and block_size[0] == 1: block_size[0] = 32 - new_kernel_name = ( - f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt" - ) + new_kernel_name = (f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt") Qweight_bytes = N * K // 8 * bit function_body = function_body.replace("main_kernel", new_kernel_name) call = f""" @@ -143,18 +138,17 @@ def get_thread_block_infomation(mod: IRModule): """ kernel_call += real_call - # make output cur_dir = os.path.dirname(os.path.abspath(__file__)) -ladder_path = os.path.join(cur_dir, f"kenrel_output") +ladder_path = os.path.join(cur_dir, "kenrel_output") if not os.path.exists(ladder_path): os.makedirs(ladder_path) -ladder_kernel_path = os.path.join(ladder_path, f"ladder_kernel.cu") -ladder_header_path = os.path.join(ladder_path, f"ladder_kernel.h") +ladder_kernel_path = os.path.join(ladder_path, "ladder_kernel.cu") +ladder_header_path = os.path.join(ladder_path, "ladder_kernel.h") -with open(template_path, mode="r", encoding="utf-8") as r_f, open( - ladder_kernel_path, mode="w", encoding="utf8" -) as w_f: +with open( + template_path, mode="r", encoding="utf-8") as r_f, open( + ladder_kernel_path, mode="w", encoding="utf8") as w_f: template_content = r_f.read() template = Template(template_content) data = template.substitute(kernel_body=kernel_body, kernel_call=kernel_call) @@ -169,9 +163,7 @@ def get_thread_block_infomation(mod: IRModule): return (v1 << 16) | v0; } """ -with open( - ladder_header_path, mode="w", encoding="utf8" -) as w_f: +with open(ladder_header_path, mode="w", encoding="utf8") as w_f: headers = f"""// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #ifndef __LADDER_KERNEL_H__ diff --git a/integration/fastertransformer/kernel_generator_dynzeros.py b/integration/fastertransformer/kernel_generator_dynzeros.py index 9b813bffa..fc319f218 100644 --- a/integration/fastertransformer/kernel_generator_dynzeros.py +++ b/integration/fastertransformer/kernel_generator_dynzeros.py @@ -6,7 +6,7 @@ import tvm from tvm import IRModule from tvm.target import Target -from bitblas.utils import match_global_kernel, get_target_from_env +from bitblas.utils import match_global_kernel, auto_detect_nvidia_target from bitblas.base.analysis import get_reduction_blocks from bitblas.ops import Operator from bitblas.ops.matmul_dequantize import ( @@ -19,36 +19,33 @@ decode_i4_to_f16, decode_i4_to_f16_scale, ) + bit = 2 mask = (1 << bit) - 1 group_size = 128 - ft_shapes = [ [1, 15360, 5120], [128, 15360, 5120], ] - -target = tvm.target.Target(get_target_from_env()) +target = tvm.target.Target(auto_detect_nvidia_target()) def get_template_path(): cur_dir = os.path.dirname(os.path.abspath(__file__)) - return os.path.join( - cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template" - ) + return os.path.join(cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template") template_path = get_template_path() def get_codegen_result(ops: Operator, target: Target): - code = ops.codegen(target=target) + code = ops.get_source(target=target) return code -def get_thread_block_infomation(mod: IRModule): +def get_thread_block_information(mod: IRModule): sch = tvm.tir.Schedule(mod) root_block = sch.get_block("root") child_blocks = sch.get_child_blocks(root_block) @@ -78,6 +75,7 @@ def get_thread_block_infomation(mod: IRModule): grid_info[2] = extent return block_info, grid_info + kernel_body = "" kernel_call = "" for M, N, K in ft_shapes: @@ -113,14 +111,12 @@ def get_thread_block_infomation(mod: IRModule): index = code.index("{", index) function_body = declarations + code[index:] - # get block infomation from mod - block_size, grid_size = get_thread_block_infomation(matmul.optimized_func) + # get block information from mod + block_size, grid_size = get_thread_block_information(matmul.optimized_func) if M != 1 and block_size[0] == 1: block_size[0] = 32 - new_kernel_name = ( - f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt" - ) + new_kernel_name = (f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt") Qweight_bytes = N * K // 8 * bit Scale_bytes = N * K // group_size * 2 function_body = function_body.replace("main_kernel", new_kernel_name) @@ -147,18 +143,17 @@ def get_thread_block_infomation(mod: IRModule): """ kernel_call += real_call - # make output cur_dir = os.path.dirname(os.path.abspath(__file__)) -ladder_path = os.path.join(cur_dir, f"kenrel_output") +ladder_path = os.path.join(cur_dir, "kenrel_output") if not os.path.exists(ladder_path): os.makedirs(ladder_path) -ladder_kernel_path = os.path.join(ladder_path, f"ladder_kernel.cu") -ladder_header_path = os.path.join(ladder_path, f"ladder_kernel.h") +ladder_kernel_path = os.path.join(ladder_path, "ladder_kernel.cu") +ladder_header_path = os.path.join(ladder_path, "ladder_kernel.h") -with open(template_path, mode="r", encoding="utf-8") as r_f, open( - ladder_kernel_path, mode="w", encoding="utf8" -) as w_f: +with open( + template_path, mode="r", encoding="utf-8") as r_f, open( + ladder_kernel_path, mode="w", encoding="utf8") as w_f: template_content = r_f.read() template = Template(template_content) data = template.substitute(kernel_body=kernel_body, kernel_call=kernel_call) @@ -173,9 +168,7 @@ def get_thread_block_infomation(mod: IRModule): return (v1 << 16) | v0; } """ -with open( - ladder_header_path, mode="w", encoding="utf8" -) as w_f: +with open(ladder_header_path, mode="w", encoding="utf8") as w_f: headers = f"""// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #ifndef __LADDER_KERNEL_H__ diff --git a/integration/pytorch/bitblas_linear.py b/integration/pytorch/bitblas_linear.py index 3e3cb67ca..c315cf6c4 100644 --- a/integration/pytorch/bitblas_linear.py +++ b/integration/pytorch/bitblas_linear.py @@ -7,11 +7,12 @@ import torch import torch.nn as nn +from typing import List, Union, Literal, Optional logger = getLogger(__name__) try: - import bitblas + import bitblas # noqa: F401 except ImportError as e: bitblas_import_exception = e @@ -22,18 +23,17 @@ def error_raiser_bitblas(*args, **kwargs): autogptq_bitblas_cuda = bitblas_import_exception -from bitblas.utils import get_target_from_env -from bitblas.ops.matmul import MatmulConfig, Matmul -from typing import List, Union, Literal, Optional +from bitblas.utils import auto_detect_nvidia_target # noqa: E402 +from bitblas.ops.matmul import MatmulConfig, Matmul # noqa: E402 class Linear(nn.Module): def __init__( self, - infeatures: int, - outfeatures: int, - opt_features: Union[int, List[int]] = 1, + in_features: int, + out_features: int, + opt_M: Union[int, List[int]] = 1, bias: bool = False, dtype: torch.dtype = torch.float16, propagate_a: bool = False, @@ -44,7 +44,7 @@ def __init__( target: Optional[str] = None, ): """ - @opt_features: optimze range of the input shape for dynamic symbolic + @opt_M: optimize range of the input shape for dynamic symbolic if the input shape is a range, we will optimize the matmul with dynamic symbolic. if the input shape is int, we will optimize the matmul with static symbolic. """ @@ -52,16 +52,16 @@ def __init__( if trainable: raise NotImplementedError("Bitblas does not support train.") - self.infeatures = infeatures - self.outfeatures = outfeatures - self.opt_features = opt_features + self.in_features = in_features + self.out_features = out_features + self.opt_M = opt_M self.dtype = dtype self.propagate_a = propagate_a self.propagate_b = propagate_b self.enable_tuning = enable_tuning - self.weight = nn.Parameter(torch.empty((outfeatures, infeatures), dtype=dtype)) + self.weight = nn.Parameter(torch.empty((out_features, in_features), dtype=dtype)) if bias: - self.bias = nn.Parameter(torch.empty(outfeatures, dtype=dtype)) + self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype)) else: self.register_parameter("bias", None) @@ -73,11 +73,11 @@ def __init__( assert dtype in BITBLAS_DTYPES, f"Unsupported dtype: {dtype}" bitblas_dtype = BITBLAS_DTYPES[dtype] - self.target = target or get_target_from_env() + self.target = target or auto_detect_nvidia_target() matmul_config = MatmulConfig( - M=self.opt_features, - N=self.outfeatures, - K=self.infeatures, + M=self.opt_M, + N=self.out_features, + K=self.in_features, in_dtype=bitblas_dtype, out_dtype=bitblas_dtype, accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, @@ -104,22 +104,21 @@ def reset_parameters(self): if self.bias is not None: self.bias.uniform_(-stdv, stdv) - def forward(self, A, Output=None): + def forward(self, A, output=None): args = [ A, self.weight, ] if self.bias is not None: args.append(self.bias) - if Output is None: - Output = torch.empty( - A.shape[:-1] + (self.outfeatures,), dtype=A.dtype, device=A.device - ) - args.append(Output) + if output is None: + output = torch.empty( + A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) + args.append(output) self.bitblas_matmul(*args) - return Output + return output __all__ = ["Linear"] diff --git a/integration/pytorch/bitblas_quant_linear.py b/integration/pytorch/bitblas_quant_linear.py index a46e58959..c0cdac611 100644 --- a/integration/pytorch/bitblas_quant_linear.py +++ b/integration/pytorch/bitblas_quant_linear.py @@ -7,11 +7,10 @@ import torch import torch.nn as nn - logger = getLogger(__name__) try: - import bitblas + import bitblas # noqa: F401 except ImportError as e: bitblas_import_exception = e @@ -27,7 +26,7 @@ def error_raiser_bitblas(*args, **kwargs): MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize, ) -from bitblas.utils import get_target_from_env +from bitblas.utils import auto_detect_nvidia_target from typing import List, Union, Literal, Optional @@ -38,63 +37,64 @@ def __init__( self, bits: int, group_size: int, - infeatures: int, - outfeatures: int, + in_features: int, + out_features: int, bias: bool, enable_tuning: bool = False, fast_decoding: bool = False, propagate_a: bool = False, propagate_b: bool = False, - opt_features: Union[int, List[int]] = [1, 16, 32], + opt_M: Optional[Union[int, List[int]]] = None, layout: Literal["nt"] = "nt", trainable=False, **kwargs, ): super().__init__() if group_size == -1: - group_size = infeatures - if infeatures % 128 != 0 or outfeatures % 256 != 0: - raise ValueError( - "`infeatures` must be divisible by 128 and `outfeatures` by 256." - ) + group_size = in_features + if in_features % 128 != 0 or out_features % 256 != 0: + raise ValueError("`in_features` must be divisible by 128 and `out_features` by 256.") if bits not in [1, 2, 4]: raise NotImplementedError("Only 1/2/4 bits are supported.") - if infeatures % group_size != 0: - raise ValueError("`infeatures` must be divisible by `group_size`.") + if in_features % group_size != 0: + raise ValueError("`in_features` must be divisible by `group_size`.") if trainable: raise NotImplementedError("Bitblas does not support train.") + if opt_M is None: + opt_M = [1, 32, 64] self.bits = bits storage_nbit = 8 # assume int8 storage n_float_per_elem = storage_nbit // bits - self.opt_features = opt_features - self.infeatures = infeatures - self.outfeatures = outfeatures - self.group_size = group_size if group_size != -1 else infeatures + self.opt_M = opt_M + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size if group_size != -1 else in_features self.register_buffer( "qweight", torch.empty( - (self.outfeatures, self.infeatures // n_float_per_elem), + (self.out_features, self.in_features // n_float_per_elem), dtype=torch.int8, ), ) self.register_buffer( "scales", torch.empty( - (self.outfeatures, self.infeatures // self.group_size), dtype=torch.half + (self.out_features, self.in_features // self.group_size), + dtype=torch.half, ), ) self.register_buffer( "zeros", torch.full( - (self.outfeatures, self.infeatures // self.group_size), + (self.out_features, self.in_features // self.group_size), 0, dtype=torch.float16, ), ) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.half)) + self.register_buffer("bias", torch.zeros((out_features), dtype=torch.half)) else: self.bias = None @@ -110,11 +110,11 @@ def __init__( } assert dtype in BITBLAS_DTYPES, f"Unsupported dtype: {dtype}" bitblas_dtype = BITBLAS_DTYPES[dtype] - self.target = get_target_from_env() + self.target = auto_detect_nvidia_target() matmul_config = MatmulWeightOnlyDequantizeConfig( - M=self.opt_features, - N=self.outfeatures, - K=self.infeatures, + M=self.opt_M, + N=self.out_features, + K=self.in_features, in_dtype=bitblas_dtype, out_dtype=bitblas_dtype, accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, @@ -129,12 +129,10 @@ def __init__( propagate_a=propagate_a, propagate_b=propagate_b, layout=layout, - zeros_type="original", + zeros_mode="original", ) # optimize target shapes for dynamic symbolic - self.bitblas_matmul = MatmulWeightOnlyDequantize( - matmul_config, target=self.target - ) + self.bitblas_matmul = MatmulWeightOnlyDequantize(matmul_config, target=self.target) if enable_tuning: self.bitblas_matmul.hardware_aware_finetune(topk=20) @@ -145,7 +143,7 @@ def reset_parameters(self): self.qweight = torch.randint_like( self.qweight, 0, - 2 ** (self.bits - 1) - 1, + 2**(self.bits - 1) - 1, dtype=torch.int8, device=self.qweight.device, ) @@ -160,7 +158,7 @@ def post_init(self): def pack(self, linear, scales, zeros=None): """Pack a fake-quantized linear layer into this actual Bitblas representation. @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`) - @scales: corresponding quantization scales of shape `(infeatures, groups)` + @scales: corresponding quantization scales of shape `(in_features, groups)` """ if linear.weight.dtype != torch.half: raise ValueError("Only `torch.half` weights are supported.") @@ -177,20 +175,16 @@ def pack(self, linear, scales, zeros=None): # do permutation on weight intweight = [] - for idx in range(self.infeatures): + for idx in range(self.in_features): g_idx = idx // self.group_size intweight.append( - torch.round((w[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to( - torch.int - )[:, None] - ) + torch.round( + (w[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None]) intweight = torch.cat(intweight, dim=1) intweight = intweight.contiguous() intweight = intweight.cpu().numpy().astype(np.int8) # quantize to 4bit - qw_np = general_compress( - intweight, source_bits=self.bits, storage_dtype=np.int8 - ) + qw_np = general_compress(intweight, source_bits=self.bits, storage_dtype=np.int8) # do interleave for fast type conversion if self.fast_type_conversion: qw_np = interleave_weight(qw_np, nbits=self.bits, target_dtype="float16") @@ -205,19 +199,18 @@ def pack(self, linear, scales, zeros=None): if self.bias is not None: self.bias[:] = linear.bias.data.to(self.bias.device).contiguous() - def forward(self, A, Output=None): + def forward(self, A, output=None): args = [A, self.qweight, self.scales, self.zeros] if self.bias is not None: args.append(self.bias) - if Output is None: - Output = torch.empty( - A.shape[:-1] + (self.qweight.shape[0],), dtype=A.dtype, device=A.device - ) - args.append(Output) + if output is None: + output = torch.empty( + A.shape[:-1] + (self.qweight.shape[0],), dtype=A.dtype, device=A.device) + args.append(output) self.bitblas_matmul(*args) - return Output + return output __all__ = ["QuantLinear"] diff --git a/integration/pytorch/test_bitblas_linear.py b/integration/pytorch/test_bitblas_linear.py index 2b9dc92d2..18a969fc5 100644 --- a/integration/pytorch/test_bitblas_linear.py +++ b/integration/pytorch/test_bitblas_linear.py @@ -1,18 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import bitblas from bitblas_linear import Linear as BitBLASLinear import torch import time import numpy as np import torch.nn as nn -import bitblas import pytest torch.manual_seed(0) @pytest.mark.parametrize( - "m, infeatures, outfeatures, bias", + "m, in_features, out_features, bias", [ (1, 1024, 1024, False), (1, 1024, 1024, True), @@ -20,16 +20,14 @@ (1024, 1024, 1024, True), ], ) -def test_correctness_static_shape(m, infeatures, outfeatures, bias): - linear_torch = ( - nn.Linear(infeatures, outfeatures, bias=bias).to(torch.float16).cuda() - ) +def test_correctness_static_shape(m, in_features, out_features, bias): + linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( - infeatures, - outfeatures, + in_features, + out_features, bias=bias, dtype=torch.float16, - opt_features=m, + opt_M=m, enable_tuning=False, ).cuda() @@ -39,7 +37,7 @@ def test_correctness_static_shape(m, infeatures, outfeatures, bias): linear_bitblas.bias = nn.Parameter(linear_torch.bias.clone()) with torch.no_grad(): - input_data = torch.randn(m, infeatures, dtype=torch.float16).cuda() + input_data = torch.randn(m, in_features, dtype=torch.float16).cuda() output_torch = linear_torch(input_data) output_bitblas = linear_bitblas(input_data) @@ -50,7 +48,7 @@ def profile(model, input_data): model = model.cuda() model.eval() output = torch.empty( - input_data.shape[:-1] + (model.outfeatures,), + input_data.shape[:-1] + (model.out_features,), dtype=input_data.dtype, device=input_data.device, ) @@ -74,31 +72,29 @@ def get_runtime(num_repeats=1): @pytest.mark.parametrize( - "m, infeatures, outfeatures, bias", + "m, in_features, out_features, bias", [ (1, 1024, 1024, False), (1024, 1024, 1024, False), ], ) -def test_profile_performance(m, infeatures, outfeatures, bias): +def test_profile_performance(m, in_features, out_features, bias): linear_bitblas = BitBLASLinear( - infeatures, - outfeatures, + in_features, + out_features, bias=bias, dtype=torch.float16, - opt_features=m, + opt_M=m, enable_tuning=False, ).cuda() with torch.no_grad(): - input_data = torch.randn(m, infeatures, dtype=torch.float16).cuda() + input_data = torch.randn(m, in_features, dtype=torch.float16).cuda() torch_latency = profile(linear_bitblas, input_data) bitblas_latency = linear_bitblas.bitblas_matmul.profile_latency() print(f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}") - assert ( - abs(torch_latency - bitblas_latency) / torch_latency < 0.1 - ), f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}" + assert (abs(torch_latency - bitblas_latency) / torch_latency < + 0.1), f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}" if __name__ == "__main__": - # bitblas.testing.main() - test_profile_performance(1, 16384, 16384, False) + bitblas.testing.main() diff --git a/integration/pytorch/test_bitblas_quant_linear.py b/integration/pytorch/test_bitblas_quant_linear.py index bb0381e7e..1db9faa00 100644 --- a/integration/pytorch/test_bitblas_quant_linear.py +++ b/integration/pytorch/test_bitblas_quant_linear.py @@ -11,8 +11,7 @@ # !pip install auto-gptq from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( - QuantLinear as CudaOldQuantLinear, -) + QuantLinear as CudaOldQuantLinear,) torch.manual_seed(0) @@ -64,33 +63,33 @@ def reshape(w): @pytest.mark.parametrize( - "m, infeatures, outfeatures, bits, group_size, bias", + "m, in_features, out_features, bits, group_size, bias", [ (1, 1024, 4096, 4, -1, False), (1, 1024, 4096, 4, 128, False), (1, 1024, 4096, 4, 128, True), ], ) -def test_quantization_accuracy(m, infeatures, outfeatures, bits, group_size, bias): - original_w, linear, s, qw = gen_quant4(infeatures, outfeatures, group_size) +def test_quantization_accuracy(m, in_features, out_features, bits, group_size, bias): + original_w, linear, s, qw = gen_quant4(in_features, out_features, group_size) if group_size == -1: - group_size = infeatures - zeros = torch.full((infeatures // group_size, outfeatures), 7, dtype=torch.int32) + group_size = in_features + zeros = torch.full((in_features // group_size, out_features), 7, dtype=torch.int32) bitblas_zeros = zeros.clone().T cuda_old_linear = CudaOldQuantLinear( bits=bits, group_size=group_size, - infeatures=infeatures, - outfeatures=outfeatures, + in_features=in_features, + out_features=out_features, bias=bias, ) cuda_old_linear.pack(linear, s.T, zeros.T, g_idx=None) linear_module = torch.nn.Linear( - in_features=infeatures, - out_features=outfeatures, + in_features=in_features, + out_features=out_features, bias=bias, dtype=torch.float16, device="cuda", @@ -98,7 +97,8 @@ def test_quantization_accuracy(m, infeatures, outfeatures, bits, group_size, bia linear_module.weight.data.copy_(linear.weight.data) scales = s.to("cuda") - bitblas_qlinear = QuantLinear(bits, group_size, infeatures, outfeatures, bias, opt_features=m, enable_tuning=True) + bitblas_qlinear = QuantLinear( + bits, group_size, in_features, out_features, bias, opt_M=m, enable_tuning=True) bitblas_qlinear.pack( linear_module.to("cuda"), @@ -106,7 +106,7 @@ def test_quantization_accuracy(m, infeatures, outfeatures, bits, group_size, bia zeros=bitblas_zeros.contiguous().to("cuda"), ) - inp = torch.rand(m, infeatures, dtype=torch.float16, device="cuda") + inp = torch.rand(m, in_features, dtype=torch.float16, device="cuda") cuda_old_linear = cuda_old_linear.to("cuda") bitblas_qlinear = bitblas_qlinear.to("cuda") @@ -123,7 +123,7 @@ def profile(model, input_data): model = model.cuda() model.eval() output = torch.empty( - input_data.shape[:-1] + (model.outfeatures,), + input_data.shape[:-1] + (model.out_features,), dtype=input_data.dtype, device=input_data.device, ) @@ -147,28 +147,30 @@ def get_runtime(num_repeats=1): @pytest.mark.parametrize( - "m, infeatures, outfeatures, bits, group_size, bias", + "m, in_features, out_features, bits, group_size, bias", [ (1, 16384, 16384, 4, -1, False), ], ) -def test_profile_performance(m, infeatures, outfeatures, bits, group_size, bias): +def test_profile_performance(m, in_features, out_features, bits, group_size, bias): bitblas_qlinear = QuantLinear( bits, group_size, - infeatures, - outfeatures, + in_features, + out_features, bias, - opt_features=m, + opt_M=m, enable_tuning=True, ).cuda() with torch.no_grad(): - input_data = torch.randn(m, infeatures, dtype=torch.float16).cuda() + input_data = torch.randn(m, in_features, dtype=torch.float16).cuda() torch_latency = profile(bitblas_qlinear, input_data) bitblas_latency = bitblas_qlinear.bitblas_matmul.profile_latency() - assert abs(torch_latency - bitblas_latency) / torch_latency < 0.1, f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}" + assert abs( + torch_latency - bitblas_latency + ) / torch_latency < 0.1, f"torch_latency: {torch_latency}, bitblas_latency: {bitblas_latency}" if __name__ == "__main__": diff --git a/integration/vLLM/README.md b/integration/vLLM/README.md new file mode 100644 index 000000000..22d773636 --- /dev/null +++ b/integration/vLLM/README.md @@ -0,0 +1 @@ +Please checkout https://github.com/LeiWang1999/vllm-bitblas for details currently. The relative pull request to the official vLLM is still under construction. diff --git a/maint/scripts/check_mit_license.sh b/maint/scripts/check_mit_license.sh index 64cd699e5..f758c2bfc 100755 --- a/maint/scripts/check_mit_license.sh +++ b/maint/scripts/check_mit_license.sh @@ -3,19 +3,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -echo "Check MIT Liscense boilerplate..." +echo "Check MIT License boilerplate..." PWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -# TO source code root +# To source code root pushd "${PWD}/../../" > /dev/null EXITCODE=0 -for SRC_FILE in $(find . -path './3rdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name '*apply_mit_liscense.sh' \ - -not -name '*check_mit_liscense.sh' -and \( -name 'CMakeLists.txt' -or -name '*.cpp' -or -name '*.cu' -or -name '*.h' -or -name '*.hpp' \ - -or -name '*.in' -or -name '*.py' -or -name '*.sh' -or -name '*.dockerfile' -or -name '*.yaml' \) ); do +for SRC_FILE in $(find . -path './3rdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name '*apply_mit_license.sh' \ + -not -name '*check_mit_license.sh' -and \( -name 'CMakeLists.txt' -or -name '*.cpp' -or -name '*.cu' -or -name '*.h' -or -name '*.hpp' \ + -or -name '*.py' -or -name '*.sh' -or -name '*.dockerfile' -or -name '*.yaml' \) ); do + + # Skip files that already contain the Apache License + if grep -q "Apache License" "${SRC_FILE}"; then + continue + fi + if !(grep -q "Copyright (c) Microsoft Corporation." "${SRC_FILE}") || !(grep -q "Licensed under the MIT License." "${SRC_FILE}") \ - || (grep -q -i -P "Microsoft( |)\(c\)" "${SRC_FILE}") || (grep -q "Apache License" "${SRC_FILE}"); then - echo "[ERROR] Require: MIT Liscense biolerplate" "${SRC_FILE}" + || (grep -q -i -P "Microsoft( |)\(c\)" "${SRC_FILE}"); then + echo "[ERROR] Require: MIT License boilerplate" "${SRC_FILE}" EXITCODE=1 fi done diff --git a/maint/scripts/installation.sh b/maint/scripts/installation.sh index 6d426a2f1..8e083326a 100755 --- a/maint/scripts/installation.sh +++ b/maint/scripts/installation.sh @@ -21,6 +21,6 @@ echo "set(USE_LLVM llvm-config-10)" >> config.cmake && echo "set(USE_CUDA ON)" > cmake .. && make -j && cd ../../.. echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc -echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd)/python:$PYTHONPATH" >> ~/.bashrc +echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd)/python:\$PYTHONPATH" >> ~/.bashrc source ~/.bashrc diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..85fd7db04 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[tool.yapf] +based_on_style = "yapf" +column_limit = 100 +indent_width = 4 + +[tool.codespell] +ignore-words-list = "nd, te" + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + # "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + # "I", +] +ignore = [ + # Module level import not at top of file + "E402", + # star imports + "F405", "F403", + # ambigous name + "E741", + # line too long + "E501", + # key in dict.keys() + "SIM118", + # memory leaks + "B019", + # No such file or directory + "E902", +] diff --git a/python/bitblas/__init__.py b/python/bitblas/__init__.py index 56191b4e9..41606e006 100644 --- a/python/bitblas/__init__.py +++ b/python/bitblas/__init__.py @@ -1,41 +1,53 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -"""DLight package provides efficient schedules out-of-box for deep learning workloads.""" -from . import gpu -from .base import ( - fast_tune, - ApplyDefaultSchedule, - ApplyFastTuning, - BlockInfo, - IterInfo, - ScheduleRule, - normalize_prim_func, - try_inline, - try_inline_contiguous_spatial, -) import sys import os -# tvm path is under the root of the project -tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "..", "3rdparty", "tvm", "python" +# installing tvm +install_tvm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm", "python") +if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: + os.environ["PYTHONPATH"] = install_tvm_path + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, install_tvm_path) + +develop_tvm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "..", "3rdparty", "tvm", "python") +if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: + os.environ["PYTHONPATH"] = develop_tvm_path + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, develop_tvm_path) + +from . import gpu # noqa: F401 +from .base import ( + TileDevice, # noqa: F401 + fast_tune, # noqa: F401 + ApplyDefaultSchedule, # noqa: F401 + ApplyFastTuning, # noqa: F401 + BlockInfo, # noqa: F401 + IterInfo, # noqa: F401 + ScheduleRule, # noqa: F401 + normalize_prim_func, # noqa: F401 + try_inline, # noqa: F401 + try_inline_contiguous_spatial, # noqa: F401 ) -if tvm_path not in sys.path: - sys.path.append(tvm_path) -from . import testing +from . import testing # noqa: F401 +from .utils import auto_detect_nvidia_target # noqa: F401 +from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 +from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401 +from .module import Linear # noqa: F401 import logging from tqdm import tqdm - -# target logger into tqdm.write class TqdmLoggingHandler(logging.Handler): + """ Custom logging handler that directs log output to tqdm progress bar to avoid interference. """ + def __init__(self, level=logging.NOTSET): + """ Initialize the handler with an optional log level. """ super().__init__(level) def emit(self, record): + """ Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted. """ try: msg = self.format(record) tqdm.write(msg) @@ -44,20 +56,29 @@ def emit(self, record): def set_log_level(level): + """ Set the logging level for the module's logger. + + Args: + level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO). + """ + if isinstance(level, str): + level = getattr(logging, level.upper(), logging.INFO) logger = logging.getLogger(__name__) logger.setLevel(level) def _init_logger(): + """ Initialize the logger specific for this module with custom settings and a Tqdm-based handler. """ logger = logging.getLogger(__name__) handler = TqdmLoggingHandler() formatter = logging.Formatter( - fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%F %T" - ) + fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") handler.setFormatter(formatter) logger.addHandler(handler) logger.propagate = False - set_log_level(logging.INFO) + set_log_level('WARNING') _init_logger() + +__version__ = "0.0.1" diff --git a/python/bitblas/base/analysis.py b/python/bitblas/base/analysis.py index 56cbfff4e..eb9c19415 100644 --- a/python/bitblas/base/analysis.py +++ b/python/bitblas/base/analysis.py @@ -1,46 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - """Analysis on TIR blocks, loops and functions.""" -from typing import List, Optional, Set, Union, Tuple, Dict +from typing import List, Optional, Set, Union from typing_extensions import Literal -from dataclasses import dataclass -from enum import Enum from tvm import ir, tir, DataType -from tvm.ir import Range -from tvm.tir.analysis import undefined_vars from tvm._ffi import get_global_func from tvm.target.target import Target -from tvm.tir import Schedule, IterVar, Var, PrimExpr +from tvm.tir import Schedule, IterVar from tvm.tir.schedule import BlockRV -def get_reduction_blocks(sch, blocks) -> bool: - # Get the main computation block - def is_reduction(block: BlockRV) -> bool: - block_stmt = sch.get(block) - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - return iter_types == {IterVar.CommReduce, IterVar.DataPar} - - def is_spatial(block: BlockRV) -> bool: - block_stmt = sch.get(block) - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - return iter_types == {IterVar.DataPar} - - # NOTE: We assume there is only one reduction block in the function - # all blocks are required to be spatial or reduction - if not all([is_reduction(block) or is_spatial(block) for block in blocks]): - return None - - # There is only one reduction block - reduction_blocks = [block for block in blocks if is_reduction(block)] - if len(reduction_blocks) != 1: - return None - - return reduction_blocks - - class IterInfo: """Information about a loop/iter var.""" @@ -123,9 +93,7 @@ def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool: if len(r_region) != len(w_region): return False for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region): - if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range( - var, w_dom - ): + if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range(var, w_dom): return False return True @@ -178,13 +146,11 @@ def _iter_kind(i: tir.IterVar) -> str: var=iter.var, dom=iter.dom, loop_rv=loop, - ) - for loop, iter in zip(loops, iters) + ) for loop, iter in zip(loops, iters) ], block_rv=block, reduction_block=is_reduction, - ) - ) + )) return blocks @@ -225,25 +191,21 @@ def get_max_shared_memory_per_block(target: Target) -> int: max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) if max_shared_memory_per_block is None: raise ValueError( - f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually" - ) + f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually") return int(max_shared_memory_per_block) def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: try: block = sch.mod[func_name].body.block - except: - raise ValueError( - f"The function body is expected to be the root block, but got:\n" - f"{sch.mod[func_name].body}" - ) + except Exception: + raise ValueError(f"The function body is expected to be the root block, but got:\n" + f"{sch.mod[func_name].body}") from None return sch.get_block(block.name_hint) -def collect_block_iter_vars_used_in_access_region( - block: tir.Block, region: List[ir.Range] -) -> Set[tir.Var]: +def collect_block_iter_vars_used_in_access_region(block: tir.Block, + region: List[ir.Range]) -> Set[tir.Var]: """Collect the block iter variables used in the access region of a buffer region.""" tir_vars = set() for expr in region: @@ -271,9 +233,7 @@ def detect_dominant_read(block: tir.Block) -> tir.PrimExpr: dominant_read = None num_read_iters = -1 for buffer_region in block.reads: - tir_vars = collect_block_iter_vars_used_in_access_region( - block, buffer_region.region - ) + tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region) if num_read_iters < len(tir_vars): num_read_iters = len(tir_vars) dominant_read = buffer_region @@ -294,16 +254,14 @@ def is_broadcast_epilogue( if buffer_region.buffer not in write_buffers: continue tir_vars = collect_block_iter_vars_used_in_access_region( - sch.get(epilogue), buffer_region.region - ) + sch.get(epilogue), buffer_region.region) if len(tir_vars) < len(epilogue_iters): return True return False -def get_reduction_blocks( - sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] -) -> List[tir.schedule.BlockRV]: +def get_reduction_blocks(sch: tir.Schedule, + blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]: # Get the main computation block def is_reduction(block: BlockRV) -> bool: block_stmt = sch.get(block) @@ -330,7 +288,6 @@ def is_spatial(block: BlockRV) -> bool: def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int: # gpu memory prefer 128 bits coalesced access (e.g. four banks) # 128 bits - block_stmt buffers: List[tir.Buffer] = [] for read in block_stmt.reads: buffers.append(read.buffer) diff --git a/python/bitblas/base/common_schedules.py b/python/bitblas/base/common_schedules.py index e1852e81c..7d528c70a 100644 --- a/python/bitblas/base/common_schedules.py +++ b/python/bitblas/base/common_schedules.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm common_schedules.py in dlight. """Common schedule strategies for TIR.""" @@ -110,7 +110,7 @@ def _trial(func: Callable): for i, block in enumerate(blocks): try: func(block.block_rv) - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except continue return i return None diff --git a/python/bitblas/base/roller/__init__.py b/python/bitblas/base/roller/__init__.py index 39180bf1a..7ca6f15c2 100644 --- a/python/bitblas/base/roller/__init__.py +++ b/python/bitblas/base/roller/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .node import PrimFuncNode -from .config import Config -from .policy import DefaultPolicy, TensorCorePolicy -from .arch import Arch, CUDA +from .node import PrimFuncNode # noqa: F401 +from .hint import Hint # noqa: F401 +from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401 +from .arch import TileDevice, CUDA # noqa: F401 diff --git a/python/bitblas/base/roller/arch/__init__.py b/python/bitblas/base/roller/arch/__init__.py index 80293a133..9cb036792 100644 --- a/python/bitblas/base/roller/arch/__init__.py +++ b/python/bitblas/base/roller/arch/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .arch_base import Arch +from .arch_base import TileDevice from .cuda import * from .cpu import * -def get_arch(target: tvm.target.Target) -> Arch: +def get_arch(target: tvm.target.Target) -> TileDevice: if target.kind.name == "cuda": return CUDA(target) elif target.kind.name == "llvm": diff --git a/python/bitblas/base/roller/arch/arch_base.py b/python/bitblas/base/roller/arch/arch_base.py index 8628badce..6e98838c7 100644 --- a/python/bitblas/base/roller/arch/arch_base.py +++ b/python/bitblas/base/roller/arch/arch_base.py @@ -4,7 +4,7 @@ from typing import List -class Arch: +class TileDevice: """ Represents the architecture of a computing device, capturing various hardware specifications. """ @@ -38,4 +38,3 @@ def __init__(self) -> None: def get_avaliable_tensorintrin_shapes(self): raise NotImplementedError() - \ No newline at end of file diff --git a/python/bitblas/base/roller/arch/cpu.py b/python/bitblas/base/roller/arch/cpu.py index a90015717..98fb14af5 100644 --- a/python/bitblas/base/roller/arch/cpu.py +++ b/python/bitblas/base/roller/arch/cpu.py @@ -3,13 +3,13 @@ import tvm from tvm.target import Target -from .arch_base import Arch -from typing import List, Dict +from .arch_base import TileDevice # For LLVM Backend, we do not provide the detailed information of the CPU # As the LLVM backend do not required tuning, just maintain the consistency -class CPU(Arch): +class CPU(TileDevice): + def __init__(self, target: Target): self.target = target device = tvm.runtime.cpu(0) diff --git a/python/bitblas/base/roller/arch/cuda.py b/python/bitblas/base/roller/arch/cuda.py index ebe1a2ee8..63775ecbe 100644 --- a/python/bitblas/base/roller/arch/cuda.py +++ b/python/bitblas/base/roller/arch/cuda.py @@ -3,14 +3,17 @@ import tvm from tvm.target import Target -from .arch_base import Arch +from .arch_base import TileDevice from typing import List, Dict + def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 + class TensorInstruction(object): + def __init__( self, name: str, @@ -19,10 +22,12 @@ def __init__( ): self.name: str = name self.intrin_group: Dict = intrin_group - # only mantain the shape of M and N + # only maintain the shape of M and N self.shape: List[int] = shape -class CUDA(Arch): + +class CUDA(TileDevice): + def __init__(self, target: Target): self.target = target self.sm_version = check_sm_version(self.target.arch) @@ -57,4 +62,4 @@ def get_avaliable_tensorintrin_shapes(self): TensorInstruction("mma", get_mma_intrin_group, [16, 16]), TensorInstruction("wmma", get_wmma_intrin_group, [16, 16]), ) - return [t.shape for t in self.available_tensor_instructions] \ No newline at end of file + return [t.shape for t in self.available_tensor_instructions] diff --git a/python/bitblas/base/roller/config.py b/python/bitblas/base/roller/hint.py similarity index 81% rename from python/bitblas/base/roller/config.py rename to python/bitblas/base/roller/hint.py index f3ce60847..c5fcda366 100644 --- a/python/bitblas/base/roller/config.py +++ b/python/bitblas/base/roller/hint.py @@ -1,11 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -"""Config definition for schedule""" -from typing import Dict, List, Optional, Tuple -from ..roller import PrimFuncNode +"""Hint definition for schedule""" +from typing import Dict, List, Tuple +from . import PrimFuncNode import numpy as np - +from .rasterization import * class TensorCoreExtraConfig: """ @@ -62,7 +61,7 @@ def compute_elements_from_shape(self, shape: List[int]) -> int: strided_elem = original_shape else: assert self.ax < len(shape) - strided_elem = np.prod(shape[0 : self.ax + 1]) * self.stride + strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride assert strided_elem >= original_shape return int(strided_elem) @@ -107,7 +106,7 @@ def __hash__(self) -> int: class IntrinInfo: """ - The information of tensorcore intrinsic related infomation + The information of tensorcore intrinsic related information """ def __init__( @@ -115,21 +114,37 @@ def __init__( in_dtype: str, out_dtype: str, trans_b: bool, - smooth_a: bool = False, - smooth_b: bool = False, + input_transform_kind: int = 0, + weight_transform_kind: int = 0, ) -> None: self.in_dtype = in_dtype self.out_dtype = out_dtype self.trans_a = False self.trans_b = trans_b - self.smooth_a = smooth_a - self.smooth_b = smooth_b + self.input_transform_kind = input_transform_kind + self.weight_transform_kind = weight_transform_kind def __repr__(self) -> str: return f"" + @property + def smooth_a(self) -> bool: + return self.input_transform_kind >= 2 + + @property + def smooth_b(self) -> bool: + return self.weight_transform_kind >= 2 + + @property + def inter_transform_a(self) -> bool: + return self.input_transform_kind >= 1 -class Config(object): + @property + def inter_transform_b(self) -> bool: + return self.weight_transform_kind >= 1 + + +class Hint(object): """ Central configuration class for managing various parameters of computational tasks. """ @@ -138,7 +153,7 @@ def __init__(self) -> None: self.arch = None self.use_tc = None # todo(lei): this should be renamed. - # spacial axes tiling info + # special axes tiling info self.block = [] self.thread = [] # special axes for tensorCore @@ -146,7 +161,7 @@ def __init__(self) -> None: # reduce axes tiling info self.rstep = [] self.reduce_thread = [] - self.rasterization_plan = None + self.rasterization_plan = NoRasterization() self.cached_tensors = [] self.output_strides = {} self.schedule_stages = None @@ -189,28 +204,10 @@ def to_dict(self) -> Dict: dic["vectorize"] = self.vectorize return dic - def from_dict(self, dic: Dict) -> "Config": + def from_dict(self, dic: Dict) -> "Hint": self.__init__() - if "use_tc" in dic: - self.use_tc = dic["use_tc"] - self.block = dic["block"] - if self.use_tc: - self.warp = dic["warp"] - else: - self.thread = dic["thread"] - self.rstep = dic["rstep"] - if "reduce_thread" in dic: - self.reduce_thread = dic["reduce_thread"] - else: - self.reduce_thread = [1 for _ in self.rstep] - if "strides" in dic: - self.output_strides = dic["strides"] - if "step" in dic: - self._step = dic["step"] - if "raxis_order" in dic: - self._raxis_order = dic["raxis_order"] - if "vectorize" in dic: - self.vectorize = dic["vectorize"] + for k, v in dic.items(): + setattr(self, k, v) return self @property @@ -228,7 +225,7 @@ def step(self) -> List[int]: def __repr__(self) -> str: return str(self.to_dict()) - def complete_config(self, node:PrimFuncNode): + def complete_config(self, node: PrimFuncNode): # analysis pass context, for int8 mma, we should merge static shared memory merge_static_smem = False if self.use_tc and self.intrin_info.in_dtype == "int8": diff --git a/python/bitblas/base/roller/node.py b/python/bitblas/base/roller/node.py index 97f7be917..8e20440bb 100644 --- a/python/bitblas/base/roller/node.py +++ b/python/bitblas/base/roller/node.py @@ -1,15 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -"""PrimFunc Warpper and Block Infomation Analaysis""" +"""PrimFunc Wrapper and Block information Analaysis""" import tvm from tvm import tir -from tvm.tir import IterVar, Var, PrimFunc -from typing import Any, Iterable, Dict, List, Tuple -import functools -import numpy as np +from tvm.tir import IterVar, PrimFunc +from typing import Any, Dict, List, Tuple, Optional from tvm.tir.schedule.schedule import BlockRV +import numpy as np +import functools from ..analysis import BlockInfo, get_reduction_blocks from .. import analysis from .. import normalize_prim_func @@ -32,6 +31,7 @@ def _traverse(block): class BlockAnalyzer(object): + def __init__(self, sch) -> None: self.sch: tir.Schedule = sch self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch) @@ -84,7 +84,10 @@ def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]: class Node(object): - def __init__(self, tags: Dict = {}) -> None: + + def __init__(self, tags: Optional[Dict] = None) -> None: + if tags is None: + tags = {} self._dtypes = [] self._tag: Dict = {} for tag in tags: @@ -103,7 +106,8 @@ def get_tag(self, k: str) -> Any: class PrimFuncNode(Node): - def __init__(self, prim_func: PrimFunc, tags: Dict = {}) -> None: + + def __init__(self, prim_func: PrimFunc, tags: Optional[Dict] = None) -> None: super().__init__(tags) self.prim_func = self._specialize_func(prim_func) self.sch: tir.Schedule = tir.Schedule(self.prim_func) @@ -181,7 +185,7 @@ def get_opt_shape(self, name) -> int: return None return opt_shapes[name] - def extent_warpper(self, value) -> int: + def extent_wrapper(self, value) -> int: if isinstance(value, tvm.tir.Var): return self.get_opt_shape(value.name) elif isinstance(value, tvm.tir.IntImm): @@ -224,19 +228,21 @@ def get_dtype(self, id=0) -> tvm.DataType: def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType: return tvm.DataType(buffer.dtype) - def propogate(self, tile, rstep={}, targets=None): + def propagate(self, tile, rstep: Optional[Dict] = None, targets=None): + if rstep is None: + rstep = {} shape = { - self.block_analyzer.get_output_buffers(block)[0].name: [ - tvm.arith.ConstIntBound(0, val - 1) for val in tile - ] - for block in self.schedule_stages + self.block_analyzer.get_output_buffers(block)[0].name: + [tvm.arith.ConstIntBound(0, val - 1) for val in tile] for block in self.schedule_stages } return self.ana.infer(shape, rstep, targets) - def propogate_inputs(self, tile, rstep={}) -> List[List[int]]: + def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + if rstep is None: + rstep = {} read_idx_offset = len(self.input_buffers) targets = [t.name for t in self.args[:read_idx_offset]] - shapes, intermediate_bind = self.propogate(tile, rstep, targets) + shapes, intermediate_bind = self.propagate(tile, rstep, targets) results = [] for i, arg in enumerate(self.args[:read_idx_offset]): if arg.name in intermediate_bind: @@ -244,16 +250,42 @@ def propogate_inputs(self, tile, rstep={}) -> List[List[int]]: continue # should not exceed original shape trimmed_shape = [ - self.extent_warpper(i) + self.extent_wrapper(i) for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) ] results.append(trimmed_shape) return results - def propogate_outputs(self, tile, rstep={}) -> List[List[int]]: + # Propagate inputs only on reduction block + def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + if rstep is None: + rstep = {} + reduction_block = self.reduction_block + args = self.block_analyzer.get_input_buffers(reduction_block) + targets = [t.name for t in args] + shapes, intermediate_bind = self.propagate(tile, rstep, targets) + results = [] + for i, arg in enumerate(args): + if arg.name in intermediate_bind: + results.append(shapes[arg.name]) + continue + # should not exceed original shape + propagate_shape = shapes[arg.name] + buffer_shape = args[i].shape + if len(buffer_shape) > len(propagate_shape): + buffer_shape = buffer_shape[-len(propagate_shape):] + trimmed_shape = [ + self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape))) + ] + results.append(trimmed_shape) + return results + + def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + if rstep is None: + rstep = {} read_idx_offset = len(self.input_buffers) targets = [t.name for t in self.args[read_idx_offset:]] - shapes, _ = self.propogate(tile, rstep, targets) + shapes, _ = self.propagate(tile, rstep, targets) results = [] for i, arg in enumerate(self.args[read_idx_offset:]): # should not exceed original shape @@ -261,11 +293,15 @@ def propogate_outputs(self, tile, rstep={}) -> List[List[int]]: results.append(trimmed_shape) return results - def propogate_reduction_inputs(self, shape, rstep={}) -> Dict[str, List[int]]: + def propagate_reduction_inputs(self, + shape, + rstep: Optional[Dict] = None) -> Dict[str, List[int]]: + if rstep is None: + rstep = {} if self.reduction_block is None: return {} targets = [b.name for b in self.block_analyzer.get_input_buffers(self.reduction_block)] - results, _ = self.propogate(shape, rstep, targets) + results, _ = self.propagate(shape, rstep, targets) return results def get_reduce_inputs_dtype(self): @@ -284,52 +320,51 @@ def infer_tensorcore_axis(self) -> Tuple[int]: C_ax_m, C_ax_n = self.get_tag("tensorcore_config") wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok - def get_cl_shapes(c_ax_m, c_ax_n): - output_buffer_shape = ( - self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape - ) - valid_region = [] - for region in output_buffer_shape: - if region.value == 1: - continue - valid_region.append(region) + output_buffer_shape = ( + self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape) + valid_region = [] + for region in output_buffer_shape: + if region.value == 1: + continue + valid_region.append(region) - num_nvalid_regions = len(output_buffer_shape) - len(valid_region) + num_nvalid_regions = len(output_buffer_shape) - len(valid_region) + self.set_tag("num_nvalid_regions", num_nvalid_regions) + def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions): spatial_dim = self.get_space_dim() assert len(valid_region) == len( - spatial_dim - ), f" {valid_region} mismatch with {spatial_dim}" + spatial_dim), f" {valid_region} mismatch with {spatial_dim}" cl_shapes = [1] * len(spatial_dim) cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n - self.set_tag("tensorcore_config", [s - num_nvalid_regions for s in [c_ax_m, c_ax_n]]) return cl_shapes - CL_shape = get_cl_shapes(C_ax_m, C_ax_n) - shapes = self.propogate_reduction_inputs(CL_shape, {x.var.name: 1 for x in self.raxis}) + CL_shape = get_cl_shapes(C_ax_m, C_ax_n, num_nvalid_regions) + self.set_tag("tensorcore_config", [s - num_nvalid_regions for s in [C_ax_m, C_ax_n]]) + shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: 1 for x in self.raxis}) A_deps, B_deps = shapes.values() A_ax_m = A_deps.index(wmma_m) B_ax_n = B_deps.index(wmma_n) CL_shape = [1] * len(self.get_space_dim()) - shapes = self.propogate_reduction_inputs(CL_shape, {x.var.name: wmma_k for x in self.raxis}) + shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: wmma_k for x in self.raxis}) A_deps, B_deps = shapes.values() A_ax_k = len(A_deps) - 1 - A_deps[::-1].index(wmma_k) B_ax_k = len(B_deps) - 1 - B_deps[::-1].index(wmma_k) tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n) return tc_axis - def footprint(self, shape, rstep, stride_map={}) -> int: + def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int: + if stride_map is None: + stride_map = {} result = 0 - shapes, _ = self.propogate(shape, rstep) + shapes, _ = self.propagate(shape, rstep) def is_broadcast_pattern(buffer, output_buffer): - return ( - buffer in self.args - and len(shapes[output_buffer.name]) > len(shapes[buffer.name]) - and np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name]) - ) + return (buffer in self.args and + len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and + np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name])) def is_after_reduce_stage(block): if not self.reduction_block: @@ -351,9 +386,8 @@ def is_after_reduce_stage(block): output_buffer = self.block_analyzer.get_output_buffers(block)[0] for buffer in self.block_analyzer.get_input_buffers(block): cache = buffer.name not in cached_tensor and ( - is_broadcast_pattern(buffer, output_buffer) - or self.block_analyzer.get_block_info(block).is_reduction - ) + is_broadcast_pattern(buffer, output_buffer) or + self.block_analyzer.get_block_info(block).is_reduction) if not cache: continue cached_tensor.append(buffer.name) @@ -362,11 +396,13 @@ def is_after_reduce_stage(block): if buffer.name in stride_map: num_elem = stride_map[buffer.name].compute_elements_from_shape( - shapes[buffer.name] - ) + shapes[buffer.name]) else: num_elem = np.prod(shapes[buffer.name]) buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8) buffer_len = (buffer_len + 31) // 32 * 32 result += buffer_len return result, cached_tensor + + def get_input_buffers(self) -> List[tir.Buffer]: + return self.block_analyzer.input_buffers diff --git a/python/bitblas/base/roller/policy/default.py b/python/bitblas/base/roller/policy/default.py index ac85921ce..5526e1316 100644 --- a/python/bitblas/base/roller/policy/default.py +++ b/python/bitblas/base/roller/policy/default.py @@ -1,22 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - """Policy for cuda core schedule""" import functools import math from queue import PriorityQueue -from typing import Iterable, Dict, List +from typing import Iterable, Dict, List, Optional import numpy as np import tvm - -from ..arch import Arch +from ..arch import TileDevice from ..bestfit import BestFit -from ..config import Config, Stride, TileDict +from ..hint import Hint, Stride, TileDict from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors from ..node import PrimFuncNode -from ..rasterization import * +from ..rasterization import NoRasterization class DefaultPolicy: @@ -25,13 +23,18 @@ class DefaultPolicy: minimize memory traffic and maximize parallelism.for Dlight Schedule. """ - def __init__(self, func: tvm.tir.PrimFunc, arch: Arch, tags: Dict = {}) -> None: + def __init__(self, + func: tvm.tir.PrimFunc, + arch: TileDevice, + tags: Optional[Dict] = None) -> None: + if tags is None: + tags = {} self.arch = arch self.prim_func_node = PrimFuncNode(func, tags) self.ordered_nodes = [self.prim_func_node] self.output_nodes = [self.prim_func_node] - def emit_config(self, topk: int) -> List[Config]: + def emit_config(self, topk: int) -> List[Hint]: base_tile = self.get_base_tile() if base_tile is None: return [] @@ -54,14 +57,13 @@ def emit_config(self, topk: int) -> List[Config]: def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]: _steps = [get_all_factors(n) for n in self.prim_func_node.get_space_dim()] - steps = [step[step.index(t) :] for step, t in zip(_steps, init_tile)] + steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)] for i in range(len(steps)): added = list( filter( lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i], [2, 4, 8, 16, 32], - ) - ) + )) steps[i].extend(added) steps[i] = sorted(steps[i]) visited_tiles = {} @@ -125,10 +127,8 @@ def _get_output_tile_map(self, tile): """ tile_map = {} tile_map[self.prim_func_node] = [ - tile[i] - * self.prim_func_node.get_space_dim()[i] - // self.output_nodes[0].get_space_dim()[i] - for i in range(len(tile)) + tile[i] * self.prim_func_node.get_space_dim()[i] // + self.output_nodes[0].get_space_dim()[i] for i in range(len(tile)) ] return tile_map @@ -226,11 +226,10 @@ def sim(a: int, b: int): def _score(rstep_id): rstep = {k: all_steps[k][rstep_id[k]] for k in rstep_id} score = 0 - shape = node.propogate_inputs(tile, rstep=rstep) + shape = node.propagate_inputs(tile, rstep=rstep) for i, input_buffer in enumerate(node.input_buffers): read_transaction_elements = self.arch.transaction_size[1] // ( - (node.get_buffer_dtype(input_buffer).bits + 7) // 8 - ) + (node.get_buffer_dtype(input_buffer).bits + 7) // 8) score += sim( int(coalesced_factor(shape[i], input_buffer.shape)), read_transaction_elements, @@ -290,7 +289,7 @@ def _score(rstep_id): k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis } score = 0 - shape = node.propogate_inputs(td.get_tile(node), rstep=rstep) + shape = node.propagate_inputs(td.get_tile(node), rstep=rstep) for i, input_buffer in enumerate(node.input_buffers): score += coalesced_factor(shape[i], input_buffer.shape) return score @@ -356,24 +355,20 @@ def _compute_memory_traffic(self, output_tile): traffic = 0 for node in reversed(self.ordered_nodes): tile = op_tile_map[node] - input_shapes = node.propogate_inputs(tile) - output_shapes = node.propogate_outputs(tile) + input_shapes = node.propagate_inputs(tile) + output_shapes = node.propagate_outputs(tile) for i, buffer in enumerate(node.input_buffers): nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8 read_transaction_elements = self.arch.transaction_size[1] // nbytes traffic += ( coalesced_tensor_shape(input_shapes[i], buffer.shape, read_transaction_elements) - * nbytes - ) + * nbytes) for i, buffer in enumerate(node.output_buffers): nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8 write_transaction_elements = self.arch.transaction_size[0] // nbytes traffic += ( - coalesced_tensor_shape( - output_shapes[i], buffer.shape, write_transaction_elements - ) - * nbytes - ) + coalesced_tensor_shape(output_shapes[i], buffer.shape, + write_transaction_elements) * nbytes) return traffic, op_tile_map def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): @@ -416,8 +411,7 @@ def _compute_shared_memory_usage(self, td: TileDict): cached_tensors_map = {} node_internal_bytes, cached_tensors_map[self.prim_func_node] = self.infer_node_smem_usage( - td, self.prim_func_node - ) + td, self.prim_func_node) block = allocator.malloc(node_internal_bytes) allocator.free(block) assert len(block_map) == 0 @@ -463,8 +457,7 @@ def _compute_stride_map(self, td: TileDict): tensor_strides_map = {} for node in self.ordered_nodes: output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map( - node, td - ) + node, td) td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict: @@ -494,15 +487,9 @@ def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict: output_shape = self.output_nodes[0].get_space_dim() td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)])) # estimated reg usage - reg_usage = int( - 2 - * max( - [ - np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 - for node in self.ordered_nodes - ] - ) - ) + reg_usage = int(2 * max([ + np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes + ])) if reg_usage > self.arch.reg_cap: td.valid = False return td @@ -527,16 +514,13 @@ def check_tile_shape_isvalid(self, td: TileDict) -> bool: for node in self.ordered_nodes: if np.prod(td.get_tile(node)) == 0: return False - node_grid_size = np.prod( - [(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())] - ) + node_grid_size = np.prod([ + (y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim()) + ]) if node_grid_size != td.grid_size: return False - if ( - hasattr(node, "reduce_op") - and node.reduce_op is not None - and len(node.reduce_op.axis) == len(td.output_tile) - ): + if (hasattr(node, "reduce_op") and node.reduce_op is not None and + len(node.reduce_op.axis) == len(td.output_tile)): for i, tile_extent in enumerate(td.output_tile): if node.reduce_op.axis[i].dom.extent % tile_extent: return False @@ -561,8 +545,7 @@ def recommend_block_size(self, td: TileDict) -> List[int]: max_block_size = functools.reduce(math.gcd, node_space_sizes) if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min( - node_space_sizes - ): + node_space_sizes): node_reduce_sizes = [ int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes ] @@ -572,14 +555,12 @@ def recommend_block_size(self, td: TileDict) -> List[int]: filter( lambda x: x % max_block_size == 0 and x <= 1024, get_all_factors(max_possible_size), - ) - ) + )) possible_block_sizes = list( filter( # either be a factor of space or cover fully cover the space lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]), possible_block_sizes, - ) - ) + )) factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) return factor_ordered else: @@ -635,8 +616,8 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): Returns ------- - Config - A Config object containing the assigned block size and other related settings. + Hint + A Hint object containing the assigned block size and other related settings. """ tile, rsteps = td.get_tile(node), td.get_rstep(node) factors = factorize(block_size) @@ -647,8 +628,8 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): def _score(node, thread): # small is better score = 0 block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] - shape = node.propogate_inputs(block_tile) - for i, buffer in enumerate(node.input_buffers): + shape = node.propagate_inputs(block_tile) + for i, _ in enumerate(node.input_buffers): score += np.prod(shape[i]) / self.arch.bandwidth[1] for buffer in node.output_buffers: score += coalesced_tensor_shape(thread, buffer.shape, 8) / self.arch.bandwidth[0] @@ -678,7 +659,7 @@ def _score(node, thread): # small is better assert target_ax reduce_thread[target_ax] *= factor - codegen_dict = Config() + codegen_dict = Hint() codegen_dict.block = tile codegen_dict.thread = cur_threads codegen_dict.rstep = [rsteps[ax.var.name] for ax in node.raxis] @@ -740,15 +721,12 @@ def is_type_allowed(dtype, vec): vectorize_sizes = [16, 8, 4, 2] dtypes = node.get_reduce_inputs_dtype() - shapes = node.propogate_reduction_inputs(td.get_tile(node), td.get_rstep(node)) + shapes = node.propagate_reduction_inputs(td.get_tile(node), td.get_rstep(node)) vectorize_result = {} for tensor, shape in shapes.items(): for v in vectorize_sizes: - if ( - is_shape_aligned(shape, block_size * v) - and is_cont(shape, v) - and is_type_allowed(dtypes[tensor], v) - ): + if (is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and + is_type_allowed(dtypes[tensor], v)): vectorize_result[tensor] = v break return vectorize_result diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py index 7f4620fa4..eb8aa0600 100644 --- a/python/bitblas/base/roller/policy/tensorcore.py +++ b/python/bitblas/base/roller/policy/tensorcore.py @@ -1,21 +1,24 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - """Policy for tensorcore schedule""" import tvm -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import numpy as np -from ..arch import Arch -from ..config import Config, Stride, TileDict, IntrinInfo +from ..arch import TileDevice +from ..hint import Hint, Stride, TileDict, IntrinInfo from ..node import PrimFuncNode from .common import coalesced_factor, factorize, get_all_factors from .default import DefaultPolicy -from ..rasterization import * +from ..rasterization import NoRasterization, Rasterization2DColumn class TensorCorePolicy(DefaultPolicy): - def __init__(self, func: tvm.tir.PrimFunc, arch: Arch, tags: Dict = {}) -> None: + + def __init__(self, + func: tvm.tir.PrimFunc, + arch: TileDevice, + tags: Optional[Dict] = None) -> None: super().__init__(func, arch, tags) # this is the trick for wmma. # However, for int8 mma, the wmma_k should be 32. @@ -43,11 +46,16 @@ def _legalize_info(self): self.use_async_copy = False def _compute_tc_strides( - self, node: PrimFuncNode, tile: List[int], rstep: Dict[str, int] = {} + self, + node: PrimFuncNode, + tile: List[int], + rstep: Optional[Dict[str, int]] = None, ) -> Tuple[Stride, Stride, Stride]: + if rstep is None: + rstep = {} # strides was used for shared memory padding. which is necessary for avoiding # shared memory load bank conflict when we do not applying tensorcore layout. - shapes = node.propogate_reduction_inputs(tile, rstep) + shapes = node.propagate_reduction_inputs(tile, rstep) AS_shape, BS_shape = shapes.values() CS_shape = tile A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n = node.infer_tensorcore_axis() @@ -58,15 +66,9 @@ def _compute_tc_strides( A_high_ax = min(A_ax_m, A_ax_k) B_high_ax = min(B_ax_n, B_ax_k) C_high_ax = min(C_ax_m, C_ax_n) - A_stride = Stride( - stride=np.prod(AS_shape[A_high_ax + 1 :]) + offset, ax=A_high_ax - ) - B_stride = Stride( - stride=np.prod(BS_shape[B_high_ax + 1 :]) + offset, ax=B_high_ax - ) - C_stride = Stride( - stride=np.prod(CS_shape[C_high_ax + 1 :]) + offset, ax=C_high_ax - ) + A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax) + B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax) + C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax) return A_stride, B_stride, C_stride def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): @@ -81,8 +83,7 @@ def _assign_reduce_step(self, node): target_transaction = self.arch.transaction_size[0] * 2 # 512 bytes // type bits reduce_input_dtype = node.get_buffer_dtype( - node.block_analyzer.get_input_buffers(node.reduction_block)[0] - ) + node.block_analyzer.get_input_buffers(node.reduction_block)[0]) basic = (target_transaction * 8) // reduce_input_dtype.bits result = {} @@ -90,9 +91,7 @@ def _assign_reduce_step(self, node): iter_name = iter_info.var.name iter_dom = iter_info.dom.extent if iter_dom % 16 > 0: - result[iter_name] = ( - 16 if iter_dom < basic else basic - ) # for the case of padding + result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding elif iter_dom % basic == 0: result[iter_name] = basic else: @@ -113,32 +112,28 @@ def _check_small_tile(td: TileDict): if not _check_small_tile(td): return None - smem_limit = min( - self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap - ) + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) rstep_map = td.rstep_map.copy() def _optimize(node, rstep): all_steps = self.get_node_reduce_step_candidates(node) - # todo(lei): optimzie the all_steps enlarge policy to be a multiple of the original all_steps[k] + # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] for k in all_steps: all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) if any([v == [] for v in all_steps.values()]): return rstep def _shared_memory_usage(td: TileDict): - return node.footprint( - td.output_tile, new_rstep_map, td.tensor_strides_map[node] - ) + return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) def _score(rstep_id): rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] - for k in node.raxis + k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis } score = 0 - shape = node.propogate_inputs(td.get_tile(node), rstep=rstep) - for i, input_buffer in enumerate(node.input_buffers): + shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, input_buffer in enumerate(input_buffers): score += coalesced_factor(shape[i], input_buffer.shape) return score @@ -155,8 +150,7 @@ def _enlarge(rstep_id): return max(candidates, key=lambda x: x[1])[0] cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) - for k in node.raxis + k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis } new_rstep_map = rstep_map.copy() while True: @@ -164,8 +158,7 @@ def _enlarge(rstep_id): if new_rstep_id is None: break new_rstep_map = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] - for k in node.raxis + k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis } old_rstep_map = td.rstep_map td.rstep_map = new_rstep_map @@ -176,8 +169,7 @@ def _enlarge(rstep_id): else: cur_rstep_id = new_rstep_id rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] - for k in node.raxis + k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis } return rstep @@ -195,10 +187,8 @@ def get_node_reduce_step_candidates(self, node): else: # must be a a multiple of wmma_k return { - k.var.name: [ - x * self.wmma_k - for x in get_all_factors(int(k.dom.extent) // self.wmma_k) - ] + k.var.name: + [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] for k in node.raxis } @@ -206,7 +196,10 @@ def check_tile_shape_isvalid(self, td: TileDict): for node in self.ordered_nodes: if node.get_tag("tensorcore_config"): ax_m, ax_n = node.get_tag("tensorcore_config") - block_m, block_n = td.tile_map[node][ax_m], td.tile_map[node][ax_n] + block_m, block_n = ( + td.tile_map[node][ax_m], + td.tile_map[node][ax_n], + ) # check the tile size is valid wmma_invalid = [ block_m < wmma_m or block_n < wmma_n @@ -214,9 +207,7 @@ def check_tile_shape_isvalid(self, td: TileDict): ] if all(wmma_invalid): return False - if any( - [y % x for x, y in zip(td.tile_map[node], node.get_space_dim())] - ): + if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]): return False return super().check_tile_shape_isvalid(td) @@ -231,20 +222,16 @@ def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): return super().compute_node_stride_map(node, td) use_layout = self._can_implement_layout(node, td) - AS_stride, BS_stride, C_stride = self._compute_tc_strides( - node, td.get_tile(node), td.get_rstep(node) - ) + AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), + td.get_rstep(node)) A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node)) tensor_strides = {} output_strides = { - int(i + len(node.input_buffers)): Stride() - for i, _ in enumerate(node.output_buffers) + int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) } tensor_strides = {} # when connected to shared input, should use full stride without rstep - for i, (stride, stride_full) in enumerate( - zip([AS_stride, BS_stride], [A_stride, B_stride]) - ): + for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])): if use_layout: continue _ = node.block_analyzer.get_input_buffers(node.reduction_block)[i].name @@ -278,8 +265,9 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): def _score(node, thread): # small is better score = 0 block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] - shape = node.propogate_inputs(block_tile) - for i, _ in enumerate(node.input_buffers): + shape = node.propagate_inputs_on_reduction(block_tile) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, _ in enumerate(input_buffers): score += np.prod(shape[i]) / self.arch.bandwidth[1] return score @@ -297,7 +285,7 @@ def _score(node, thread): # small is better dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) warp_tile[dim_order[0]] *= factor - codegen_dict = Config() + codegen_dict = Hint() codegen_dict.block = tile codegen_dict.warp = warp_tile codegen_dict.use_tc = True @@ -315,9 +303,7 @@ def _score(node, thread): # small is better codegen_dict.shared_scope = "shared.dyn" codegen_dict.complete_config(node) - codegen_dict.vectorize = self._plan_vectorize( - self.prim_func_node, td, block_size - ) + codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size) codegen_dict.arch = self.arch codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes") return codegen_dict @@ -326,19 +312,16 @@ def plan_rasterization(self, td: TileDict): conditions = [] # only support single node for now conditions.append(len(self.ordered_nodes) > 1) - # small op don't need imporve l2 cache - conditions.append(td.num_wave < 4) # only on Ampere+ arch conditions.append(self.arch.compute_capability < "80") def _check_memory_size(): overall_gmem_size_in_bytes: int = 0 for node in self.ordered_nodes: - for arg in node.args: + for buffer in node.input_buffers: overall_gmem_size_in_bytes += ( - int(np.prod(arg.shape)) * tvm.DataType(arg.dtype).bits // 8 - ) - return overall_gmem_size_in_bytes < (self.arch.l2_cache_size_bytes * 4) + int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8) + return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes conditions.append(_check_memory_size()) if any(conditions): diff --git a/python/bitblas/base/roller/rasterization.py b/python/bitblas/base/roller/rasterization.py index a15b0d8dc..4fb779069 100644 --- a/python/bitblas/base/roller/rasterization.py +++ b/python/bitblas/base/roller/rasterization.py @@ -1,12 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - """Rasteration Plan For L2 Cache Locality""" from typing import List class Rasterization: + def __init__(self) -> None: pass @@ -15,6 +15,7 @@ def get_code(self) -> List[str]: class NoRasterization(Rasterization): + def __init__(self) -> None: super().__init__() @@ -63,7 +64,7 @@ def __repr__(self) -> str: def get_device_function(self) -> str: return """ -__device__ dim3 rasterization2DColumn(const int panel_width) { +__device__ __inline__ dim3 rasterization2DColumn(const int panel_width) { const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; const auto totalPanel = (gridDim.x * gridDim.y +panel_width * gridDim.x - 1) / (panel_width * gridDim.x); const auto totalBlock = gridDim.x * gridDim.y; @@ -78,8 +79,10 @@ def get_device_function(self) -> str: } """ - def get_code(self) -> List[str]: + def get_code(self, panel_width: int = None) -> List[str]: + if panel_width is None: + panel_width = self.panel_width_ return [ self.get_device_function(), - "const dim3 blockIdx(rasterization2DColumn({});".format(self.panel_width_), + "const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width), ] diff --git a/python/bitblas/base/transform.py b/python/bitblas/base/transform.py index 6f4f6adce..647efa772 100644 --- a/python/bitblas/base/transform.py +++ b/python/bitblas/base/transform.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - """ Apply ScheduleRules onto an IRModule to generate default schedules without tuning, or a space for MetaSchedule tuning @@ -16,12 +15,12 @@ from tvm.ir import IRModule from tvm.ir.transform import PassContext, module_pass from tvm.target import Target -from .roller.policy import DefaultPolicy, TensorCorePolicy -from .roller.arch import CUDA from .schedule_rule import ScheduleRule -from ..gpu.matmul_analysis import get_tensorized_func_and_tags from ..base.analysis import check_func_with_dynamic -from .utils import apply_and_build, fast_tune, fast_tune_with_dynamic_range +from .utils import fast_tune, fast_tune_with_dynamic_range +import logging + +logger = logging.getLogger(__name__) def _is_scheduled(func: tir.PrimFunc) -> bool: @@ -61,9 +60,7 @@ def transform_module( # pylint: disable=missing-function-docstring sch = _apply_rules(func, target, self.rules, tunable=False) if sch is not None: assert len(sch) == 1 - updated_functions[g_var] = ( - sch[0].mod["main"].with_attr("tir.is_scheduled", 1) - ) + updated_functions[g_var] = (sch[0].mod["main"].with_attr("tir.is_scheduled", 1)) for g_var, func in updated_functions.items(): mod[g_var] = func return mod @@ -79,8 +76,8 @@ def __init__( target: Optional[Target] = None, parallel_build: bool = True, meta_database_dir: str = None, - whitelist: List[str] = [], - dynamic_range: Dict[str, List[int]] = {}, + whitelist: Optional[List[str]] = None, + dynamic_range: Optional[Dict[str, List[int]]] = None, ): """Construct a new ApplyFastTuning pass. @@ -91,6 +88,10 @@ def __init__( dynamic_range : Dict[str, List[int]] Use for generate kernel based on dynamic range. """ + if whitelist is None: + whitelist = [] + if dynamic_range is None: + dynamic_range = {} self.topk = topk self.target = Target.current() if target is None else target self.parallel_build = parallel_build @@ -98,20 +99,15 @@ def __init__( self.whitelist = whitelist self.dynamic_range = dynamic_range self.temp_dir = tempfile.TemporaryDirectory() - print(f"[BitBLAS] Using meta database dir {self.temp_dir}") path_workload = osp.join(self.temp_dir.name, "database_workload.json") path_tuning_record = osp.join(self.temp_dir.name, "database_tuning_record.json") self.cache_meta_database = ms.database.JSONDatabase( - path_workload, path_tuning_record, module_equality="structural" - ) + path_workload, path_tuning_record, module_equality="structural") def _in_white_list(self, func_name: str) -> bool: if len(self.whitelist) == 0: return True - for name in self.whitelist: - if name in func_name: - return True - return False + return any([name in func_name for name in self.whitelist]) def transform_module( # pylint: disable=missing-function-docstring self, @@ -125,10 +121,7 @@ def transform_module( # pylint: disable=missing-function-docstring if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): if not self._in_white_list(g_var.name_hint): continue - print(f"[BitBLAS] Start to apply fast tuning for {g_var}") - normalize_mod_func_ = tvm._ffi.get_global_func( - "tvm.meta_schedule.normalize_mod" - ) + normalize_mod_func_ = tvm._ffi.get_global_func("tvm.meta_schedule.normalize_mod") _normalized_func_mod = normalize_mod_func_(func) if self.cache_meta_database.has_workload(_normalized_func_mod): @@ -141,10 +134,7 @@ def transform_module( # pylint: disable=missing-function-docstring trace = tuning_record.trace sch = tvm.tir.Schedule(func) trace.apply_to_schedule(sch, remove_postproc=False) - print(f"[BitBLAS] Find Cache for {g_var}") - updated_functions[g_var] = sch.mod["main"].with_attr( - "tir.is_scheduled", 1 - ) + updated_functions[g_var] = sch.mod["main"].with_attr("tir.is_scheduled", 1) continue if check_func_with_dynamic(func): @@ -163,16 +153,11 @@ def transform_module( # pylint: disable=missing-function-docstring if g.name_hint == g_var.name_hint: # avoid duplicated global symbol updated_functions[g_var] = f.without_attr( - "global_symbol" - ).with_attr("tir.is_scheduled", 1) + "global_symbol").with_attr("tir.is_scheduled", 1) else: - updated_functions[g] = f.with_attr( - "tir.is_scheduled", 1 - ) - # cannot reuse meta database as it canot be recorvered from the trace - workload = self.cache_meta_database.commit_workload( - _normalized_func_mod - ) + updated_functions[g] = f.with_attr("tir.is_scheduled", 1) + # cannot reuse meta database as it cannot be recorvered from the trace + workload = self.cache_meta_database.commit_workload(_normalized_func_mod) else: # otherwise is static shape analysis _, best = fast_tune( @@ -184,11 +169,8 @@ def transform_module( # pylint: disable=missing-function-docstring if best is not None: updated_functions[g_var] = best.sch.mod["main"].with_attr( - "tir.is_scheduled", 1 - ) - workload = self.cache_meta_database.commit_workload( - _normalized_func_mod - ) + "tir.is_scheduled", 1) + workload = self.cache_meta_database.commit_workload(_normalized_func_mod) # only record the best schedule self.cache_meta_database.commit_tuning_record( ms.database.TuningRecord( @@ -196,11 +178,8 @@ def transform_module( # pylint: disable=missing-function-docstring workload, [best.latency], target, - ms.arg_info.ArgInfo.from_prim_func( - func=best.sch.mod["main"] - ), - ) - ) + ms.arg_info.ArgInfo.from_prim_func(func=best.sch.mod["main"]), + )) for g_var, func in updated_functions.items(): mod[g_var] = func @@ -210,9 +189,7 @@ def transform_module( # pylint: disable=missing-function-docstring if not osp.exists(self.meta_database_dir): os.makedirs(self.meta_database_dir) # TODO(lei): maybe another way to copy the database - shutil.copytree( - self.temp_dir.name, self.meta_database_dir, dirs_exist_ok=True - ) + shutil.copytree(self.temp_dir.name, self.meta_database_dir, dirs_exist_ok=True) return mod @@ -231,7 +208,7 @@ def _apply_rules( try: space = rule.apply(func, target, tunable) except Exception: - print(f"[BitBLAS][Error] applying rule {rule} failed") + logger.debug(f"[BitBLAS][Error] applying rule {rule} failed") space = None if space is None: continue diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py index 278b36e53..ac5763487 100644 --- a/python/bitblas/base/utils.py +++ b/python/bitblas/base/utils.py @@ -3,7 +3,7 @@ import tvm import os -from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind, MapResult +from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np from typing import List, Tuple, Optional, Dict, Union, Literal @@ -16,15 +16,15 @@ from bitblas.base.roller.arch import CUDA from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.base.roller.rasterization import NoRasterization import tempfile import itertools from tvm.ir.supply import GlobalVarSupply -from bitblas.utils import match_global_kernel, tensor_replace_dp4a +from bitblas.utils import tensor_replace_dp4a import logging logger = logging.getLogger(__name__) + def get_rasterization_code(pannel_width: int = 8) -> str: return f""" const int MAX_BLOCK_N = {pannel_width}; @@ -56,12 +56,12 @@ def __init__(self, config, sch, mod: Module): def profile(self): profile_tensors = self.profile_tensors - return self.time_evaluator(*profile_tensors).mean + return self.time_evaluator(*profile_tensors).mean * 1e3 def _apply_config( - func: tir.PrimFunc, - config=None, # todo(lei): update typing + func: tir.PrimFunc, + config=None, # todo(lei): update typing ) -> Optional[tir.Schedule]: """ find rules: @@ -112,6 +112,7 @@ def get_dummy_input_arrays( device: tvm.runtime.Device, distribution: Literal["uniform", "onefill"] = "uniform", ): + def var_wrapper(v): if isinstance(v, tvm.tir.Var): assert "opt_shapes" in func.attrs @@ -137,48 +138,44 @@ def var_wrapper(v): if distribution == "uniform": profile_tensors.append( tvm.nd.array( - np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype( - arg.dtype - ), + np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(arg.dtype), device=device, - ) - ) + )) elif distribution == "onefill": profile_tensors.append( tvm.nd.array( np.ones([var_wrapper(i) for i in arg.shape]).astype(arg.dtype), device=device, - ) - ) + )) else: raise ValueError("Not supported distribution: ", distribution) return profile_tensors -def apply_and_build_parallel( - func, configs, arch, num_repeats=3, max_workers=10, data_distribution="uniform" -) -> CompileResult: +def apply_and_build_parallel(func, + configs, + arch, + num_repeats=3, + max_workers=10, + data_distribution="uniform") -> CompileResult: cpresults = [] - profile_tensors = get_dummy_input_arrays( - func, arch.device, distribution=data_distribution - ) + profile_tensors = get_dummy_input_arrays(func, arch.device, distribution=data_distribution) max_workers = min(len(configs), os.cpu_count(), max_workers) # apply config in thread parallel _sched: List[Schedule] = [] + def _apply_schedule(f, c): try: sch = _apply_config(f, c) except Exception as apply_schedule_error: - logger.debug("Apply schedule failed: ", apply_schedule_error) + logger.debug("Apply schedule failed: {}".format(apply_schedule_error)) sch = None return sch - with ThreadPoolExecutor(max_workers=4) as schduler: - futures = { - schduler.submit(_apply_schedule, func, config) - for config in configs - } + + with ThreadPoolExecutor(max_workers=4) as scheduler: + futures = {scheduler.submit(_apply_schedule, func, config) for config in configs} for future in as_completed(futures): _sched.append(future.result()) @@ -195,24 +192,15 @@ def _build(context) -> str: @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) def tvm_callback_cuda_postproc(code, _): - index = code.index("{", match_global_kernel(code)) - if not isinstance(config.rasterization_plan, NoRasterization): - factor = config.rasterization_plan.panel_width_ - rasterization_code = get_rasterization_code(factor) - code = code[: index + 2] + rasterization_code + code[index + 2 :] code = tensor_replace_dp4a(code) return code - with tvm.transform.PassContext( - config={"tir.use_async_copy": True, **config.pass_context} - ): + with tvm.transform.PassContext(config={"tir.use_async_copy": True, **config.pass_context}): rt_mod = tvm.build(mod, target=arch.target) from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel - artifact_path = os.path.join( - tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format - ) + artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format) code = rt_mod.imported_modules[0].get_source() rt_mod.export_library(artifact_path, fcompile=tar) return idx, code, artifact_path @@ -220,14 +208,14 @@ def tvm_callback_cuda_postproc(code, _): _mods = [sch.mod if sch is not None else None for sch in _sched] for map_result in builder.map_with_error_catching( - _build, + _build, [(i, mod, arch) for i, mod in enumerate(_mods)], ): if map_result.status == StatusKind.TIMEOUT: logger.debug("LocalBuilder: Timeout") elif map_result.status == StatusKind.EXCEPTION: # TODO(lei): redirect the exception to file if needed - logger.debug("LocalBuilder: An exception occurred ", map_result.value) + logger.debug("LocalBuilder: An exception occurred {}".format(map_result.value)) continue elif map_result.status == StatusKind.COMPLETE: idx, code, artifact_path = map_result.value @@ -239,8 +227,7 @@ def tvm_callback_cuda_postproc(code, _): rt_mod = tvm.runtime.load_module(artifact_path) cpresult = CompileResult(config, sch, rt_mod) timer_cuda_mod = rt_mod.time_evaluator( - rt_mod.entry_name, arch.device, number=num_repeats - ) + rt_mod.entry_name, arch.device, number=num_repeats) cpresult.profile_tensors = profile_tensors cpresult.time_evaluator = timer_cuda_mod cpresult.code = code @@ -260,7 +247,7 @@ def tvm_callback_cuda_postproc(code, _): logger.debug("Evaluation with config failed: ", e_mesg) continue logger.info("Evaluation with config {}".format(config)) - logger.info("Time cost of this config: {:.3f} ms".format(latency * 1e3)) + logger.info("Time cost of this config: {:.3f} ms".format(latency)) cpresult.latency = latency if latency < best_latency: @@ -278,7 +265,8 @@ def apply_and_build( data_distribution="uniform", ) -> Tuple[List[CompileResult], CompileResult]: max_workers = 10 if parallel_build else 1 - return apply_and_build_parallel(func, configs, arch, max_workers=max_workers, data_distribution=data_distribution) + return apply_and_build_parallel( + func, configs, arch, max_workers=max_workers, data_distribution=data_distribution) def fast_tune( @@ -288,6 +276,10 @@ def fast_tune( parallel_build: bool = True, data_distribution: Literal["uniform", "onefill"] = "uniform", ): + # check the function is a primfunc + if not isinstance(func, tir.PrimFunc): + raise ValueError("Only support func is PrimFunc") # pragma: no cover + if target.kind.name != "cuda": logger.error("Only support CUDA target") return None, None @@ -299,32 +291,28 @@ def fast_tune( if not all([isinstance(v.value, int) for v in opt_shapes.values()]): logger.error("The opt_shapes should be int value") return None, None - # currently only support one dynmaic range + # currently only support one dynamic range if len(opt_shapes) > 1: logger.error("Currently only support one dynamic range") return None, None for buffer in func.buffer_map.values(): for axis in buffer.shape: - if isinstance(axis, tvm.tir.Var): - if axis.name not in opt_shapes: - raise NotImplementedError( - "Currently do not support fast tune with none-dynamic range set" - ) + if isinstance(axis, tvm.tir.Var) and axis.name not in opt_shapes: + raise NotImplementedError( + "Currently do not support fast tune with none-dynamic range set") if opt_shapes: for name, shape in opt_shapes.items(): var = find_var_from_func(func, name) - specilized_func = func.specialize( - {var: shape.astype(var.dtype)} - ).with_attr("is_specialized") + specilized_func = func.specialize({ + var: shape.astype(var.dtype) + }).with_attr("is_specialized") arch = CUDA(target) policy = DefaultPolicy(func=func, arch=arch) try: - specilized_func, tags = get_tensorized_func_and_tags( - specilized_func, arch.target - ) + specilized_func, tags = get_tensorized_func_and_tags(specilized_func, arch.target) except Exception as e_msg: logger.debug("Get tensorized func and tags failed: ", e_msg) tags = None @@ -380,10 +368,9 @@ def serialize_name(opt_shapes: Dict): for buf in buffers_to_declare: body = tvm.tir.DeclBuffer(buf, body=body) - # devide func must be private - device_func = tvm.tir.PrimFunc(params, body, ret_type, attrs=attrs).without_attr( - "global_symbol" - ) + # device func must be private + device_func = tvm.tir.PrimFunc( + params, body, ret_type, attrs=attrs).without_attr("global_symbol") return global_symbol, device_func @@ -415,7 +402,7 @@ def create_dispatch_func(g_var: str, func: tir.PrimFunc, refactored_funcs: List[ global_symbols.append(g_var) # TODO(lei): general the dispatch function to support multiple dynamic symbolics - assert len(dyn_symbolic) == 1, "Only support one dyanmic symbolics currently" + assert len(dyn_symbolic) == 1, "Only support one dynamic symbolics currently" ib = tvm.tir.ir_builder.create() syb = list(dyn_symbolic)[-1] @@ -431,29 +418,26 @@ def create_dispatch_func(g_var: str, func: tir.PrimFunc, refactored_funcs: List[ with ib.if_scope(syb > last_range): ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) stmt = ib.get() - dispatch_func = tvm.tir.PrimFunc( - params, stmt, ret_type, buffer_map, attrs - ).with_attrs({"tir.is_global_func": True, "global_symbol": global_symbol}) + dispatch_func = tvm.tir.PrimFunc(params, stmt, ret_type, buffer_map, attrs).with_attrs({ + "tir.is_global_func": True, + "global_symbol": global_symbol + }) return dispatch_func -def create_dispatch_mod( - g_var: str, original_func: tir.PrimFunc, specialized_funcs: List[tir.PrimFunc] -) -> IRModule: +def create_dispatch_mod(g_var: str, original_func: tir.PrimFunc, + specialized_funcs: List[tir.PrimFunc]) -> IRModule: dispatch_mod: IRModule = tvm.IRModule() g_var_supply = GlobalVarSupply(dispatch_mod) refactored_funcs = [] for func in specialized_funcs: params, buffers_to_declare = collect_buffers_to_declare(func) - global_symbol, device_func = refactor_specialized_func( - g_var, func, params, buffers_to_declare - ) + global_symbol, device_func = refactor_specialized_func(g_var, func, params, + buffers_to_declare) global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False) dispatch_mod[global_symbol] = device_func refactored_funcs.append((global_symbol, device_func)) - dispatch_func = create_dispatch_func( - g_var, original_func, refactored_funcs=refactored_funcs - ) + dispatch_func = create_dispatch_func(g_var, original_func, refactored_funcs=refactored_funcs) dispatch_mod.update(tvm.IRModule.from_expr(dispatch_func)) return dispatch_mod @@ -464,15 +448,17 @@ def fast_tune_with_dynamic_range( topk: int = 10, parallel_build: bool = True, global_symbol: Optional[str] = None, - dynamic_range: Dict[str, List[int]] = {}, + dynamic_range: Optional[Dict[str, List[int]]] = None, ) -> IRModule: + if dynamic_range is None: + dynamic_range = {} if target.kind.name != "cuda": logger.error("Only support CUDA target") return None if not global_symbol: global_symbol = func.attrs["global_symbol"] - # set opt_shapes for the primfunc with dynamc symbolic + # set opt_shapes for the primfunc with dynamic symbolic opt_shapes: Dict[str, List[int]] = {} for buffer in func.buffer_map.values(): for axis in buffer.shape: @@ -480,21 +466,16 @@ def fast_tune_with_dynamic_range( if axis.name in dynamic_range: opt_shapes[axis.name] = dynamic_range[axis.name] else: - raise ValueError( - f"[BitBLAS] The axis {axis.name} is not in dynamic_range" - ) + raise ValueError(f"[BitBLAS] The axis {axis.name} is not in dynamic_range") func = func.with_attr("opt_shapes", opt_shapes) if "opt_shapes" not in func.attrs: - print( - "[BitBLAS] The primfunc has no opt_shapes, please set opt_shapes for the primfunc" - ) + logger.error( + "[BitBLAS] The primfunc has no opt_shapes, please set opt_shapes for the primfunc") return None else: # should be list value - if not all( - [isinstance(v, tvm.ir.Array) for v in func.attrs["opt_shapes"].values()] - ): + if not all([isinstance(v, tvm.ir.Array) for v in func.attrs["opt_shapes"].values()]): logger.error("The opt_shapes should be list value") return None @@ -505,9 +486,7 @@ def fast_tune_with_dynamic_range( product_list = list(itertools.product(*(opt_shapes[key] for key in opt_shapes))) # Convert the Cartesian product to a list of dictionaries - specialize_items: List[Dict] = [ - dict(zip(opt_shapes.keys(), values)) for values in product_list - ] + specialize_items: List[Dict] = [dict(zip(opt_shapes.keys(), values)) for values in product_list] specilized_tuned_funcs: List[tir.PrimFunc] = [] for item in specialize_items: diff --git a/python/bitblas/cache/__init__.py b/python/bitblas/cache/__init__.py index ff40f4c2a..0c8fd3b9c 100644 --- a/python/bitblas/cache/__init__.py +++ b/python/bitblas/cache/__init__.py @@ -1,4 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .operator import global_operator_cache +from .operator import ( + global_operator_cache, # noqa: F401 + load_global_ops_cache, # noqa: F401 + get_database_path, # noqa: F401 + set_database_path, # noqa: F401 +) diff --git a/python/bitblas/cache/operator.py b/python/bitblas/cache/operator.py index da7566843..08c3ac185 100644 --- a/python/bitblas/cache/operator.py +++ b/python/bitblas/cache/operator.py @@ -1,54 +1,174 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +import bitblas from bitblas.ops.operator import OperatorConfig, Operator +from dataclasses import asdict +import os +import json +import tempfile +from hashlib import sha256 +import shutil +import tvm +from tvm.contrib.tar import tar +import logging + +logger = logging.getLogger(__name__) +BITBLAS_DATABASE_PATH = os.path.expanduser("~/.cache/bitblas") -class OperatorCache(object): + +class OperatorCache: """ - A cache manager for operator instances, such as Matmul, Convolution, etc., - keyed by their configuration objects. Supports adding, retrieving, and - checking the existence of operator instances based on their unique configurations. + Manages a cache for operator instances (e.g., Matmul, Convolution) based on their configurations. """ def __init__(self): - """ - Initializes the cache. - """ self.cache = {} def add(self, config: OperatorConfig, op_inst: Operator): - """ - Adds an operator instance to the cache with the given configuration. - - Parameters: - - config: A hashable configuration object that uniquely identifies the operator instance. - - op_inst: The instance of the operator to cache. - """ self.cache[config] = op_inst def get(self, config: OperatorConfig): - """ - Retrieves an operator instance from the cache based on the given configuration. + return self.cache.get(config) - Parameters: - - config: The configuration object that uniquely identifies the operator instance. + def exists(self, config): + return config in self.cache - Returns: - The cached operator instance if present; otherwise, None. - """ - return self.cache.get(config, None) + def clear(self): + self.cache.clear() - def exists(self, config): - """ - Checks if an operator instance with the given configuration exists in the cache. + def size(self): + return len(self.cache) - Parameters: - - config: The configuration object that uniquely identifies the operator instance. + def save_into_database(self, database_path=None, target=None): + database_path = self._ensure_database_path(database_path) + for config, op_inst in self.cache.items(): + arch_str = self._determine_arch_str(op_inst, target) + arch_path = os.path.join(database_path, arch_str) + self._ensure_directory(arch_path) + hash_str = sha256(repr(config).encode()).hexdigest() + config_path = os.path.join(arch_path, hash_str) + # if the config already exists, skip saving + if os.path.exists(config_path): + continue + self._ensure_directory(config_path) + self._save_operator_config_and_artifact(config, op_inst, config_path) + + def load_from_database(self, database_path, target=None): + if not os.path.exists(database_path): + logger.info( + f"Database path {database_path} does not exist, skipping loading operators from the database" + ) + return + arch_str = self._determine_target_arch_str(target) + arch_path = os.path.join(database_path, arch_str) + if not os.path.exists(arch_path): + logger.info( + f"Target {arch_str} does not exist in the database, skipping loading operators from the database" + ) + return + self._load_operators_from_arch_path(arch_path, target) + + def _ensure_database_path(self, database_path): + if database_path is None: + return tempfile.mkdtemp() + os.makedirs(database_path, exist_ok=True) + return database_path + + def _determine_arch_str(self, op_inst, target): + return (target if target else "-".join(list(op_inst.target.keys) + [op_inst.target.arch])) + + def _ensure_directory(self, path): + os.makedirs(path, exist_ok=True) + + def _save_operator_config_and_artifact(self, config, op_inst, config_path): + config_type, operator_type = type(config).__name__, type(op_inst).__name__ + with open(os.path.join(config_path, f"{config_type}.json"), "w") as json_file: + json.dump(asdict(config), json_file) + artifact_path = os.path.join(config_path, "tvm_rt_mod." + tar.output_format) + try: + op_inst.rt_mod.export_library(artifact_path, fcompile=tar) + except Exception as e: + # library does not support export_library + export_error = e # noqa: F841 + pass + json_data = {"config_type": config_type, "operator_type": operator_type} + json_file_path = os.path.join(config_path, "mapping.json") + with open(json_file_path, "w") as json_file: + json.dump(json_data, json_file) + + # For writing source.cu file + source_file_path = os.path.join(config_path, "source.cu") + with open(source_file_path, "w") as source_file: + source_file.write(op_inst.get_source()) + + # For writing optimized.py file + optimized_file_path = os.path.join(config_path, "optimized.py") + with open(optimized_file_path, "w") as optimized_file: + optimized_file.write(op_inst.optimized_func.script(show_meta=False)) + if op_inst.wrapper.lib_name is not None: + # copy lib name to the same directory as the artifact + src_name = op_inst.wrapper.src_name + shutil.copy( + src_name, + os.path.join(config_path, os.path.basename("wrapper_source.cu")), + ) + lib_name = op_inst.wrapper.lib_name + shutil.copy( + lib_name, + os.path.join(config_path, os.path.basename("wrapper_compiled.so")), + ) + + def _determine_target_arch_str(self, target): + return (target if isinstance(target, str) else "-".join(list(target.keys) + [target.arch])) + + def _load_operators_from_arch_path(self, arch_path, target): + for root, dirs, _ in os.walk(arch_path): + for directory in dirs: + config_path = os.path.join(root, directory) + self._load_operator(config_path, target) + + def _load_operator(self, config_path, target): + mapping, config, rt_mod, lib_name = None, None, None, None + for file in os.listdir(config_path): + full_path = os.path.join(config_path, file) + if file == "mapping.json": + with open(full_path) as f: + mapping = json.load(f) + elif file.endswith(".json"): + with open(full_path) as f: + config = json.load(f) + elif file.endswith(".tar"): + rt_mod = tvm.runtime.load_module(full_path) + elif file == "wrapper_compiled.so": + lib_name = full_path + + if mapping and config and rt_mod: + self._instantiate_and_add_operator(mapping, config, rt_mod, lib_name, target) + + def _instantiate_and_add_operator(self, mapping, config, rt_mod, lib_name, target): + config_cls = getattr(bitblas, mapping["config_type"]) + operator_cls = getattr(bitblas, mapping["operator_type"]) + op_inst = operator_cls(config=config_cls(**config), target=target, enable_tuning=False) + op_inst.update_runtime_module(rt_mod, lib_name=lib_name) + self.add(config_cls(**config), op_inst) - Returns: - True if the instance exists in the cache; otherwise, False. - """ - return config in self.cache global_operator_cache = OperatorCache() + + +def load_global_ops_cache(database_path=BITBLAS_DATABASE_PATH, target=None): + if target is None: + target = bitblas.auto_detect_nvidia_target() + global_operator_cache.load_from_database(database_path, target) + return global_operator_cache + + +def get_database_path(): + return BITBLAS_DATABASE_PATH + + +def set_database_path(path): + global BITBLAS_DATABASE_PATH + BITBLAS_DATABASE_PATH = path + return BITBLAS_DATABASE_PATH diff --git a/python/bitblas/generator.py b/python/bitblas/generator.py index 4cbe697e2..4ac6f2be2 100644 --- a/python/bitblas/generator.py +++ b/python/bitblas/generator.py @@ -1,17 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + class BitBLASGenerator: - def __init__(self, input_size, data_type='float', optimization_level=1): - self.input_size = input_size - self.data_type = data_type - self.optimization_level = optimization_level - # 其他初始化代码 + + def __init__(self): + # Initialize the generator with configuration + pass def generate_cuda_code(self): - # 生成CUDA代码的逻辑 pass def generate_header(self): - # 生成Header文件的逻辑 pass diff --git a/python/bitblas/gpu/__init__.py b/python/bitblas/gpu/__init__.py index 9fbe8ba93..df0635b3c 100644 --- a/python/bitblas/gpu/__init__.py +++ b/python/bitblas/gpu/__init__.py @@ -1,21 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - """ GPU-generic schedule rules. For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead """ -from .fallback import Fallback -from .element_wise import ElementWise -from .gemv import GEMV -from .general_reduction import GeneralReduction +from .fallback import Fallback # noqa: F401 +from .element_wise import ElementWise # noqa: F401 +from .gemv import GEMV # noqa: F401 +from .gemv_dequantize import GEMVWithDequantizeInfo # noqa: F401 +from .general_reduction import GeneralReduction # noqa: F401 from .matmul import ( - Matmul, - MatmulTensorizationMMA, - MatmulTensorizationWMMA, - MatmulTensorizationLegacy, + Matmul, # noqa: F401 + MatmulTensorizationMMA, # noqa: F401 + MatmulTensorizationWMMA, # noqa: F401 +) +from .matmul_mma_dequantize import ( + MatmulTensorizationMMAWithDequantizeInfo, # noqa: F401 ) -from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo +from .matmul_wmma import MatmulTensorizationLegacy # noqa: F401 -from .reduction import Reduction -from .transpose import Transpose +from .reduction import Reduction # noqa: F401 +from .transpose import Transpose # noqa: F401 diff --git a/python/bitblas/gpu/gemv.py b/python/bitblas/gpu/gemv.py index 81a3d48af..33388bffe 100644 --- a/python/bitblas/gpu/gemv.py +++ b/python/bitblas/gpu/gemv.py @@ -17,7 +17,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm gemv.py in dlight. """A rule for GEMV and DecodeGEMV.""" @@ -25,7 +25,6 @@ from functools import reduce from typing import List, Optional, Union, Dict -from tvm.tir.function import PrimFunc from tvm import DataType, arith, ir, tir from tvm.target import Target @@ -41,9 +40,7 @@ ) from .base import GPUScheduleRule from .gemv_dequantize import GEMVWithDequantizeInfo -from ..base.analysis import ( - get_coalesced_veclen -) + def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: # Detect and return `Y` in `X[...] = X[...] + Y` @@ -53,9 +50,9 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: if not isinstance(buffer_store.value, tir.Add): return None if not ir.structural_equal( - buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), - map_free_vars=True, + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, ): return None return buffer_store.value.b @@ -99,13 +96,8 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe conditions.append(len(block_stmt.writes) == 1) conditions.append(_get_reduction_expr(block_stmt) is not None) conditions.append( - len( - collect_block_iter_vars_used_in_access_region( - block_stmt, block_stmt.writes[0].region - ) - ) - > 0 - ) + len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) + > 0) if not all(conditions): return None @@ -113,10 +105,8 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe ret = [ read.buffer for read in block_stmt.reads - if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) - < iter_num - and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) - > 0 + if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num + and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 ] if len(ret) == len(block_stmt.reads): func = sch.mod["main"] @@ -145,15 +135,12 @@ def normalize( collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) for buf in block_stmt.writes ] - buffers_use_vars.extend( - [ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.reads - ] - ) + buffers_use_vars.extend([ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.reads + ]) if collect_vars_used_in_prim_expr(access.base) & set( - iter_var.var for iter_var in block_stmt.iter_vars - ): + iter_var.var for iter_var in block_stmt.iter_vars): return None iter_to_info = {i.var: i for i in block_info.iters} batch_loops, s_loops, r_loops, c_loops = [], [], [], [] @@ -204,6 +191,9 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- ) -> Union[None, tir.Schedule, List[tir.Schedule]]: if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None + if "dequantize_info" in func.attrs: + dequantize_rule = GEMVWithDequantizeInfo() + return dequantize_rule.apply(func, target, False) sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) block_infos = try_inline_contiguous_spatial(sch, block_infos) @@ -236,9 +226,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) return sch else: - return self.sch_outer_reduction( - sch, target, block, vector_input_buffers, epilogue - ) + return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, @@ -276,28 +264,21 @@ def apply( _, s, r, c = sch.get_loops(block=gemv) s = sch.fuse(_, s) r = sch.fuse(r, c) - bx, ts, tile_s = sch.split( - s, factors=[None, TS, TILE_S], preserve_unit_iters=True - ) + bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S], preserve_unit_iters=True) r, tr, tile_r_vec_n, vec_c = sch.split( - r, factors=[None, TR, TILE_R // VEC_C, VEC_C], preserve_unit_iters=True - ) + r, factors=[None, TR, TILE_R // VEC_C, VEC_C], preserve_unit_iters=True) sch.reorder(r, tile_r_vec_n, tr, vec_c) tr_vec_c = sch.fuse(tr, vec_c) rf = sch.rfactor(tr_vec_c, 0) # rfactor: reduce to tx bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv) - tr, vec_c = sch.split( - tr_vec_c, factors=[TR, None], preserve_unit_iters=True - ) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) rf2 = sch.rfactor(tr, 0) # bind, vectorize compute bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf) - tr, vec_c = sch.split( - tr_vec_c, factors=[TR, None], preserve_unit_iters=True - ) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c) sch.bind(bx, "blockIdx.x") sch.bind(ts, TAG_S) @@ -306,19 +287,16 @@ def apply( shared_mem_usage = 0 for buf in vector_input_buffers: - buf_size = reduce( - lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1) - ) * get_bytes(buf.dtype) + buf_size = reduce(lambda x, y: x * y, buf.shape, tir.IntImm( + buf.shape[0].dtype, 1)) * get_bytes(buf.dtype) shared_mem_usage += buf_size try: max_shared_memory_per_block = target.max_shared_memory_per_block - except: + except Exception: max_shared_memory_per_block = 49152 LOAD_V_SHARED = ( - LOAD_V_SHARED - and isinstance(shared_mem_usage, tir.IntImm) - and shared_mem_usage.value <= max_shared_memory_per_block - ) + LOAD_V_SHARED and isinstance(shared_mem_usage, tir.IntImm) and + shared_mem_usage.value <= max_shared_memory_per_block) # vectorize load A # (TODO) this is now actually problematic since the number of loops is dependent on the @@ -327,20 +305,15 @@ def apply( sch.compute_at(Aq_local, r, preserve_unit_loops=True) s_local, r_local = sch.get_loops(block=Aq_local)[-2:] s_local, vec_load = sch.split( - s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True - ) - sch.reorder( - s_local, r_local, vec_load - ) # either s_local or r_local should be 1 + s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True) + sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1 sch.vectorize(vec_load) # load vector into shared memory, shape should be the whole vector if LOAD_V_SHARED: - V_shared = sch.cache_read( - rf, read_buffer_index=0, storage_scope="shared" - ) + V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") sch.compute_at(V_shared, tr, preserve_unit_loops=True) - l = sch.get_loops(block=V_shared)[-1] + l = sch.get_loops(block=V_shared)[-1] # noqa: E741 loop: tir.For = sch.get(l) if isinstance(loop.extent, tir.IntImm): # avoid introducing predicates when vector length is too large @@ -349,9 +322,7 @@ def apply( get_max_factor( (int)(loop.extent), [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS * TR * 8], - ) - // TS - // TR, + ) // TS // TR, LOAD_V_VEC, ), 1, @@ -360,12 +331,10 @@ def apply( vec_length = LOAD_V_VEC if TAG_R == "threadIdx.x": _, ty, tx, vec = sch.split( - l, factors=[None, TS, TR, vec_length], preserve_unit_iters=True - ) + l, factors=[None, TS, TR, vec_length], preserve_unit_iters=True) else: _, ty, tx, vec = sch.split( - l, factors=[None, TR, TS, vec_length], preserve_unit_iters=True - ) + l, factors=[None, TR, TS, vec_length], preserve_unit_iters=True) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") sch.vectorize(vec) @@ -374,9 +343,7 @@ def apply( sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split( - ts_tile_s, factors=[TS, None], preserve_unit_iters=True - ) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) tile_s, vec_s = sch.split( tile_s, factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], @@ -391,9 +358,7 @@ def apply( sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) tr, *ts_tile_s = sch.get_loops(block=gemv)[1:] ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split( - ts_tile_s, factors=[TS, None], preserve_unit_iters=True - ) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) sch.reorder(tile_s, ts, tr) sch.bind(ts, TAG_S) sch.bind(tr, TAG_R) @@ -447,15 +412,13 @@ def apply( sch.reverse_compute_at(epilogue, bx) sch.set_scope(block, 0, "shared") _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name - _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) + _, tx = sch.split(sch.fuse(*s), factors=[None, TS]) sch.bind(tx, "threadIdx.x") else: sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) ts_tile_s = sch.get_loops(epilogue)[-1] - ts, tile_s = sch.split( - ts_tile_s, factors=[TS, None], preserve_unit_iters=True - ) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) sch.bind(ts, TAG_S) sch.set_scope(block, 0, "local") # pylint: enable=invalid-name @@ -546,13 +509,8 @@ def apply( TILE_S, TILE_R = ( 1, - ( - len_c - if len_c > 1 - else max( - get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1 - ) - ), + (len_c if len_c > 1 else max( + get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1)), ) VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) VEC_LOAD = 1 @@ -639,15 +597,9 @@ def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many dom_kind = block.dom_kind() block = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block) - ] - ) - or len(sch.get_loops(block)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): continue for loop, iter_type in zip(sch.get_loops(block), dom_kind): @@ -665,7 +617,7 @@ def prod(iterable): vec = 1 if len(config.vectorize): - vec = list(config.vectorize.values())[-1] + vec = list(config.vectorize.values())[-1] num_warps = int(prod(config.thread)) warp_size = int(prod(config.reduce_thread)) @@ -679,13 +631,13 @@ def prod(iterable): sch.compute_inline(block) try: i, j, k = sch.get_loops(block_b) - except: + except Exception: j, k = sch.get_loops(block_b) block_local_A = sch.cache_read(block_b, 0, "local") block_local_B = sch.cache_read(block_b, 1, "local") block_local_C = sch.cache_write(block_b, 0, "local") # reverse inline - if reduction_block != None and reduction_block != output_blocks[0]: + if reduction_block is not None and reduction_block != output_blocks[0]: sch.reverse_compute_inline(output_blocks[0]) bx, j = sch.split(j, factors=[None, num_warps]) @@ -729,15 +681,9 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many dom_kind = block.dom_kind() block = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block) - ] - ) - or len(sch.get_loops(block)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): continue for loop, iter_type in zip(sch.get_loops(block), dom_kind): @@ -758,17 +704,11 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many thrd_axis = [] tile_axis = [] # for gemv, we should skip dynamic symbolic in s_loops - s_loops = [ - loop for loop in s_loops if isinstance(sch.get(loop).extent, tir.IntImm) - ] - assert len(s_loops) == len( - config.block - ), f"{len(s_loops)} != {len(config.block)}" + s_loops = [loop for loop in s_loops if isinstance(sch.get(loop).extent, tir.IntImm)] + assert len(s_loops) == len(config.block), f"{len(s_loops)} != {len(config.block)}" for i, loop in enumerate(s_loops): if sch.get(loop).extent % config.block[i]: - raise NotImplementedError( - "Undivisible block in TIR schedule is still buggy." - ) + raise NotImplementedError("Undivisible block in TIR schedule is still buggy.") bx, _t = sch.split(loop, factors=[None, config.block[i]]) blck_axis.append(bx) if config.step[i] > 1: @@ -791,13 +731,7 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many vthd_axis = list(reversed(vthd_axis)) # inner virtual thread first axis_order = ( - blck_axis - + vthd_axis - + thrd_axis - + reduce_outer_axis - + reduce_inner_axis - + tile_axis - ) + blck_axis + vthd_axis + thrd_axis + reduce_outer_axis + reduce_inner_axis + tile_axis) sch.reorder(*axis_order) blck_fused = sch.fuse(*blck_axis) @@ -813,7 +747,7 @@ def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many sch.reverse_compute_at(CL, thrd_fused) if len(tile_axis) > 0: - for ax in sch.get_loops(CL)[-len(tile_axis) :]: + for ax in sch.get_loops(CL)[-len(tile_axis):]: sch.unroll(ax) sch.decompose_reduction(C, reduce_outer_axis[0]) diff --git a/python/bitblas/gpu/gemv_dequantize.py b/python/bitblas/gpu/gemv_dequantize.py index ddbdaf3b8..fbdee9c9c 100644 --- a/python/bitblas/gpu/gemv_dequantize.py +++ b/python/bitblas/gpu/gemv_dequantize.py @@ -1,13 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """A rule for GEMV and DecodeGEMV.""" -import re from functools import reduce -from typing import List, Optional, Union, Dict - -from tvm.tir.function import PrimFunc -from tvm import DataType, arith, ir, tir +from typing import List, Dict from tvm.target import Target +from tvm.tir.function import PrimFunc +from tvm import DataType, tir import logging from ..base import ( normalize_prim_func, @@ -15,6 +13,7 @@ get_block, ) from .base import GPUScheduleRule +from .matmul_analysis import auto_inline_producers, auto_inline_consumers logger = logging.getLogger(__name__) @@ -22,13 +21,175 @@ class GEMVWithDequantizeInfo(GPUScheduleRule): """A rule for Dequantized GEMV.""" + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ): + sch = tir.Schedule(func) + from .intrin import get_lop3_intrin_group + + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + if not check_dequantize_info(dequantize_info): + logger.debug("Dequantize info is not valid") + return None + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append( + weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + if not check_weight_decode_info(weight_decode_info): + logger.debug("Weight Dequantize info is not valid") + return None + + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + reduction_block: tir.schedule.BlockRV = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + if len(r_loops) > 0: + reduction_block = block + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + def get_vectorize_factor(target_format): + # coalesced access requires the vectorize factor to be the same as the transaction size + return 128 // DataType(target_format).bits + + vec = get_vectorize_factor(weight_decode_info["target_format"]) + num_warps = 1 + warp_size = 32 + + block_b = reduction_block + output_blocks = get_output_blocks(sch, block_infos) # noqa: F841 + B_decode_block = get_block(sch, block_infos, weight_decode_info["decode_block"]) + + block_decode_B = sch.cache_read(block_b, 1, "local") + sch.compute_inline(B_decode_block) + + j, k = sch.get_loops(block_b)[-2:] + + # get target dequantize buffer's idx + def get_idx(weight_decode_info: Dict): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + block_shared_local_A = sch.cache_read(block_b, 0, "local") + block_shared_local_B = sch.cache_read(block_decode_B, get_idx(weight_decode_info), "local") + block_local_C = sch.cache_write(block_b, 0, "local") + + auto_inline_producers(sch, block_shared_local_B) + auto_inline_consumers(sch, block_local_C) + + bx, j = sch.split(j, factors=[None, num_warps]) + k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) + # for dp4a/hfma2 + inst_factor = 2 if weight_decode_info["target_format"] == "float16" else 4 + _, vk = sch.split(vk, factors=[None, inst_factor]) + sch.reorder(bx, j, k, tx) + + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.bind(j, "threadIdx.y") + + self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] + self.grid_size = [sch.get(bx).extent, 1, 1] + + sch.compute_at(block_decode_B, tx, preserve_unit_loops=True) + sch.compute_at(block_shared_local_A, tx, preserve_unit_loops=True) + sch.compute_at(block_shared_local_B, tx, preserve_unit_loops=True) + sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) + + block_local_a_v = sch.get_loops(block_shared_local_A)[-1] + sch.vectorize(block_local_a_v) + block_local_b_v = sch.get_loops(block_shared_local_B)[-1] + sch.vectorize(block_local_b_v) + + skip_blocks = [block_shared_local_B] + + if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized": + if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: + block_local_scales = sch.cache_read(block_decode_B, + get_idx(weight_decode_info) + 1, "local") + sch.compute_at(block_local_scales, tx, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + skip_blocks.append(block_local_scales) + + if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]: + block_local_zeros = sch.cache_read(block_decode_B, + get_idx(weight_decode_info) + 2, "local") + sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + skip_blocks.append(block_local_zeros) + + auto_inline_producers(sch, block_decode_B, skip_blocks) + + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"]) + sch.annotate(block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"]) + return sch + def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements self, func: tir.PrimFunc, config, ): sch = tir.Schedule(func) - from .intrin.lop3 import get_lop3_intrin_group + from .intrin import get_lop3_intrin_group dequantize_info = func.attrs["dequantize_info"] @@ -47,21 +208,15 @@ def check_dequantize_info(dequantize_info): def check_weight_decode_info(weight_decode_info): conditions = [] - # check source format in ["int", "fp", "af"] + # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) conditions.append( - weight_decode_info["source_format"]["format"] - in ["uint", "int", "fp", "af"] - ) + weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) # check source bits in [1, 2, 4, 8] - conditions.append( - weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8] - ) + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] conditions.append("target_format" in weight_decode_info) - conditions.append( - weight_decode_info["target_format"] in ["float16", "int8"] - ) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) return all(conditions) if not check_weight_decode_info(weight_decode_info): @@ -81,15 +236,9 @@ def check_weight_decode_info(weight_decode_info): dom_kind = block.dom_kind() block = block.block_rv - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block) - ] - ) - or len(sch.get_loops(block)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): continue for loop, iter_type in zip(sch.get_loops(block), dom_kind): @@ -104,7 +253,7 @@ def prod(iterable): return reduce(lambda x, y: x * y, iterable, 1) def get_vectorize_factor(target_format): - # coalseced access requires the vectorize factor to be the same as the transaction size + # coalesced access requires the vectorize factor to be the same as the transaction size return config.arch.transaction_size[-1] // DataType(target_format).bits vec = get_vectorize_factor(weight_decode_info["target_format"]) @@ -112,15 +261,9 @@ def get_vectorize_factor(target_format): warp_size = int(prod(config.reduce_thread)) block_b = reduction_block - output_blocks = get_output_blocks(sch, block_infos) + output_blocks = get_output_blocks(sch, block_infos) # noqa: F841 B_decode_block = get_block(sch, block_infos, weight_decode_info["decode_block"]) - # compute inline - for block_info in reversed(block_infos): - block = block_info.block_rv - if block not in (reduction_block, *output_blocks, B_decode_block): - sch.compute_inline(block) - block_decode_B = sch.cache_read(block_b, 1, "local") sch.compute_inline(B_decode_block) @@ -129,20 +272,18 @@ def get_vectorize_factor(target_format): # get target dequantize buffer's idx def get_idx(weight_decode_info: Dict): # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structual based way + # maybe we can use a more general and structural based way # to analysis the idx - if weight_decode_info["source_format"]["format"] == "af": + if weight_decode_info["source_format"]["format"] == "nf": return 1 return 0 block_shared_local_A = sch.cache_read(block_b, 0, "local") - block_shared_local_B = sch.cache_read( - block_decode_B, get_idx(weight_decode_info), "local" - ) + block_shared_local_B = sch.cache_read(block_decode_B, get_idx(weight_decode_info), "local") block_local_C = sch.cache_write(block_b, 0, "local") - # reverse inline - if reduction_block != None and reduction_block != output_blocks[0]: - sch.reverse_compute_inline(output_blocks[0]) + + auto_inline_producers(sch, block_shared_local_B) + auto_inline_consumers(sch, block_local_C) bx, j = sch.split(j, factors=[None, num_warps]) k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) @@ -167,10 +308,27 @@ def get_idx(weight_decode_info: Dict): sch.vectorize(block_local_a_v) block_local_b_v = sch.get_loops(block_shared_local_B)[-1] sch.vectorize(block_local_b_v) - if ( - "fast_decoding" in weight_decode_info - and weight_decode_info["fast_decoding"] - ): + + skip_blocks = [block_shared_local_B] + + if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized": + if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: + block_local_scales = sch.cache_read(block_decode_B, + get_idx(weight_decode_info) + 1, "local") + sch.compute_at(block_local_scales, tx, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + skip_blocks.append(block_local_scales) + + if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]: + block_local_zeros = sch.cache_read(block_decode_B, + get_idx(weight_decode_info) + 2, "local") + sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + skip_blocks.append(block_local_zeros) + + auto_inline_producers(sch, block_decode_B, skip_blocks) + + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): source_bit = weight_decode_info["source_format"]["bits"] out_dtype = weight_decode_info["target_format"] intrin_info = get_lop3_intrin_group( @@ -180,12 +338,10 @@ def get_idx(weight_decode_info: Dict): source_bit=source_bit, with_scaling=weight_decode_info["with_scaling"], with_zeros=weight_decode_info["with_zeros"], - zeros_type=weight_decode_info["zeros_type"], + zeros_mode=weight_decode_info["zeros_mode"], ) sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"]) - sch.annotate( - block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"] - ) + sch.annotate(block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"]) return sch def apply_config(self, func: PrimFunc, config): diff --git a/python/bitblas/gpu/intrin/__init__.py b/python/bitblas/gpu/intrin/__init__.py new file mode 100644 index 000000000..d9d9ba942 --- /dev/null +++ b/python/bitblas/gpu/intrin/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .lop3 import get_lop3_intrin_group # noqa: F401 diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py index 4ee35daaa..c9d7e5fd1 100644 --- a/python/bitblas/gpu/intrin/lop3.py +++ b/python/bitblas/gpu/intrin/lop3.py @@ -7,9 +7,9 @@ from bitblas.quantization import ( _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, + _tir_packed_to_unsigned_convert_with_zeros, ) - decode_i4_to_f16 = """ template __device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) @@ -46,17 +46,22 @@ """ decode_i4_to_f16_scale = """ -template -__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale, const int N = 8) +template +__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr) { uint *h = reinterpret_cast(B_local_decode); static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint BOTTOM_MASK = 0x000f000f; static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + #pragma unroll + // decode 2 elems at one time. for (int i = 0; i < (N / 2); i++) { @@ -64,30 +69,27 @@ : "=r"(h[i]) : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - unsigned v0 = *((unsigned short *)scale); - unsigned v1 = *((unsigned short *)scale); - unsigned __packed_scale = (v1 << 16) | v0; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__packed_scale), "r"(0)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); } - } template -__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale, const int N = 8) +__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) { - decode_i4b_to_f16_scale(_i4s, B_local_decode, scale, N); + decode_i4b_to_f16_scale(_i4s, B_local_decode, N, scale); } template -__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale, const int N = 8) +__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) { - decode_i4b_to_f16_scale(_i4u, B_local_decode, scale, N); + decode_i4b_to_f16_scale(_i4u, B_local_decode, N, scale); } + """ decode_i4_to_f16_scale_zeros_original = """ -template -__device__ void decode_i4b_to_f16_scale_zeros_original(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) +template +__device__ void decode_i4b_to_f16_zeros_original(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) { uint *h = reinterpret_cast(B_local_decode); @@ -97,6 +99,13 @@ // Minus 7 to scale the value to signed static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + + #pragma unroll // decode 2 elems at one time. for (int i = 0; i < (N / 2); i++) @@ -105,22 +114,24 @@ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" : "=r"(h[i]) : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); } } -template -__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) + +template +__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) { - decode_i4b_to_f16_scale_zeros_original(_i4u, B_local_decode, scale, zeros, N); + decode_i4b_to_f16_zeros_original(_i4u, B_local_decode, N, scale, zeros); } - """ decode_i4_to_f16_scale_zeros_rescale = """ -template -__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) +template +__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) { uint *h = reinterpret_cast(B_local_decode); @@ -130,6 +141,11 @@ // Minus 7 to scale the value to signed static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + #pragma unroll // decode 2 elems at one time. for (int i = 0; i < (N / 2); i++) @@ -138,19 +154,60 @@ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" : "=r"(h[i]) : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); } } -template -__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) + +template +__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) { - decode_i4b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, scale, zeros, N); + decode_i4b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); } """ +decode_i4_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_quantized(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_quantized(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, zero_dtype *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_quantized(_i4u, B_local_decode, N, scale, zeros); +} +""" + decode_i2_to_f16 = """ template __device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) @@ -271,7 +328,6 @@ } """ - decode_i2_to_f16_scale_zeros_rescale = """ template __device__ void decode_i2b_to_f16_scale_zeros_rescale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) @@ -309,7 +365,164 @@ } """ -decode_i1s_to_i8s_l16 = """template +decode_i1_to_f16 = """ +template +__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1s, B_local_decode, N); +} + +template +__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1u, B_local_decode, N); +} +""" + +decode_i1_to_f16_scale = """ +template +__device__ void decode_i1b_to_f16_scale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i1b_to_f16_scale(_i1s, B_local_decode, N, scale); +} +template +__device__ void decode_i1u_to_f16_scale(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i1b_to_f16_scale(_i1u, B_local_decode, N, scale); +} +""" +decode_i1_to_f16_scale_zeros_original = """ +template +__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_zeros_original(_i1u, B_local_decode, N, scale, zeros); +} +""" +decode_i1_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i1b_to_f16_scale_zeros_rescale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } +} + +template +__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1s_to_i8s = """template __device__ void decode_i1s_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) { int *i8s = reinterpret_cast(_i8s); @@ -469,7 +682,7 @@ def get_fast_decode_intrin( loops_extent=8, with_scale=False, with_zeros=False, - zeros_type="original", + zeros_mode="original", ): """ loops extent is the number of elements to be decoded in one stage @@ -487,14 +700,15 @@ def get_fast_decode_intrin( if with_scale: func_name += "_scale" if with_zeros: - func_name += f"_zeros_{zeros_type}" + func_name += f"_zeros_{zeros_mode}" assert storage_dtype in ["int8", "int32", "uint32"] storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) elem_per_unit = storage_nbit // source_bit n_storage_elems = loops_extent // elem_per_unit - - if source_format == "int": + if with_zeros and zeros_mode == "quantized": + decode_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) + elif source_format == "int": decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) elif source_format == "uint": decode_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) @@ -568,9 +782,7 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: elif with_zeros is False: @T.prim_func - def fast_decode_desc( - compressed: T.handle, decompressed: T.handle, scale: T.handle - ) -> None: + def fast_decode_desc(compressed: T.handle, decompressed: T.handle, scale: T.handle) -> None: Compressed = T.match_buffer( compressed, [ @@ -593,6 +805,7 @@ def fast_decode_desc( 1, ], dtype=target_dtype, + scope="global", ) with T.block("root"): T.reads(Compressed[0:n_storage_elems], Scale[0:1]) @@ -606,16 +819,145 @@ def fast_decode_desc( Compressed[vi // elem_per_unit], vi % elem_per_unit, dtype=target_dtype, - ) - * Scale[0] + ) * Scale[0]) + + @T.prim_func + def fast_decode_impl(compressed: T.handle, decompressed: T.handle, scale: T.handle) -> None: + s0 = T.int32() + + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + offset_factor=1, + strides=[s0], + scope="global", + ) + with T.block("root"): + T.reads(Compressed[0:n_storage_elems], Scale[0:1]) + T.writes(Decompressed[0:loops_extent]) + T.call_extern( + "handle", + func_name, + Compressed.data, + Decompressed.data, + Scale.access_ptr("r"), + loops_extent, + ) + + elif zeros_mode == "quantized": + + def get_dequantize_buffers_list(weight, scale, zeros, zeros_mode="original"): + if zeros_mode == "original": + return [weight, zeros, scale] + elif zeros_mode == "rescale": + return [weight, scale, zeros] + elif zeros_mode == "quantized": + return [weight, zeros, scale] + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + def get_dequantize_func(weight, scale, zeros, zeros_mode="original"): + if zeros_mode == "original": + return (weight - zeros) * scale + elif zeros_mode == "rescale": + return weight * scale - zeros + elif zeros_mode == "quantized": + return weight * scale + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + # Scale with Zeros + @T.prim_func + def fast_decode_desc( + compressed: T.handle, + decompressed: T.handle, + scale: T.handle, + zeros: T.handle, + ) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + scope="local", + ) + Zeros = T.match_buffer( + zeros, + [ + 1, + ], + dtype=storage_dtype, + scope="local", + ) + with T.block("root"): + T.reads(*get_dequantize_buffers_list( + Compressed[0:n_storage_elems], + Scale[0:1], + Zeros[0:1], + zeros_mode=zeros_mode, + )) + T.writes(Decompressed[0:loops_extent]) + for i in T.grid(loops_extent): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = get_dequantize_func( + decode_func( + source_bit, + Compressed[vi // elem_per_unit], + vi % elem_per_unit, + Zeros[0], + dtype=target_dtype, + ), + Scale[0], + Zeros[0], + zeros_mode, ) @T.prim_func def fast_decode_impl( - compressed: T.handle, decompressed: T.handle, scale: T.handle + compressed: T.handle, + decompressed: T.handle, + scale: T.handle, + zeros: T.handle, ) -> None: s0 = T.int32() - + s1 = T.int32() Compressed = T.match_buffer( compressed, [ @@ -640,9 +982,20 @@ def fast_decode_impl( dtype=target_dtype, offset_factor=1, strides=[s0], + scope="local", + ) + Zeros = T.match_buffer( + zeros, + [ + 1, + ], + dtype=storage_dtype, + offset_factor=1, + strides=[s1], + scope="local", ) with T.block("root"): - T.reads(Compressed[0:n_storage_elems], Scale[0:1]) + T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) T.writes(Decompressed[0:loops_extent]) T.call_extern( "handle", @@ -650,26 +1003,27 @@ def fast_decode_impl( Compressed.data, Decompressed.data, Scale.access_ptr("r"), + Zeros.access_ptr("r"), loops_extent, ) else: - def get_dequantize_buffers_list(weight, scale, zeros, zeros_type="original"): - if zeros_type == "original": + def get_dequantize_buffers_list(weight, scale, zeros, zeros_mode="original"): + if zeros_mode == "original": return [weight, zeros, scale] - elif zeros_type == "rescale": + elif zeros_mode == "rescale": return [weight, scale, zeros] else: - raise ValueError(f"Unsupported zeros_type: {zeros_type}") + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") - def get_dequantize_func(weight, scale, zeros, zeros_type="original"): - if zeros_type == "original": + def get_dequantize_func(weight, scale, zeros, zeros_mode="original"): + if zeros_mode == "original": return (weight - zeros) * scale - elif zeros_type == "rescale": + elif zeros_mode == "rescale": return weight * scale - zeros else: - raise ValueError(f"Unsupported zeros_type: {zeros_type}") + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") # Scale with Zeros @T.prim_func @@ -701,6 +1055,7 @@ def fast_decode_desc( 1, ], dtype=target_dtype, + scope="global", ) Zeros = T.match_buffer( zeros, @@ -708,16 +1063,15 @@ def fast_decode_desc( 1, ], dtype=target_dtype, + scope="global", ) with T.block("root"): - T.reads( - *get_dequantize_buffers_list( - Compressed[0:n_storage_elems], - Scale[0:1], - Zeros[0:1], - zeros_type=zeros_type, - ) - ) + T.reads(*get_dequantize_buffers_list( + Compressed[0:n_storage_elems], + Scale[0:1], + Zeros[0:1], + zeros_mode=zeros_mode, + )) T.writes(Decompressed[0:loops_extent]) for i in T.grid(loops_extent): with T.block("decode"): @@ -731,7 +1085,7 @@ def fast_decode_desc( ), Scale[0], Zeros[0], - zeros_type, + zeros_mode, ) @T.prim_func @@ -767,6 +1121,7 @@ def fast_decode_impl( dtype=target_dtype, offset_factor=1, strides=[s0], + scope="global", ) Zeros = T.match_buffer( zeros, @@ -776,6 +1131,7 @@ def fast_decode_impl( dtype=target_dtype, offset_factor=1, strides=[s1], + scope="global", ) with T.block("root"): T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) @@ -793,41 +1149,36 @@ def fast_decode_impl( return fast_decode_desc, fast_decode_impl -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_" -) +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN, *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="float16", loops_extent=8 - ), + source_bit=4, storage_dtype="int8", target_dtype="float16", loops_extent=8), ) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_" -) +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u2_to_int8_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN, *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="float16", loops_extent=8 - ), + source_bit=2, storage_dtype="int8", target_dtype="float16", loops_extent=8), ) -LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN = ( - "lop3_fast_decode_u4_to_int32_to_f16_l8_" +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u1_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=1, storage_dtype="int8", target_dtype="float16", loops_extent=8), ) + +LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int32_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN, *get_fast_decode_intrin( - source_bit=4, storage_dtype="int32", target_dtype="float16", loops_extent=8 - ), + source_bit=4, storage_dtype="int32", target_dtype="float16", loops_extent=8), ) - LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int32_to_f16_l8_scale_" -) + "lop3_fast_decode_u4_to_int32_to_f16_l8_scale_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN, *get_fast_decode_intrin( @@ -839,20 +1190,15 @@ def fast_decode_impl( ), ) -LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN = ( - "lop3_fast_decode_u4_to_uint32_to_f16_l8_" -) +LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_uint32_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN, *get_fast_decode_intrin( - source_bit=4, storage_dtype="uint32", target_dtype="float16", loops_extent=8 - ), + source_bit=4, storage_dtype="uint32", target_dtype="float16", loops_extent=8), ) - LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_uint32_to_f16_l8_scale_" -) + "lop3_fast_decode_u4_to_uint32_to_f16_l8_scale_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN, *get_fast_decode_intrin( @@ -864,10 +1210,8 @@ def fast_decode_impl( ), ) - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_" -) + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, *get_fast_decode_intrin( @@ -880,8 +1224,7 @@ def fast_decode_impl( ) LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_original_" -) + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_original_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, *get_fast_decode_intrin( @@ -891,13 +1234,12 @@ def fast_decode_impl( loops_extent=8, with_scale=True, with_zeros=True, - zeros_type="original", + zeros_mode="original", ), ) LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_rescale_" -) + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_rescale_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, *get_fast_decode_intrin( @@ -907,13 +1249,27 @@ def fast_decode_impl( loops_extent=8, with_scale=True, with_zeros=True, - zeros_type="rescale", + zeros_mode="rescale", ), ) -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_" +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_quantized_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="quantized", + ), ) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_") TensorIntrin.register( LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, *get_fast_decode_intrin( @@ -926,8 +1282,7 @@ def fast_decode_impl( ) LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_original_" -) + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_original_") TensorIntrin.register( LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, *get_fast_decode_intrin( @@ -937,13 +1292,12 @@ def fast_decode_impl( loops_extent=8, with_scale=True, with_zeros=True, - zeros_type="original", + zeros_mode="original", ), ) LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_rescale_" -) + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_rescale_") TensorIntrin.register( LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, *get_fast_decode_intrin( @@ -953,64 +1307,89 @@ def fast_decode_impl( loops_extent=8, with_scale=True, with_zeros=True, - zeros_type="rescale", + zeros_mode="rescale", ), ) +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_i8_l8_" +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( + "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_original_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="original", + ), ) + +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( + "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_rescale_") TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN, + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=8 + source_bit=1, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="rescale", ), ) -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_i8_l16_" +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=8), ) + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l16_") TensorIntrin.register( LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=16 - ), + source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=16), ) -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_i8_l16_" -) +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u2_to_int8_to_i8_l16_") TensorIntrin.register( LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16 - ), + source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16), ) -LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN = ( - "lop3_fast_decode_i2_to_int8_to_i8_l16_" -) +LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i2_to_int8_to_i8_l16_") TensorIntrin.register( LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16 - ), + source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16), ) -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_i8_l16_" -) +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_") TensorIntrin.register( LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16 - ), + source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16), ) -LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ( - "lop3_fast_decode_i4_to_int8_to_f16_l8_" -) +LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, *get_fast_decode_intrin( @@ -1023,8 +1402,7 @@ def fast_decode_impl( ) LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i4_to_int8_to_f16_l8_scale_" -) + "lop3_fast_decode_i4_to_int8_to_f16_l8_scale_") TensorIntrin.register( LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, *get_fast_decode_intrin( @@ -1037,9 +1415,7 @@ def fast_decode_impl( ), ) -LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN = ( - "lop3_fast_decode_i2_to_int8_to_f16_l8_" -) +LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i2_to_int8_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN, *get_fast_decode_intrin( @@ -1052,8 +1428,7 @@ def fast_decode_impl( ) LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i2_to_int8_to_f16_l8_scale_" -) + "lop3_fast_decode_i2_to_int8_to_f16_l8_scale_") TensorIntrin.register( LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, *get_fast_decode_intrin( @@ -1074,7 +1449,7 @@ def get_lop3_intrin_group( storage_dtype: Literal["int32", "int8"] = "int8", with_scaling: bool = False, with_zeros: bool = False, - zeros_type: Literal["original", "rescale", "quantized"] = "original", + zeros_mode: Literal["original", "rescale", "quantized"] = "original", ) -> Dict[str, str]: """ This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. @@ -1114,18 +1489,23 @@ def get_lop3_intrin_group( if with_scaling: _intrin += "scale_" if with_zeros: - _intrin += f"zeros_{zeros_type}_" + _intrin += f"zeros_{zeros_mode}_" import_c_map = { "i4_to_f16": decode_i4_to_f16, "i2_to_f16": decode_i2_to_f16, + "i1_to_f16": decode_i1_to_f16, "i4_to_f16_scale": decode_i4_to_f16_scale, "i2_to_f16_scale": decode_i2_to_f16_scale, + "i1_to_f16_scale": decode_i1_to_f16_scale, "i4_to_f16_scale_zeros_original": decode_i4_to_f16_scale_zeros_original, "i2_to_f16_scale_zeros_original": decode_i2_to_f16_scale_zeros_original, + "i1_to_f16_scale_zeros_original": decode_i1_to_f16_scale_zeros_original, "i4_to_f16_scale_zeros_rescale": decode_i4_to_f16_scale_zeros_rescale, "i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale, - "i1_to_i8": decode_i1s_to_i8s_l16, + "i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale, + "i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized, + "i1_to_i8": decode_i1s_to_i8s, "i2_to_i8": decode_i2s_to_i8s, "i4_to_i8": decode_i4s_to_i8s, } @@ -1133,7 +1513,7 @@ def get_lop3_intrin_group( if with_scaling: key += "_scale" if with_zeros: - key += f"_zeros_{zeros_type}" + key += f"_zeros_{zeros_mode}" return { "c_source": import_c_map[key], diff --git a/python/bitblas/gpu/matmul.py b/python/bitblas/gpu/matmul.py index 4147e44d0..ad450eff2 100644 --- a/python/bitblas/gpu/matmul.py +++ b/python/bitblas/gpu/matmul.py @@ -22,8 +22,15 @@ get_reduction_blocks, ) from .matmul_mma import MatmulTensorizationMMA -from .matmul_wmma import MatmulInt8Tensorization, MatmulTensorizationWMMA, MatmulTensorizationLegacy +from .matmul_wmma import ( + MatmulInt8Tensorization, + MatmulTensorizationWMMA, +) from functools import reduce +import logging + +logger = logging.getLogger(__name__) + class Matmul(GPUScheduleRule): """The schedule rule for matmul-like computation""" @@ -115,9 +122,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring apply_tensorization = False for item_var in block_stmt.iter_vars[1:]: extent = item_var.dom.extent - if isinstance(extent, tir.expr.IntImm): - if extent.value <= minimal_tensorize_threshold: - apply_tensorization = False + if isinstance(extent, + tir.expr.IntImm) and extent.value <= minimal_tensorize_threshold: + apply_tensorization = False if apply_tensorization: if in_dtype == "int8" and out_dtype == "int32": tensorize_sch = MatmulInt8Tensorization().apply(func, target, _) @@ -149,11 +156,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring ) batch, x, y, k = sch.get_loops(main_block) by, vy, ty, yi = sch.split( - y, [None, config.vthread_y, config.block_size_y, config.micro_size_y] - ) + y, [None, config.vthread_y, config.block_size_y, config.micro_size_y]) bx, vx, tx, xi = sch.split( - x, [None, config.vthread_x, config.block_size_x, config.micro_size_x] - ) + x, [None, config.vthread_x, config.block_size_x, config.micro_size_x]) ko, ki = sch.split(k, factors=[None, config.micro_size_k]) sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) by = sch.fuse(batch, by) @@ -214,7 +219,7 @@ def is_scheduled(block: tir.schedule.BlockRV) -> bool: return loop_kinds != {ForKind.SERIAL} blocks = sch.get_child_blocks(root_block) - max_threads_per_block = utils.max_threads_per_block(target) + max_threads_per_block = utils.max_threads_per_block(target) # noqa: F841 for block in blocks: if is_scheduled(block): continue @@ -248,7 +253,7 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring # in some case conv template will use this rule, but the tile config is not # analyzed by matmul expr. if len(config.block) != 2: - print(f"Warning: block config {config.block} is not valid for matmul, skip.") + logger.debug(f"Warning: block config {config.block} is not valid for matmul, skip.") return None main_block = reduction_blocks[0] @@ -275,8 +280,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) thread_row_tiles = config.thread[1] // (config.step[0] * 2) thread_col_tiles = config.thread[1] // (config.step[1] * 2) - vthread_row_tiles = config.step[0] * 2 # expand vtrhead to avoid load band conflict - vthread_col_tiles = config.step[1] * 2 # expand vtrhead to avoid load band conflict + vthread_row_tiles = (config.step[0] * 2) # expand vtrhead to avoid load band conflict + vthread_col_tiles = (config.step[1] * 2) # expand vtrhead to avoid load band conflict chunk = config.rstep[0] # Step 3. Schedule matmul @@ -328,8 +333,7 @@ def is_trivial_load(block): if len(reads) != 1 or len(writes) != 1: return False return all( - read.region[-1] == write.region[-1] for read, write in zip(reads, writes) - ) + read.region[-1] == write.region[-1] for read, write in zip(reads, writes)) if is_trivial_load(block): sch.vectorize(vec) @@ -348,24 +352,20 @@ def is_trivial_load(block): for i, input_region in enumerate(sch.get(main_block).reads): _buffer_name = input_region.buffer.name.replace("_reindex", "").replace("_pad", "") if _buffer_name not in config.cached_tensors: - print( + logger.warning( f"Warning: {_buffer_name} is not in cached_tensors {config.cached_tensors}, skip." ) continue # otherwise cooperative fetch in shared memory. - if _buffer_name in config.vectorize: - vectorize = config.vectorize[_buffer_name] - else: - vectorize = 1 + vectorize = config.vectorize.get(_buffer_name, 1) _cooperative_fetch(i, vec_len=vectorize) auto_inline_consumer_chain(sch, l2g) _, vec = sch.split( - sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)] - ) + sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)]) sch.vectorize(vec) sch.decompose_reduction(main_block, ko) diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py index a6f452170..df50d283c 100644 --- a/python/bitblas/gpu/matmul_analysis.py +++ b/python/bitblas/gpu/matmul_analysis.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from enum import Enum from typing import List, Optional, Set, Union, Tuple, Dict -from tvm import tir, DataType +from tvm import tir from tvm.ir import Range from tvm.tir import IterVar, PrimExpr, Var, BufferRegion from tvm.tir.analysis import undefined_vars @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) + def _is_one(x: PrimExpr) -> bool: return isinstance(x, tir.IntImm) and x.value == 1 @@ -52,14 +53,12 @@ def auto_inline_producers( inlined_cnt = 0 producers = _collect_producers(sch, block) for producer in producers: - if any( - sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks - ): + if any(sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks): continue try: sch.compute_inline(producer) inlined_cnt += 1 - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except continue if inlined_cnt == 0: return @@ -76,13 +75,13 @@ def auto_inline_consumers( try: sch.compute_inline(consumer) inlined_cnt += 1 - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except continue for consumer in consumers: try: sch.reverse_compute_inline(consumer) inlined_cnt += 1 - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except continue if inlined_cnt == 0: return @@ -106,24 +105,25 @@ def auto_inline_consumer_chain( # Try inlining into the cache-write stage again, this time it should succeed. auto_inline_consumers(sch, block) + # used to match the similar region with dequantize op. -def find_first_similar_region(regions:List[BufferRegion], buffer: tir.Buffer): +def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer): for region in regions: if len(region.buffer.shape) == len(buffer.shape): return region return None + # used to match the similar buffer with dequantize op. -def find_first_similar_buffer(regions:List[BufferRegion], buffer: tir.Buffer): +def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): for region in regions: if len(region.buffer.shape) == len(buffer.shape): return region.buffer return None + # find the block that required to be reindex and scope. -def find_last_producer_from_buffer( - sch, main_block, buffer: tir.Buffer -) -> Optional[BlockRV]: +def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]: # block that most near to the arguments block = main_block buffer = buffer @@ -146,9 +146,8 @@ def find_last_producer_from_buffer( return block -def find_arg_idx_from_buffer_chain( - sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer -) -> int: +def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, + buffer: tir.Buffer) -> int: """traverse to find the arg index from the buffer""" producers = sch.get_producers(main_block) @@ -217,8 +216,7 @@ def make_iter_fusion_index_map( fused_iters[trait.kind] = v_i final_indices: List[tir.PrimExpr] = [ - fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) - for kind in kind_order + fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order ] return tir.IndexMap(input_iters, final_indices, None) @@ -291,22 +289,15 @@ def get_access_axes(region: List[Range]) -> Set[Var]: if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: return None - A_traits = [ - traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes - ] - B_traits = [ - traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes - ] - C_traits = [ - traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes - ] + A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes] + B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes] + C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes] block_traits = [traits[i.var] for i in block.iter_vars] return A_traits, B_traits, C_traits, block_traits -def get_index_map( - block: tir.Block, layout: List[str] = ["n", "t", "n"] -) -> Optional[Tuple[tir.IndexMap, ...]]: +def get_index_map(block: tir.Block, + layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]: """Get index maps for the block Parameters @@ -325,6 +316,8 @@ def get_index_map( index_maps : Optional[Tuple[tir.IndexMap]] The index maps for the block, or None if the block is not a gemm-liked kernel """ + if layout is None: + layout = ["n", "t", "n"] traits = detect_iter_traits(block) if traits is None: return None @@ -376,23 +369,17 @@ def infer_layout(layout: str, region: List[Range], kind: str = "A"): if kind == "C": return [IterKind.kIter_S, primary_iter, secondary_iter] else: - return ( - [IterKind.kIter_S, spatial_iter, reduction_iter] - if check_last_trait(region) - else [IterKind.kIter_S, reduction_iter, spatial_iter] - ) + return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region) + else [IterKind.kIter_S, reduction_iter, spatial_iter]) else: raise ValueError(f"Unknown layout {layout}") A_index_map = make_iter_fusion_index_map( - A_traits, infer_layout(layout[0], block.reads[0].region, kind="A") - ) + A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) B_index_map = make_iter_fusion_index_map( - B_traits, infer_layout(layout[1], block.reads[1].region, kind="B") - ) + B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) C_index_map = make_iter_fusion_index_map( - C_traits, infer_layout(layout[2], block.writes[0].region, kind="C") - ) + C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) matmul_index_map = make_iter_fusion_index_map( block_traits, @@ -424,14 +411,10 @@ def is_dequantize(block: BlockRV) -> bool: block_stmt = sch.get(block) if len(block_stmt.reads) < 2: return False - has_uint_input = any( - "uint" in str(region.buffer.dtype) for region in block_stmt.reads - ) + has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) if not has_uint_input: return False - if len(block_stmt.writes) != 1 or "float" not in str( - block_stmt.writes[0].buffer.dtype - ): + if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype): return False return True @@ -456,10 +439,7 @@ def get_access_vars(region: List[Range]) -> List[Var]: axes.extend(undefined_vars(r.min)) # remove trivial axis trivial_vars = set( - iter_var.var - for iter_var in block_stmt.iter_vars - if _is_one(iter_var.dom.extent) - ) + iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) axes = [axis for axis in axes if axis not in trivial_vars] # remove duplicate axis axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] @@ -468,9 +448,8 @@ def get_access_vars(region: List[Range]) -> List[Var]: lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] is_identity = list(lhs_access_vars) == list(rhs_access_vars) - is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set( - lhs_access_vars - ) == set(rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( + rhs_access_vars) return is_identity, is_transpose @@ -490,23 +469,25 @@ def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] continue try: sch.compute_inline(block) - except: + except Exception: try: sch.reverse_compute_inline(block) - except: + except Exception: result_blocks.append(block) return result_blocks -def normalize_to_matmul( - sch: tir.Schedule, main_block: BlockRV, layout: List[str] = ["n", "t", "n"] -) -> Optional[tir.Schedule]: +def normalize_to_matmul(sch: tir.Schedule, + main_block: BlockRV, + layout: Optional[List[str]] = None) -> Optional[tir.Schedule]: + if layout is None: + layout = ["n", "t", "n"] block_stmt = sch.get(main_block) # let layout be 'a' to auto inference the layout index_maps = get_index_map(block_stmt, layout=layout) if index_maps is None: - print("[WARNING] Cannot find the appropriate index map for tensorcore") + logger.debug("Cannot find the appropriate index map for tensorcore") return None matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps @@ -526,17 +507,17 @@ def normalize_to_matmul( def get_tensorized_func_and_tags( func: tir.PrimFunc, target: Target, - layout: List[str] = ["a", "a", "a"], + layout: Optional[List[str]] = None, skip_normalize: bool = False, allow_gemv: bool = False, ) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group, - ) - + get_wmma_intrin_group,) """ transform function to matmul if necessary (e.g. transform conv2d with im2col) """ + if layout is None: + layout = ["a", "a", "a"] # step1. detect whether the function can utilize tensorcore sch = tir.Schedule(func) root_block = get_root_block(sch) @@ -552,12 +533,8 @@ def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool: conditions.append(len(block_stmt.writes) == 1) conditions.append( len( - collect_block_iter_vars_used_in_access_region( - block_stmt, block_stmt.writes[0].region - ) - ) - > 0 - ) + collect_block_iter_vars_used_in_access_region(block_stmt, + block_stmt.writes[0].region)) > 0) if not all(conditions): return False return True @@ -567,9 +544,7 @@ def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 - def analysis_tensorcore_tags( - sch: tir.Schedule, block: BlockRV, target: Target - ) -> bool: + def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool: tags: Dict[str, Union[List[int], int]] = {} block_stmt = sch.get(block) @@ -583,7 +558,7 @@ def analysis_tensorcore_tags( # todo(lei): maybe we can integrate this into policy in the future tags["pipeline_stage"] = 1 if target.kind.name == "cuda" and check_sm_version(target.arch) == 80: - # enable pipleline stage only for sm_80 devices + # enable pipeline stage only for sm_80 devices tags["pipeline_stage"] = 2 # analysis async copy @@ -593,7 +568,7 @@ def analysis_tensorcore_tags( # async copy only works in software pipeline. tags["use_async_copy"] = True - # analysis intrin infomation + # analysis intrin information def get_ordered_axes(region: List[Range]) -> Set[Var]: axes: List[Var] = [] for r in region: @@ -618,10 +593,10 @@ def check_last_trait(region: List[Range]): intrin_info["out_dtype"] = out_dtype # if the last dimension is reduce axis, the B is transposed intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region) - if func.attrs is not None and "smooth_a" in func.attrs: - intrin_info["smooth_a"] = func.attrs["smooth_a"] - if func.attrs is not None and "smooth_b" in func.attrs: - intrin_info["smooth_b"] = func.attrs["smooth_b"] + if func.attrs is not None and "input_transform_kind" in func.attrs: + intrin_info["input_transform_kind"] = func.attrs["input_transform_kind"] + if func.attrs is not None and "weight_transform_kind" in func.attrs: + intrin_info["weight_transform_kind"] = func.attrs["weight_transform_kind"] tags["intrin_info"] = intrin_info return tags @@ -638,7 +613,7 @@ def check_last_trait(region: List[Range]): in_dtype=in_dtype, out_dtype=out_dtype, ) - except: + except Exception: logger.debug("Cannot find the corresponding wmma intrin group") return func, None @@ -656,28 +631,24 @@ def check_last_trait(region: List[Range]): minimal_tensorize_threshold = 16 # the batch dimension is not taken into consideration. extent = block_stmt.iter_vars[1].dom.extent - if isinstance(extent, tir.expr.IntImm): - if extent.value < (1 if allow_gemv else minimal_tensorize_threshold): - return func, None + if isinstance(extent, + tir.expr.IntImm) and (extent.value < + (1 if allow_gemv else minimal_tensorize_threshold)): + return func, None for item_var in block_stmt.iter_vars[2:]: extent = item_var.dom.extent - if isinstance(extent, tir.expr.IntImm): - if extent.value < minimal_tensorize_threshold: - return func, None + if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): + return func, None tags = analysis_tensorcore_tags(sch, main_block, target) return sch.mod["main"], tags return func, None -def get_propagate_map( - trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32" -): +def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - ldmatrix_32x8_to_shared_16x16_layout, - ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, - ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, ) assert dtype in ["float16", "int8"], "Only support float16 for now" @@ -686,9 +657,9 @@ def get_propagate_map( ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout elif dtype == "int8": # int8 mma only support 32x16 to 16x32 layout - if matrix_name == "A" and trans == False: + if matrix_name == "A" and trans is False: ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a - elif matrix_name == "B" and trans == True: + elif matrix_name == "B" and trans is True: ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_b else: raise ValueError("Unknown matrix name ", matrix_name) @@ -712,9 +683,7 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): if dtype == "float16": ldmatrix_index_map = ( ldmatrix_trans_permutation_16x16_32x8_16x16 - if trans - else ldmatrix_permutation_16x16_32x8_16x16 - ) + if trans else ldmatrix_permutation_16x16_32x8_16x16) else: ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 @@ -757,9 +726,7 @@ def layout_propagate_chain( read_indices = [r.min for r in read.region] # reverse index map from [vi // x] -> [vi * x] to match the inconsistent layout tmp_index_map = IndexMap(write_indices, read_indices, None) - tmp_index_map = tmp_index_map.non_surjective_inverse( - write.buffer.shape - )[0] + tmp_index_map = tmp_index_map.non_surjective_inverse(write.buffer.shape)[0] # if dequantize like ops are used, the scaling factor should be considered # to be applied to the final indices @@ -767,8 +734,7 @@ def layout_propagate_chain( for i, j in zip(write.buffer.shape, read.buffer.shape): scaling_factor *= i // j final_indices = list( - index_map.map_indices(tmp_index_map.map_indices(write_indices)) - ) + index_map.map_indices(tmp_index_map.map_indices(write_indices))) final_indices[-1] = final_indices[-1] // scaling_factor index_map = IndexMap( write_indices, diff --git a/python/bitblas/gpu/matmul_mma.py b/python/bitblas/gpu/matmul_mma.py index 9e26712fa..143ca9b45 100644 --- a/python/bitblas/gpu/matmul_mma.py +++ b/python/bitblas/gpu/matmul_mma.py @@ -28,7 +28,8 @@ ) -def get_index_map_3d(index_map, l=16, r=16): +def get_index_map_3d(index_map, l=16, r=16): # noqa: E741 + def index_map_3d(b, i, j): return ( b, @@ -56,7 +57,7 @@ def index_map_5d(b, i, j, ii, jj): return index_map_5d -def get_warp_index_map(index_map, l=16, r=16, is_5d=False): +def get_warp_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741 if is_5d: return get_index_map_5d(index_map) return get_index_map_3d(index_map, l, r) @@ -74,6 +75,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring target: Target, _: bool, ) -> Optional[tir.Schedule]: + if "dequantize_info" in func.attrs: + dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() + return dequantize_rule.apply(func, target, False) sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) @@ -114,8 +118,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring # Tensorization by hardware intrinsics from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, - shared_16x16_to_mma_32x8_layout, + get_mma_intrin_group, shared_16x16_to_mma_32x8_layout, ) # tile size @@ -129,7 +132,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring thread_z, thread_y, thread_x = 2, 2, 32 vector_size = 8 - unroll_depth = 4 + unroll_depth = 4 # noqa: F841 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] block = sch.reindex(main_block, ("read", 0)) @@ -193,33 +196,24 @@ def apply( # pylint: disable=too-many-locals,missing-docstring sch.bind(j2, "threadIdx.y") # Step 4. Read/write to shared mem and register - def fetch_input( - block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose - ): + def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose): # 1) Read to shared memory block_read_smem = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") sch.compute_at(block_read_smem, k0) - auto_inline_producers( - sch, block_read_smem, [dequantize_block] if dequantize_block else [] - ) + auto_inline_producers(sch, block_read_smem, + [dequantize_block] if dequantize_block else []) # For transposed read, we directly load transposed tensor from global # Then use ldmatrix.trans to handle transpose later - if (tensor_name == "A" and is_transpose) or ( - tensor_name == "B" and not is_transpose - ): + if (tensor_name == "A" and is_transpose) or (tensor_name == "B" and not is_transpose): # specifical handle transpose read (for NN matmul or TT matmul) v0, v1 = sch.get_loops(block_read_smem)[-2:] sch.reorder(v1, v0) - sch.transform_layout( - block_read_smem, ("write", 0), lambda b, i, j: (b, j, i) - ) + sch.transform_layout(block_read_smem, ("write", 0), lambda b, i, j: (b, j, i)) # bind loops fused = sch.fuse(*sch.get_loops(block_read_smem)[-2:]) - f0, f1, f2, f3, f4 = sch.split( - fused, [None, thread_z, thread_y, thread_x, vector_size] - ) + f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size]) sch.bind(f1, "threadIdx.z") sch.bind(f2, "threadIdx.y") sch.bind(f3, "threadIdx.x") @@ -234,17 +228,11 @@ def fetch_input( # bind_loops micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n - micro_size_1, micro_size_2 = ( - (micro_size_spatial, micro_size_k) - if not is_transpose - else (micro_size_k, micro_size_spatial) - ) - v00, v01 = sch.split( - sch.get_loops(block_read_reg)[-2], [None, micro_size_1] - ) - v10, v11 = sch.split( - sch.get_loops(block_read_reg)[-1], [None, micro_size_2] - ) + micro_size_1, micro_size_2 = ((micro_size_spatial, + micro_size_k) if not is_transpose else + (micro_size_k, micro_size_spatial)) + v00, v01 = sch.split(sch.get_loops(block_read_reg)[-2], [None, micro_size_1]) + v10, v11 = sch.split(sch.get_loops(block_read_reg)[-1], [None, micro_size_2]) sch.reorder(v00, v10, v01, v11) # reorder read axis to match the layout of ldmatrix @@ -255,9 +243,7 @@ def fetch_input( v0, v1 // micro_size_1, v2 // micro_size_2, - *shared_16x16_to_mma_32x8_layout( - v1 % micro_size_1, v2 % micro_size_2 - ), + *shared_16x16_to_mma_32x8_layout(v1 % micro_size_1, v2 % micro_size_2), ), ) @@ -267,19 +253,13 @@ def fetch_input( return block_read_smem, block_read_reg - block_read_a, block_read_reg_a = fetch_input( - block_outer, 0, "A", is_transpose_a - ) - block_read_b, block_read_reg_b = fetch_input( - block_outer, 1, "B", is_transpose_b - ) + block_read_a, block_read_reg_a = fetch_input(block_outer, 0, "A", is_transpose_a) + block_read_b, block_read_reg_b = fetch_input(block_outer, 1, "B", is_transpose_b) # Write to register, and then smem def store_output(block_outer, write_buffer_idx): # 1) Write to shared memory - block_write_smem = sch.cache_write( - block_outer, write_buffer_idx, "shared.dyn" - ) + block_write_smem = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") sch.reverse_compute_at(block_write_smem, block_axis) auto_inline_consumer_chain(sch, block_write_smem) @@ -308,9 +288,7 @@ def store_output(block_outer, write_buffer_idx): v0, v1 // micro_size_m, v2 // micro_size_n, - *shared_16x16_to_mma_32x8_layout( - v1 % micro_size_m, v2 % micro_size_n - ), + *shared_16x16_to_mma_32x8_layout(v1 % micro_size_m, v2 % micro_size_n), ), ) @@ -366,8 +344,9 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring return dequantize_rule.apply_config(func, config) from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, - ) + get_mma_intrin_group,) + + import_source: List[str] = [] sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) @@ -381,10 +360,6 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring return None main_block = reduction_blocks[0] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] @@ -415,20 +390,6 @@ def check_has_dynamic(func: tir.PrimFunc): if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not ( - func.attrs is not None - and "dlight.tensorcore_prenormlized" in func.attrs.keys() - ): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not ( - func.attrs is not None - and "dlight.tensorcore_prenormlized" in func.attrs.keys() - ): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - shared_scope = config.shared_scope intrin_info = config.intrin_info @@ -460,8 +421,8 @@ def check_has_dynamic(func: tir.PrimFunc): micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] # get the axis for layout transform - def get_axis(l, r, trans): - return (r, l) if trans else (l, r) + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) @@ -473,8 +434,8 @@ def can_enable_swizzle(dtype: str, smooth: bool): return not smooth return False - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_b) + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) warp_size = 32 @@ -527,32 +488,14 @@ def can_enable_swizzle(dtype: str, smooth: bool): thread_idy = i2 thread_idz = j2 - # plan rasteration - if ( - not isinstance(config.rasterization_plan, NoRasterization) - and sch.get(batch).extent.value == 1 - ): - device_func, invoke_func = config.rasterization_plan.get_code() - factor = config.rasterization_plan.panel_width_ - - # TODO(lei): this is a trick for rasterization implementation - # is not optimal. (5% performance loss) - # require a solution for general block rasterization - factor = 8 # should be divisible by block_idx - if sch.get(block_idx).extent.value % factor == 0: - block_k, block_idx = sch.split(block_idx, factors=[None, factor]) - sch.reorder(block_k, block_idy, block_idx) - sch.bind(block_k, "blockIdx.z") - else: - sch.bind(batch, "blockIdx.z") - + sch.bind(batch, "blockIdx.z") sch.bind(block_idx, "blockIdx.x") sch.bind(block_idy, "blockIdx.y") sch.bind(thread_idy, "threadIdx.y") sch.bind(thread_idz, "threadIdx.z") # rewrite smooth layout of shared memory - def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): + def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: return sch.transform_layout( @@ -568,24 +511,19 @@ def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): ) smooth_smem_layout_rewrite( - block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a - ) + block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) smooth_smem_layout_rewrite( - block_outer, ("read", 1), *b_lr, enable=intrin_info.smooth_b - ) + block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b) smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) - def fetch_to_shared( - block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False - ): + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): block_read = sch.cache_read(block, idx, shared_scope) sch.compute_at(block_read, k0, preserve_unit_loops=True) ndim = len(sch.get(block_read).iter_vars) fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[num_ty, num_tz, None, warp_size, vec_len] - ) + fused, factors=[num_ty, num_tz, None, warp_size, vec_len]) sch.bind(f_3, "threadIdx.x") sch.bind(f_1, "threadIdx.z") @@ -622,9 +560,7 @@ def fetch_to_shared( ) # rewrite global smooth layout - def smooth_gmem_layout_rewrite( - sch, block, enable=True, trans=False, matrix_name="A" - ): + def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"): if not enable: return # step1: find the first producer block @@ -633,10 +569,11 @@ def smooth_gmem_layout_rewrite( # read and write buffer producers = _collect_producers(sch, block) g2s_block = a_g2s if matrix_name == "A" else b_g2s - propagate_block: tir.Block = producers[-1] if len(producers) > 0 else g2s_block + propagate_block: tir.Block = (producers[-1] if len(producers) > 0 else g2s_block) # step2: transform the layout with inverse permutation - _, inverse_indexmap = get_propagate_map(trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) def inverse_permutation(i, j, ii, jj): return (i, j, *intra_indexmap.map_indices([ii, jj])) @@ -644,11 +581,9 @@ def inverse_permutation(i, j, ii, jj): sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) smooth_gmem_layout_rewrite( - sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A" - ) + sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") smooth_gmem_layout_rewrite( - sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B" - ) + sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") auto_inline_producers(sch, a_g2s) auto_inline_producers(sch, b_g2s) @@ -697,12 +632,12 @@ def inverse_permutation(i, j, ii, jj): sch.transform_layout( A_mat, ("write", 0), - get_warp_index_map(index_map_a, *a_lr, intrin_info.smooth_a), + get_warp_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), ) sch.transform_layout( B_mat, ("write", 0), - get_warp_index_map(index_map_b, *b_lr, intrin_info.smooth_b), + get_warp_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), ) sch.transform_layout( store, @@ -734,11 +669,24 @@ def tensorize_init_store_compute(): tensorize_init_store_compute() if stage > 1: - sch.annotate( - k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1] - ) + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) if use_async: sch.annotate(k0, "software_pipeline_async_stages", [0]) + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) return sch diff --git a/python/bitblas/gpu/matmul_mma_dequantize.py b/python/bitblas/gpu/matmul_mma_dequantize.py index 4bec0a972..eca5fd2e2 100644 --- a/python/bitblas/gpu/matmul_mma_dequantize.py +++ b/python/bitblas/gpu/matmul_mma_dequantize.py @@ -3,11 +3,13 @@ # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" -from typing import Literal, Optional +from typing import Optional, List +from contextlib import suppress from tvm import tir -from tvm.target import Target +from ..base.roller.hint import Hint, IntrinInfo +from tvm.target import Target from ..base.roller.rasterization import NoRasterization from ..base import analysis from .base import GPUScheduleRule @@ -20,10 +22,13 @@ get_propagate_map, layout_propagate_chain, find_last_producer_from_buffer, + _collect_producers, + get_in_out_dtypes, ) -def get_index_map_3d(index_map, l=16, r=16): +def get_index_map_3d(index_map, l=16, r=16): # noqa: E741 + def index_map_3d(b, i, j): return ( b, @@ -51,7 +56,7 @@ def index_map_5d(b, i, j, ii, jj): return index_map_5d -def get_index_map(index_map, l=16, r=16, is_5d=False): +def get_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741 if is_5d: return get_index_map_5d(index_map) return get_index_map_3d(index_map, l, r) @@ -63,6 +68,510 @@ class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule): func with attr 'dlight.do_not_tensorize' will not be tensorized. """ + def apply( + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ): + """ + For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. + quantized weight + | + V + dequantized in register + | + V + save into shared memory + | + V + compute + """ + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + # always enable shared memory rewrite + cache_write_required = True + + # Check Dequantize Info + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append( + weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" + + # Start Schedule + # Step 1. Get default schedule config. + + # tensor core intrinsic size + in_dtype, out_dtype = get_in_out_dtypes(sch.get(main_block)) + intrin_info = IntrinInfo( + in_dtype=in_dtype, + out_dtype=out_dtype, + trans_b=True, + ) + if "weight_transform_kind" in func.attrs: + intrin_info.weight_transform_kind = int(func.attrs["weight_transform_kind"]) + + if "input_transform_kind" in func.attrs: + intrin_info.input_transform_kind = int(func.attrs["input_transform_kind"]) + # default Hint + config = Hint().from_dict({ + "block": [128, 128], + "warp": [64, 64], + "rstep": [32], + "pipeline_stage": 1, + "use_async": False, + "intrin_info": intrin_info, + "shared_scope": "shared.dyn", + }) + shared_scope = config.shared_scope + + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + in_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + # rewrite global smooth layout, for dequantize, currently only support weight only recover. + def smooth_gmem_layout_rewrite(sch, main_block, enable=True, trans=False, matrix_name="A"): + if not enable: + return + + # normalized block may have three read buffers, while the first one is the write buffer. + buffer_offset = (1 if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer else 0) + buffer_idx = 0 if matrix_name == "A" else 1 + source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer + + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + propagate_block: tir.Block = find_last_producer_from_buffer( + sch, main_block, source_buffer) + # some trick impl may not have reindex block + (weight_dequantize_info,) = dequantize_info.values() + if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): + return + + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # step3: propagate the matmul layout to the first reindex block + + intra_indexmap = layout_propagate_chain( + sch, + start_block=main_block, + start_buffer=source_buffer, + end_block=propagate_block, + index_map=intra_indexmap, + ) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["n", "t", "n"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") + + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_layout_recover( + block_outer, + ("read", 1), + *b_lr, + enable=intrin_info.inter_transform_b, + ) + smooth_layout_recover(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_2, "threadIdx.z") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_4) + sch.unroll(f_0) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_2, "pragma_unroll_explicit", False) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=4, + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + ) + + auto_inline_producers(sch, a_g2s) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, k0, preserve_unit_loops=True) + + decode_factor = get_coalesced_veclen(sch.get(block_shared)) + _, B_shared_vi, _ = sch.split( + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = _collect_producers(sch, weight_dequantize_block) + auto_inline_producers(sch, block_shared_local, weight_producers) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") + + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) + + dequantize_block_local = block_shared_local + if ("zeros_mode" in weight_decode_info and + weight_decode_info["zeros_mode"] == "quantized"): + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + # pop the scale block + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + + # fast type conversion + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize( + sch.get_loops(dequantize_block_local)[-1], + lop3_intrin_info["compute"], + ) + import_source.append(lop3_intrin_info["c_source"]) + + sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) + union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( + B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) + if not (can_swizzle_b or intrin_info.smooth_b): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) + sch.bind(B_shared_tx, "threadIdx.x") + sch.bind(B_shared_ty, "threadIdx.y") + sch.bind(B_shared_tz, "threadIdx.z") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) + sch.reverse_compute_at(block_shared_lut, j2) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + _ = decode_fetch_to_shared(block_outer, 1) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1, preserve_unit_loops=True) + sch.compute_at(B_mat, k1, preserve_unit_loops=True) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate( + k0, + ann_key="software_pipeline_stage", + ann_val=[0, 0, stage - 1], + ) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch + def sch_dequantize_in_register_with_config( self, func: tir.PrimFunc, @@ -82,9 +591,10 @@ def sch_dequantize_in_register_with_config( compute """ from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, - ) - from .intrin.lop3 import get_lop3_intrin_group + get_mma_intrin_group,) + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) @@ -117,26 +627,18 @@ def check_dequantize_info(dequantize_info): def check_weight_decode_info(weight_decode_info): conditions = [] - # check source format in ["int", "fp", "af"] + # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) conditions.append( - weight_decode_info["source_format"]["format"] - in ["uint", "int", "fp", "af"] - ) + weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) # check source bits in [1, 2, 4, 8] - conditions.append( - weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8] - ) + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] conditions.append("target_format" in weight_decode_info) - conditions.append( - weight_decode_info["target_format"] in ["float16", "int8"] - ) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) return all(conditions) - assert check_weight_decode_info( - weight_decode_info - ), "Invalid Weight Decode Info" + assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" # Start Schedule # Step 0. Get schedule config. @@ -146,7 +648,6 @@ def check_weight_decode_info(weight_decode_info): intrin_info = config.intrin_info shared_scope = config.shared_scope - intrin_info = config.intrin_info intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", @@ -170,8 +671,8 @@ def check_weight_decode_info(weight_decode_info): micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] # get the axis for layout transform - def get_axis(l, r, trans): - return (r, l) if trans else (l, r) + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) @@ -183,23 +684,17 @@ def can_enable_swizzle(dtype: str, smooth: bool): return not smooth return False - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_b) + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) # rewrite global smooth layout, for dequantize, currently only support weight only recover. - def smooth_gmem_layout_rewrite( - sch, main_block, enable=True, trans=False, matrix_name="A" - ): + def smooth_gmem_layout_rewrite(sch, main_block, enable=True, trans=False, matrix_name="A"): if not enable: return # normalized block may have three read buffers, while the first one is the write buffer. - buffer_offset = ( - 1 - if sch.get(main_block).reads[0].buffer - == sch.get(main_block).writes[0].buffer - else 0 - ) + buffer_offset = (1 if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer else 0) buffer_idx = 0 if matrix_name == "A" else 1 source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer @@ -208,43 +703,36 @@ def smooth_gmem_layout_rewrite( # otherwise, the layout transform will have no effect as it will transform both # read and write buffer propagate_block: tir.Block = find_last_producer_from_buffer( - sch, main_block, source_buffer - ) + sch, main_block, source_buffer) # some trick impl may not have reindex block (weight_dequantize_info,) = dequantize_info.values() - if ( - sch.get(propagate_block).name_hint - == weight_dequantize_info["decode_block"] - ): + if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): return # step2: transform the layout with inverse permutation - _, inverse_indexmap = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name - ) + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) # step3: propagate the matmul layout to the first reindex block - inverse_indexmap = layout_propagate_chain( + intra_indexmap = layout_propagate_chain( sch, start_block=main_block, start_buffer=source_buffer, end_block=propagate_block, - index_map=inverse_indexmap, + index_map=intra_indexmap, ) def inverse_permutation(i, j, ii, jj): - return (i, j, *inverse_indexmap.map_indices([ii, jj])) + return (i, j, *intra_indexmap.map_indices([ii, jj])) sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A" - ) + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B" - ) + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") warp_size = 32 @@ -261,10 +749,7 @@ def inverse_permutation(i, j, ii, jj): k_pad_factor = k_factors[1] # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not ( - func.attrs is not None - and "dlight.tensorcore_prenormlized" in func.attrs.keys() - ): + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) # Step 2. Padding for dynamic shape kernels @@ -304,30 +789,13 @@ def inverse_permutation(i, j, ii, jj): thread_idy = i2 thread_idz = j2 - # plan rasteration - if ( - not isinstance(config.rasterization_plan, NoRasterization) - and sch.get(batch).extent.value == 1 - ): - device_func, invoke_func = config.rasterization_plan.get_code() - factor = config.rasterization_plan.panel_width_ - - # TODO(lei): this is a trick for rasterization implementation - # is not optimal. - # require a solution for general block rasterization - # factor = 8 # should be divisible by block_idy - # if sch.get(block_idx).extent.value % factor == 0: - # block_k, block_idx = sch.split(block_idx, factors=[None, factor]) - # sch.bind(block_k, "blockIdx.z") - else: - sch.bind(batch, "blockIdx.z") - + sch.bind(batch, "blockIdx.z") sch.bind(block_idx, "blockIdx.x") sch.bind(block_idy, "blockIdx.y") sch.bind(thread_idy, "threadIdx.y") sch.bind(thread_idz, "threadIdx.z") - def smooth_layout_recover(block, scope, l=16, r=16, enable=True): + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: return sch.transform_layout( @@ -342,14 +810,12 @@ def smooth_layout_recover(block, scope, l=16, r=16, enable=True): ), ) - smooth_layout_recover( - block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a - ) + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) smooth_layout_recover( block_outer, ("read", 1), *b_lr, - enable=intrin_info.smooth_b, + enable=intrin_info.inter_transform_b, ) smooth_layout_recover(block_outer, ("write", 0), enable=True) @@ -360,8 +826,7 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[None, num_ty, num_tz, warp_size, vec_len] - ) + fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) sch.bind(f_3, "threadIdx.x") sch.bind(f_2, "threadIdx.z") @@ -395,39 +860,51 @@ def decode_fetch_to_shared(block, idx): decode_factor = get_coalesced_veclen(sch.get(block_shared)) _, B_shared_vi, _ = sch.split( - sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor] - ) + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) block_shared_local = sch.cache_read(block_shared, 0, "local") # global -> dequantzed_local -> shared # step2. inline to local block - auto_inline_producers(sch, block_shared_local) + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = _collect_producers(sch, weight_dequantize_block) + auto_inline_producers(sch, block_shared_local, weight_producers) # get target dequantize buffer's idx def get_idx(): # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structual based way + # maybe we can use a more general and structural based way # to analysis the idx - if weight_decode_info["source_format"]["format"] == "af": + if weight_decode_info["source_format"]["format"] == "nf": return 1 return 0 b_idx = get_idx() # global -> prefetch_local -> dequantzed_local -> shared - block_shared_local_local = sch.cache_read( - block_shared_local, b_idx, "local" - ) + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) - sch.compute_at( - block_shared_local_local, B_shared_vi, preserve_unit_loops=True - ) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) dequantize_block_local = block_shared_local + if ("zeros_mode" in weight_decode_info and + weight_decode_info["zeros_mode"] == "quantized"): + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + # pop the scale block + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + # fast type conversion - if ( - "fast_decoding" in weight_decode_info - and weight_decode_info["fast_decoding"] - ): + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): source_bit = weight_decode_info["source_format"]["bits"] out_dtype = weight_decode_info["target_format"] lop3_intrin_info = get_lop3_intrin_group( @@ -437,28 +914,22 @@ def get_idx(): source_bit=source_bit, with_scaling=weight_decode_info["with_scaling"], with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], ) sch.tensorize( sch.get_loops(dequantize_block_local)[-1], lop3_intrin_info["compute"], ) - sch.annotate( - thread_idz, - ann_key="pragma_import_c", - ann_val=lop3_intrin_info["c_source"], - ) + import_source.append(lop3_intrin_info["c_source"]) sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( - B_shared_fused, factors=[None, num_ty, num_tz, warp_size] - ) + B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) if not (can_swizzle_b or intrin_info.smooth_b): pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align( - block_shared, 0, axis=-2, factor=16, offset=pad_offset - ) + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) sch.bind(B_shared_tx, "threadIdx.x") sch.bind(B_shared_ty, "threadIdx.y") sch.bind(B_shared_tz, "threadIdx.z") @@ -467,13 +938,10 @@ def get_idx(): # cache small tensors, e.g. LUT if b_idx: - block_shared_lut = sch.cache_read( - dequantize_block_local, 0, shared_scope - ) + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) sch.reverse_compute_at(block_shared_lut, j2) _, B_shared_tx = sch.split( - sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size] - ) + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) sch.bind(B_shared_tx, "threadIdx.x") return block_shared_local @@ -523,10 +991,14 @@ def get_idx(): index_map_a, index_map_b, index_map_c = intrin_group["index_map"] sch.transform_layout( - A_mat, ("write", 0), get_index_map(index_map_a, *a_lr, intrin_info.smooth_a) + A_mat, + ("write", 0), + get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), ) sch.transform_layout( - B_mat, ("write", 0), get_index_map(index_map_b, *b_lr, intrin_info.smooth_b) + B_mat, + ("write", 0), + get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), ) sch.transform_layout( store, @@ -566,7 +1038,22 @@ def tensorize_init_store_compute(): sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) if use_async: sch.annotate(k0, "software_pipeline_async_stages", [0]) - + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) return sch def sch_shared_memory_prefetch_with_config( @@ -589,9 +1076,10 @@ def sch_shared_memory_prefetch_with_config( compute """ from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, - ) - from .intrin.lop3 import get_lop3_intrin_group + get_mma_intrin_group,) + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) @@ -609,7 +1097,7 @@ def sch_shared_memory_prefetch_with_config( cache_write_required = True # Check Dequantize Info - # TODO(leiwang): this is a hack to get the configuaration, can be improved by writing a pass to analysis the dequantize block. + # TODO(leiwang): this is a hack to get the configuration, can be improved by writing a pass to analysis the dequantize block. dequantize_info = func.attrs["dequantize_info"] def check_dequantize_info(dequantize_info): @@ -625,21 +1113,15 @@ def check_dequantize_info(dequantize_info): def check_weight_decode_info(weight_decode_info): conditions = [] - # check source format in ["int", "fp", "af"] + # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) conditions.append( - weight_decode_info["source_format"]["format"] - in ["uint", "int", "fp", "af"] - ) + weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) # check source bits in [1, 2, 4, 8] - conditions.append( - weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8] - ) + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] conditions.append("target_format" in weight_decode_info) - conditions.append( - weight_decode_info["target_format"] in ["float16", "int8"] - ) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) return all(conditions) assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" @@ -649,7 +1131,6 @@ def check_weight_decode_info(weight_decode_info): # NOTE: we can analyze the config by the hardware spec in the future # tensor core intrinsic size - intrin_info = config.intrin_info shared_scope = config.shared_scope intrin_info = config.intrin_info @@ -676,8 +1157,8 @@ def check_weight_decode_info(weight_decode_info): micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] # get the axis for layout transform - def get_axis(l, r, trans): - return (r, l) if trans else (l, r) + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) @@ -689,23 +1170,24 @@ def can_enable_swizzle(dtype: str, smooth: bool): return not smooth return False - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_b) + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) # rewrite global smooth layout, for dequantize, currently only support weight only recover. def smooth_gmem_layout_rewrite( - sch, main_block, enable=True, trans=False, matrix_name="A" + sch, + main_block, + enable=True, + trans=False, + matrix_name="A", + intrin_group=intrin_group, ): if not enable: return # normalized block may have three read buffers, while the first one is the write buffer. - buffer_offset = ( - 1 - if sch.get(main_block).reads[0].buffer - == sch.get(main_block).writes[0].buffer - else 0 - ) + buffer_offset = (1 if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer else 0) buffer_idx = 0 if matrix_name == "A" else 1 source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer @@ -714,19 +1196,14 @@ def smooth_gmem_layout_rewrite( # otherwise, the layout transform will have no effect as it will transform both # read and write buffer propagate_block: tir.Block = find_last_producer_from_buffer( - sch, main_block, source_buffer - ) + sch, main_block, source_buffer) # some trick impl may not have reindex block (weight_dequantize_info,) = dequantize_info.values() - if ( - sch.get(propagate_block).name_hint - == weight_dequantize_info["decode_block"] - ): + if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): return # step2: transform the layout with inverse permutation intra_indexmap, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name - ) + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) # step3: propagate the matmul layout to the first reindex block @@ -743,13 +1220,61 @@ def inverse_permutation(i, j, ii, jj): sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + intra_index_map, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # get target dequantize buffer's offset + def get_offset(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_dequantize_info["source_format"]["format"] == "nf": + return 1 + return 0 + + offset = get_offset() + dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) + group_size = weight_dequantize_info["group_size"] + + _, mn, mk = intrin_group["micro_kernel"] + + def get_param_indices( + indexmap, + l=mn, + r=mk, + group_size=group_size # noqa: E741 + ): # noqa: E741 + # assume the param layout is n, k + rl, rr = [x.var for x in sch.get(dequantize_block).iter_vars] + warp_i, warp_j = rl % l, rr % r + spatial_i, spatial_j = rl // l, rr // r + warp_i, warp_j = indexmap.map_indices([warp_i, warp_j]) + new_indices = ( + spatial_i * l + warp_i, + (spatial_j * r + warp_j) // group_size, + ) + return new_indices + + with_scaling = bool(weight_dequantize_info["with_scaling"]) + if with_scaling: + sch.unsafe_rewrite_buffer_region( + dequantize_block, + ("read", offset + 1), + get_param_indices(intra_index_map), + ) + with_zeros = bool(weight_dequantize_info["with_zeros"]) + if with_zeros: + sch.unsafe_rewrite_buffer_region( + dequantize_block, + ("read", offset + 2), + get_param_indices(intra_index_map), + ) + smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A" - ) + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B" - ) + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") warp_size = 32 @@ -766,10 +1291,7 @@ def inverse_permutation(i, j, ii, jj): k_pad_factor = k_factors[1] # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not ( - func.attrs is not None - and "dlight.tensorcore_prenormlized" in func.attrs.keys() - ): + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) # Step 2. Padding for dynamic shape kernels @@ -809,30 +1331,13 @@ def inverse_permutation(i, j, ii, jj): thread_idy = i2 thread_idz = j2 - # plan rasteration - if ( - not isinstance(config.rasterization_plan, NoRasterization) - and sch.get(batch).extent.value == 1 - ): - device_func, invoke_func = config.rasterization_plan.get_code() - factor = config.rasterization_plan.panel_width_ - - # TODO(lei): this is a trick for rasterization implementation - # is not optimal. - # require a solution for general block rasterization - # factor = 8 # should be divisible by block_idy - # if sch.get(block_idx).extent.value % factor == 0: - # block_k, block_idx = sch.split(block_idx, factors=[None, factor]) - # sch.bind(block_k, "blockIdx.z") - else: - sch.bind(batch, "blockIdx.z") - + sch.bind(batch, "blockIdx.z") sch.bind(block_idx, "blockIdx.x") sch.bind(block_idy, "blockIdx.y") sch.bind(thread_idy, "threadIdx.y") sch.bind(thread_idz, "threadIdx.z") - def smooth_layout_recover(block, scope, l=16, r=16, enable=True): + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: return sch.transform_layout( @@ -847,14 +1352,12 @@ def smooth_layout_recover(block, scope, l=16, r=16, enable=True): ), ) - smooth_layout_recover( - block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a - ) + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) smooth_layout_recover( block_outer, ("read", 1), *b_lr, - enable=intrin_info.smooth_b, + enable=intrin_info.inter_transform_b, ) smooth_layout_recover(block_outer, ("write", 0), enable=True) @@ -865,8 +1368,7 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[None, num_ty, num_tz, warp_size, vec_len] - ) + fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) sch.bind(f_3, "threadIdx.x") sch.bind(f_2, "threadIdx.z") @@ -898,45 +1400,57 @@ def decode_fetch_to_shared(block, idx): block_shared = sch.cache_read(block, idx, shared_scope) sch.compute_at(block_shared, k0, preserve_unit_loops=True) - # TODO(lei): the factor shoule be analyzed more deeper. + # TODO(lei): the factor should be analyzed more deeper. decode_factor = get_coalesced_veclen(sch.get(block_shared)) _, B_shared_vi, _ = sch.split( - sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor] - ) + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) block_shared_local = sch.cache_read(block_shared, 0, "local") # global -> dequantzed_local -> shared - # step2. inline to local block - auto_inline_producers(sch, block_shared_local) + # step2. inline to local block, should skip qzeros + is_qzeros = ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"] and + weight_decode_info["zeros_mode"] == "quantized") + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = ( + _collect_producers(sch, weight_dequantize_block) if is_qzeros else []) + auto_inline_producers(sch, block_shared_local, weight_producers) # get target dequantize buffer's idx def get_idx(): # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structual based way + # maybe we can use a more general and structural based way # to analysis the idx - if weight_decode_info["source_format"]["format"] == "af": + if weight_decode_info["source_format"]["format"] == "nf": return 1 return 0 b_idx = get_idx() # global -> prefetch_local -> dequantzed_local -> shared - block_shared_local_local = sch.cache_read( - block_shared_local, b_idx, "local" - ) + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") # global -> prefetch_shared -> vector load -> dequantzed_local -> shared - block_shared_local_local_shared = sch.cache_read( - block_shared_local_local, 0, shared_scope - ) + block_shared_local_local_shared = sch.cache_read(block_shared_local_local, 0, + shared_scope) sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) - sch.compute_at( - block_shared_local_local, B_shared_vi, preserve_unit_loops=True - ) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) dequantize_block_local = block_shared_local + if is_qzeros: + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + # fast type conversion - if ( - "fast_decoding" in weight_decode_info - and weight_decode_info["fast_decoding"] - ): + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): source_bit = weight_decode_info["source_format"]["bits"] out_dtype = weight_decode_info["target_format"] lop3_intrin_info = get_lop3_intrin_group( @@ -945,38 +1459,30 @@ def get_idx(): source_format=weight_decode_info["source_format"]["format"], source_bit=source_bit, with_scaling=weight_decode_info["with_scaling"], - with_zeros=weight_decode_info["with_zeros"] + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], ) sch.tensorize( sch.get_loops(dequantize_block_local)[-1], lop3_intrin_info["compute"], ) - sch.annotate( - thread_idz, - ann_key="pragma_import_c", - ann_val=lop3_intrin_info["c_source"], - ) + import_source.append(lop3_intrin_info["c_source"]) sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( - B_shared_fused, factors=[None, num_ty, num_tz, warp_size] - ) + B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) if not (can_swizzle_b or intrin_info.smooth_b): pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align( - block_shared, 0, axis=-2, factor=16, offset=pad_offset - ) + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) sch.bind(B_shared_tx, "threadIdx.x") sch.bind(B_shared_ty, "threadIdx.y") sch.bind(B_shared_tz, "threadIdx.z") sch.vectorize(sch.get_loops(block_shared)[-1]) sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) - sch.compute_at( - block_shared_local_local_shared, k0, preserve_unit_loops=True - ) + sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) ndim = len(sch.get(block_shared_local_local_shared).iter_vars) fused = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) @@ -1000,13 +1506,10 @@ def get_idx(): # cache small tensors, e.g. LUT if b_idx: - block_shared_lut = sch.cache_read( - dequantize_block_local, 0, shared_scope - ) + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) sch.reverse_compute_at(block_shared_lut, j2) _, B_shared_tx = sch.split( - sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size] - ) + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) sch.bind(B_shared_tx, "threadIdx.x") return block_shared_local @@ -1056,10 +1559,14 @@ def get_idx(): index_map_a, index_map_b, index_map_c = intrin_group["index_map"] sch.transform_layout( - A_mat, ("write", 0), get_index_map(index_map_a, *a_lr, intrin_info.smooth_a) + A_mat, + ("write", 0), + get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), ) sch.transform_layout( - B_mat, ("write", 0), get_index_map(index_map_b, *b_lr, intrin_info.smooth_b) + B_mat, + ("write", 0), + get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), ) sch.transform_layout( store, @@ -1099,6 +1606,22 @@ def tensorize_init_store_compute(): sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) if use_async: sch.annotate(k0, "software_pipeline_async_stages", [0]) + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) return sch def apply_config( # pylint: disable=too-many-locals,missing-docstring @@ -1106,6 +1629,7 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring func: tir.PrimFunc, config, ) -> Optional[tir.Schedule]: + def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 @@ -1114,10 +1638,8 @@ def check_sm_version(arch: str) -> int: """MMA Template only support sm_80 and above""" return None - if ( - config.arch.target.kind.name == "cuda" - and check_sm_version(config.arch.target.arch) == 80 - ): + if (config.arch.target.kind.name == "cuda" and + check_sm_version(config.arch.target.arch) == 80): return self.sch_shared_memory_prefetch_with_config(func, config) else: return self.sch_dequantize_in_register_with_config(func, config) diff --git a/python/bitblas/gpu/matmul_wmma.py b/python/bitblas/gpu/matmul_wmma.py index 765860f88..60817258f 100644 --- a/python/bitblas/gpu/matmul_wmma.py +++ b/python/bitblas/gpu/matmul_wmma.py @@ -3,19 +3,16 @@ # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" -import math from typing import Literal, Optional from tvm import DataType, tir from tvm.target import Target -from tvm.tir.stmt import ForKind from ..base.roller.rasterization import NoRasterization from ..base import analysis from .base import GPUScheduleRule from .matmul_analysis import ( auto_inline_consumer_chain, - auto_inline_consumers, auto_inline_producers, get_index_map, get_reduction_blocks, @@ -69,10 +66,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring thread_z = 2 thread_y = 2 warp_size = 32 - thread_cnt = thread_y * thread_z * warp_size vector_size = 8 - unroll_depth = 256 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] block = sch.reindex(main_block, ("read", 0)) @@ -168,9 +163,8 @@ def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], wm sch.compute_at(block_read, k0) fused = sch.fuse(*sch.get_loops(block_read)[-2:]) - f0, f1, f2, f3, f4 = sch.split( - fused, [None, thread_z, thread_y, warp_size, vector_size] - ) + f0, f1, f2, f3, f4 = sch.split(fused, + [None, thread_z, thread_y, warp_size, vector_size]) sch.bind(f1, "threadIdx.z") sch.bind(f2, "threadIdx.y") @@ -189,12 +183,10 @@ def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], wm return wmma_read - wmma_read_a = fetch_input( - block_outer, 0, [block_m, block_k, micro_size_m, micro_size_k], "wmma.matrix_a" - ) - wmma_read_b = fetch_input( - block_outer, 1, [block_n, block_k, micro_size_n, micro_size_k], "wmma.matrix_b" - ) + wmma_read_a = fetch_input(block_outer, 0, [block_m, block_k, micro_size_m, micro_size_k], + "wmma.matrix_a") + wmma_read_b = fetch_input(block_outer, 1, [block_n, block_k, micro_size_n, micro_size_k], + "wmma.matrix_b") def store_output(block_outer, write_buffer_idx, wmma_name): block_write = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") @@ -202,9 +194,8 @@ def store_output(block_outer, write_buffer_idx, wmma_name): fused = sch.fuse(*sch.get_loops(block_write)[-2:]) - f0, f1, f2, f3, f4 = sch.split( - fused, [None, thread_z, thread_y, warp_size, vector_size] - ) + f0, f1, f2, f3, f4 = sch.split(fused, + [None, thread_z, thread_y, warp_size, vector_size]) sch.bind(f1, "threadIdx.z") sch.bind(f2, "threadIdx.y") @@ -233,8 +224,7 @@ def store_output(block_outer, write_buffer_idx, wmma_name): # Step 5. Schedule tensor core computation from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group, - ) + get_wmma_intrin_group,) intrin_group = get_wmma_intrin_group( load_scope="shared.dyn", @@ -266,8 +256,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring _: bool, ) -> Optional[tir.Schedule]: from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group, - ) + get_wmma_intrin_group,) sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) @@ -433,7 +422,7 @@ def fetch_to_shared(block, idx, ndim): sch.unroll(i0) sch.unroll(j0) sch.tensorize(i1, intrin_group["load_b"]) - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except return None def tensorize_init_store_compute(): @@ -443,7 +432,7 @@ def tensorize_init_store_compute(): try: tensorize_init_store_compute() - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except return None auto_inline_consumer_chain(sch, accumulator_shared_to_global) @@ -469,8 +458,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring _: bool, ) -> Optional[tir.Schedule]: from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group, - ) + get_wmma_intrin_group,) sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) @@ -636,7 +624,7 @@ def fetch_to_shared(block, idx, ndim): sch.unroll(i0) sch.unroll(j0) sch.tensorize(i1, intrin_group["load_b"]) - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except return None # Try to tensorize the init, store and compute block with f16 or f32 intrinsics @@ -650,7 +638,7 @@ def tensorize_init_store_compute(): try: tensorize_init_store_compute() tensorize_success = True - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except intrin_group = get_wmma_intrin_group( load_scope="shared.dyn", store_scope="shared.dyn", @@ -663,7 +651,7 @@ def tensorize_init_store_compute(): try: tensorize_init_store_compute() tensorize_success = True - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except return None auto_inline_consumer_chain(sch, accumulator_shared_to_global) @@ -680,8 +668,7 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring config, ) -> Optional[tir.Schedule]: from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group, - ) + get_wmma_intrin_group,) sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) @@ -767,10 +754,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring block_idy = sch.fuse(i1, j1) thread_idy = sch.fuse(j2, i2) # plan rasteration - if ( - not isinstance(config.rasterization_plan, NoRasterization) - and sch.get(batch).extent.value == 1 - ): + if (not isinstance(config.rasterization_plan, NoRasterization) and + sch.get(batch).extent.value == 1): device_func, invoke_func = config.rasterization_plan.get_code() factor = config.rasterization_plan.panel_width_ @@ -873,7 +858,7 @@ def fetch_to_shared(block, idx, ndim, vec_len, dtype="float16"): sch.unroll(i0) sch.unroll(j0) sch.tensorize(i1, intrin_group["load_b"]) - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except return None # Try to tensorize the init, store and compute block with f16 or f32 intrinsics @@ -887,15 +872,14 @@ def tensorize_init_store_compute(): try: tensorize_init_store_compute() tensorize_success = True - except: # pylint: disable=bare-except + except Exception: # pylint: disable=bare-except return None auto_inline_consumer_chain(sch, accumulator_shared_to_global) fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) _, f1, f2 = sch.split( - fused, factors=[None, warp_size, max(list(config.vectorize.values()))] - ) + fused, factors=[None, warp_size, max(list(config.vectorize.values()))]) sch.bind(f1, "threadIdx.x") sch.vectorize(f2) @@ -906,4 +890,3 @@ def tensorize_init_store_compute(): sch.annotate(k0, "software_pipeline_async_stages", [0]) return sch if tensorize_success else None - diff --git a/python/bitblas/module/__init__.py b/python/bitblas/module/__init__.py new file mode 100644 index 000000000..e6e393e6c --- /dev/null +++ b/python/bitblas/module/__init__.py @@ -0,0 +1,300 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import ctypes +import operator +from functools import reduce +from logging import getLogger + +import torch +import torch.nn as nn + +logger = getLogger(__name__) + +from typing import List, Union + +from bitblas.cache import global_operator_cache, get_database_path +from bitblas import Matmul, MatmulConfig +from bitblas.quantization.utils import general_compress +from bitblas import auto_detect_nvidia_target + +BITBLAS_TARGET = auto_detect_nvidia_target() +BITBLAS_DATABASE_PATH = get_database_path() + + +def unpack_qzeros(qzeros, bits): + qzeros = qzeros.view(torch.int32) + elems_per_int32 = 32 // bits + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + + for col in range(unpacked_zeros.shape[1]): + i = col % elems_per_int32 + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF + + return unpacked_zeros + 1 + + +class Linear(nn.Module): + opt_M = [1, 16, 32, 64, 128, 256, 512] + STORAGE_DTYPE = "int8" # assume int8 storage + TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + A_dtype: str = "float16", + W_dtype: str = "float16", + accum_dtype: str = "float16", + out_dtype: str = "float16", + # configs for weight only quantization + group_size: int = -1, + with_scaling: bool = None, + with_zeros: bool = False, + zeros_mode: str = None, + opt_M: Union[int, List[int]] = opt_M, + # performance related configs + enable_tuning: bool = True, + fast_decoding: bool = True, + propagate_b: bool = False, + ): + """ + @opt_M: optimize range of the input shape for dynamic symbolic + if the input shape is a range, we will optimize the matmul with dynamic symbolic. + if the input shape is int, we will optimize the matmul with static symbolic. + """ + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.opt_M = opt_M + self.group_size = self._set_group_size(group_size, in_features) + self.torch_dtype = getattr(torch, A_dtype) + self.is_consitent = A_dtype == W_dtype + self.zeros_mode = zeros_mode + self._validate_parameters(self.group_size, in_features, out_features) + self._configure_bitblas_matmul( + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_scaling, + with_zeros, + zeros_mode, + enable_tuning, + fast_decoding, + bias, + propagate_b, + ) + self._initialize_buffers(in_features, out_features, bias) + + def init_params(self): + # eliminate runtime overhead like exllama state + if self.is_consitent: + param_list = [self.weight] + if self.bitblas_matmul.config.with_bias: + param_list.append(self.bias) + self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list] + else: + param_list = [self.qweight] + if self.bitblas_matmul.config.with_scaling: + param_list.append(self.scales) + if self.bitblas_matmul.config.with_zeros: + param_list.append(self.zeros) + if self.bitblas_matmul.config.with_bias: + param_list.append(self.bias) + self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list] + + def _validate_parameters(self, group_size, in_features, out_features): + if in_features % 16 != 0 or out_features % 16 != 0: + raise ValueError("`in_features` and `out_features` must be divisible by 16.") + if in_features % group_size != 0: + raise ValueError("`in_features` must be divisible by `group_size`.") + + def _set_group_size(self, group_size, in_features): + return in_features if (group_size == -1 or group_size is None) else group_size + + def _initialize_buffers(self, in_features, out_features, bias): + if self.consistent: + self.register_buffer( + "weight", + torch.zeros((out_features, in_features // self.group_size), dtype=self.torch_dtype), + ) + else: + self.register_buffer( + "qweight", + torch.zeros( + self.bitblas_matmul.retrieve_weight_shape(), + dtype=self.TORCH_STORAGE_DTYPE, + ), + ) + self.register_buffer( + "scales", + torch.zeros((out_features, in_features // self.group_size), dtype=self.torch_dtype), + ) + if self.zeros_mode == "quantized": + storage_nbit = int("".join(c for c in self.STORAGE_DTYPE if c.isdigit())) + self.register_buffer( + "zeros", + torch.zeros( + ( + in_features // self.group_size, + out_features // storage_nbit * self.bits, + ), + dtype=self.TORCH_STORAGE_DTYPE, + ), + ) + else: + self.register_buffer( + "zeros", + torch.zeros( + (out_features, in_features // self.group_size), + dtype=self.torch_dtype, + ), + ) + if bias: + self.register_buffer("bias", torch.zeros((out_features), dtype=self.torch_dtype)) + else: + self.bias = None + + def _configure_bitblas_matmul( + self, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_scaling, + with_zeros, + zeros_mode, + enable_tuning, + fast_decoding, + bias, + propagate_b, + ): + matmul_config = MatmulConfig( + M=self.opt_M, + N=self.out_features, + K=self.in_features, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + storage_dtype=self.STORAGE_DTYPE, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=self.group_size, + fast_decoding=fast_decoding, + with_bias=bias, + propagate_b=propagate_b, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, enable_tuning) + self.bits = self.bitblas_matmul.bit + self.source_format = self.bitblas_matmul.source_format + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + logger.info(f"Loaded {global_operator_cache.size()} operators from database.") + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + # should disable tuning for the first time because we may require loading bitblas operator from database. + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + print("BitBLAS Tuning done, appended operator to global_operator_cache.") + else: + print("BitBLAS Operator created.") + else: + print("BitBLAS Operator found in global_operator_cache.") + return bitblas_matmul + + def warmup(self, topk=20): + self.bitblas_matmul.hardware_aware_finetune(topk=topk) + + def forward(self, A, output=None): + if A.dtype != torch.float16: + A = A.half() + # can be lifted to post init. + self.init_params() + + if output is None: + output = torch.empty( + A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) + m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1)) + A = self.bitblas_matmul.transform_input(A) + A_void = ctypes.c_void_p(A.data_ptr()) + # m is the product of the last n - 1 dimensions of A + self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m) + + return output + + def load_and_transform_weight( + self, + weight: torch.Tensor, + scales: torch.Tensor = None, + zeros: torch.Tensor = None, + bias: torch.Tensor = None, + ): + if self.consistent: + assert scales is None, "scales should be None for consistent mode." + assert zeros is None, "zeros should be None for consistent mode." + weight = self.bitblas_matmul.transform_weight(weight) + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = bias + else: + weight = self.bitblas_matmul.transform_weight(weight) + self.qweight = weight + if scales is not None: + self.scales = scales + if zeros is not None: + self.zeros = zeros + if bias is not None: + self.bias = bias + + def repack_from_gptq(self, gptq_module): + # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. + qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) + if self.bitblas_matmul.weight_transform is not None: + qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() + self.qweight = qweight + # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. + scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) + self.scales = scales + # qzeros should be dequantized to int zeros. + intzeros = unpack_qzeros(gptq_module.qzeros, self.bits).T.contiguous() + if self.bitblas_matmul.config.zeros_mode == "original": + self.zeros = intzeros.to(torch.float16).contiguous() + elif self.bitblas_matmul.config.zeros_mode == "rescale": + self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :] + elif self.bitblas_matmul.config.zeros_mode == "quantized": + self.zeros = ( + torch.Tensor(general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits)).to( + self.qweight.device).to(self.zeros.dtype).contiguous()) + else: + raise ValueError(f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}") + if self.bias is not None: + self.bias = gptq_module.bias.data.to(torch.float16).contiguous() + + @property + def consistent(self): + return self.is_consitent + + +__all__ = ["Linear"] diff --git a/python/bitblas/ops/__init__.py b/python/bitblas/ops/__init__.py index 08fd3d5d8..cdacc5bad 100644 --- a/python/bitblas/ops/__init__.py +++ b/python/bitblas/ops/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .operator import Operator -from .matmul import Matmul, MatmulConfig -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig +from .operator import Operator # noqa: F401 +from .matmul import Matmul, MatmulConfig # noqa: F401 +from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401 +from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401 +from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig # noqa: F401 diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py new file mode 100644 index 000000000..fbdf8058e --- /dev/null +++ b/python/bitblas/ops/general_matmul.py @@ -0,0 +1,517 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.target import Target +import operator +from functools import reduce +from bitblas.base.roller.arch.cuda import CUDA +from typing import Any, List, Literal, Optional, Tuple, Union +from .operator import Operator, TransformKind +from .impl.matmul_dequantize_impl import ( + select_implementation as weight_dequantize_implementation,) +from .impl.matmul_impl import select_implementation as consistent_implementation +from ..base.utils import tensor_replace_dp4a +from bitblas.utils.target_detector import auto_detect_nvidia_target +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch +from dataclasses import dataclass +from .ladder_permutate import LadderPermutate, LadderPermutateConfig +from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig +import logging +import torch + +logger = logging.getLogger(__name__) + +WORKSPACE_SIZE = 1024 * 1024 * 256 + + +class OPExecutorCPU: + + def __init__(self, operators: Optional[List[Operator]] = None): + if operators is None: + operators = [] + self.operators = operators + + def append(self, op): + self.operators.append(op) + + def is_none(self): + return len(self.operators) == 0 + + def forward(self, weight): + inputs = [weight] + for op in self.operators: + inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) + inputs = [op.forward(*inputs)] + return inputs[-1] + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def size(self): + return len(self.operators) + + +@dataclass(frozen=True) +class MatmulConfig: + M: Union[int, Tuple[int]] + N: int + K: int + A_dtype: str = "float16" + # is a wrapper for source_format and bit + W_dtype: str = A_dtype # W_dtype is the same as A_dtype by default + out_dtype: str = "float16" + accum_dtype: str = "float16" + layout: Literal["nn", "nt", "tn", "tt"] = "nt" + with_bias: bool = False + group_size: int = -1 + with_scaling: bool = False + with_zeros: bool = False + # documents for zeros_mode: + # original: target = (dequantize_weight - zero_point) * scale + # rescale: target = dequantize_weight * scale - zero_point + # quantized: target = (dequantize_weight - dequantize_zeros) * scale + # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. + zeros_mode: Literal["original", "rescale", "quantized"] = "original" + storage_dtype: str = "int8" + + # weight transform related flags + fast_decoding: bool = True # enable fast decoding by default + propagate_a: TransformKind = TransformKind.NonTransform + propagate_b: TransformKind = TransformKind.NonTransform + + def __post_init__(self): + # set M to default dynamic range if it is None + if self.M is None: + object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) + # set M to tuple if it is list + # otherwise, M is not hashable + object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) + if isinstance(self.propagate_a, bool): + object.__setattr__( + self, + "propagate_a", + (TransformKind.IntraWarpTransform + if self.propagate_a else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_a, int): + object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) + + if isinstance(self.propagate_b, bool): + object.__setattr__( + self, + "propagate_b", + (TransformKind.IntraWarpTransform + if self.propagate_b else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_b, int): + object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) + + # This is hack to legalize propagate_a and b + # TODO(lei): should be removed in the future when tc+br template is ready. + MICRO_KERNEL_SIZE = 16 + if isinstance( + self.M, + int) and (self.M % MICRO_KERNEL_SIZE) == 0 and (self.K % MICRO_KERNEL_SIZE) == 0: + object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform) + else: + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + + if self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or ( + self.K % MICRO_KERNEL_SIZE) != 0 or isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized"): + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + object.__setattr__(self, "propagate_b", TransformKind.NonTransform) + else: + object.__setattr__(self, "propagate_b", TransformKind.IntraWarpTransform) + + if self.zeros_mode is None: + object.__setattr__(self, "zeros_mode", "original") + + if "int" not in self.W_dtype: + object.__setattr__(self, "fast_decoding", False) + else: + object.__setattr__(self, "fast_decoding", self.fast_decoding) + + if self.with_bias is None: + object.__setattr__(self, "with_bias", False) + + if self.group_size is None: + object.__setattr__(self, "group_size", -1) + + if self.with_scaling is None: + object.__setattr__(self, "with_scaling", False) + + if self.with_zeros is None: + object.__setattr__(self, "with_zeros", False) + + if self.A_dtype == self.W_dtype and self.W_dtype in ["float16", "int8"]: + object.__setattr__(self, "storage_dtype", self.W_dtype) + + +class Matmul(Operator): + + # TODO(lei): This should be improved into a general datatype. + BITBLAS_TRICK_DTYPE_MAP = { + "float64": ("fp", 64), + "float32": ("fp", 32), + "float16": ("fp", 16), + "int32": ("int", 32), + "uint32": ("uint", 32), + "int16": ("int", 16), + "uint16": ("uint", 16), + "int8": ("int", 8), + "uint8": ("uint", 8), + "int4": ("int", 4), + "uint4": ("uint", 4), + "int2": ("int", 2), + "uint2": ("uint", 2), + "int1": ("int", 1), + "uint1": ("uint", 1), + "nf4": ("nf", 4), + "fp8_e5m2": ("fp", 8), + "fp4_e2m1": ("fp", 4), + } + + def __init__( + self, + config: MatmulConfig, + name: str = "matmul", + target: Optional[Union[str, Target]] = None, + enable_tuning: bool = True, + ): + if target is None: + target = auto_detect_nvidia_target() + assert (config.A_dtype + in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" + source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] + + self.source_format = source_format + self.bit = bit + super().__init__(name, config, target) + + if source_format == "int" and self.with_zeros: + logger.warning( + "[BitBLAS][Warning] with_zeros is not supported for int source format as int has a constant zeropoints already." + ) + + target = self.target + if target.kind.name != "cuda": + raise ValueError("Currently only support cuda target") + + self.arch = CUDA(target) + + try: + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + except Exception: + self.optimized_func = None + logger.warnning( + "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." + ) + + if isinstance(self.M, Tuple): + self.dynamic_range = {"m": self.M} + self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( + {"opt_shapes": self.dynamic_range}) + else: + self.dynamic_range = None + + self._build_runtime_module(target) + + self.workspace = None + if self.propagate_a: + # for general purpose, we use propagate_a to control the ladder permutation. + ladder_permutate_config = LadderPermutateConfig( + M=self.M, + N=self.K, + datatype=self.A_dtype, + storage_dtype=self.A_dtype, + propagate_kind="A", + transpose_matrix=False, + transform_kind=self.propagate_a, + ) + self.ladder_permutate_a = LadderPermutate( + config=ladder_permutate_config, + target=target, + enable_tuning=enable_tuning, + ) + self.workspace = torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() + else: + self.ladder_permutate_a = None + + if self.propagate_b: + ladder_permutate_config = LadderPermutateConfig( + M=self.N, + N=self.K, + datatype=self.A_dtype, + dequantize_bits=self.bit, + storage_dtype=self.storage_dtype, + propagate_kind="B", + transpose_matrix=self.layout == "nt", + transform_kind=self.propagate_b, + ) + self.ladder_permutate_b = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_b = None + + if self.fast_decoding: + lop3_permutate_config = LOP3PermutateConfig( + M=self.N, + N=self.K, + datatype=self.A_dtype, + dequantize_bits=self.bit, + storage_dtype=self.storage_dtype, + ) + self.lop3_permutate = LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.lop3_permutate = None + + input_executors = OPExecutorCPU() + if self.ladder_permutate_a is not None: + input_executors.append(self.ladder_permutate_a) + self.input_executors = input_executors + + weight_executors = OPExecutorCPU() + if self.lop3_permutate is not None: + weight_executors.append(self.lop3_permutate) + + if self.ladder_permutate_b is not None: + weight_executors.append(self.ladder_permutate_b) + + self.weight_executors = weight_executors + + if enable_tuning: + self.hardware_aware_finetune() + + if source_format == "nf": + self.lut = torch.Tensor(([ + -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, + -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, + 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, + 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 + ]), + dtype=getattr(torch, self.A_dtype)).cuda() + else: + self.lut = None + + def _select_implementation(self): + if self.A_dtype == self.W_dtype: + return consistent_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + with_bias=self.with_bias, + layout=self.layout, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + else: + return weight_dequantize_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + bit=self.bit, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + layout=self.layout, + zeros_mode=self.zeros_mode, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + + def post_process(self, code: str) -> str: + code = tensor_replace_dp4a(code) + return code + + def retrieve_weight_shape(self): + return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] + + def transform_weight(self, weight, scale=None, zeros=None, bias=None): + """ + Transforms the given weight tensor based on the specified quantization parameters and + returns the transformed weight along with optional scale, zeros, and bias. + + Parameters: + - weight: The input weight tensor to be transformed. + - scale: Optional scaling factor for the weight tensor. + - zeros: Optional zero-point adjustment for the weight tensor. + - bias: Optional bias to be added to the weight tensor. + + Returns: + A list containing the transformed weight tensor and optionally the scale, zeros, and bias. + """ + if self.W_dtype == self.A_dtype: + if self.weight_transform is not None: + return self.weight_transform(weight.cpu()).cuda().contiguous() + return weight + + from bitblas.quantization import general_compress + import torch + import numpy as np + + source_format, bit = self.source_format, self.bit + + # Process integer source format + if source_format == "int": + assert not self.with_scaling, "scale should be False for int source format" + assert not self.with_zeros, "zeros should be False for int source format" + maxq = 2**(bit - 1) - 1 + # Clamp weight values to be within the quantizable range and adjust + weight = torch.clamp(weight, -maxq, maxq).int() + maxq + else: + # For non-integer formats, simply convert weights to integers + weight = weight.int() + + np_storage_dtype = getattr(np, self.storage_dtype) + + weight = general_compress( + weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) + + weight = torch.from_numpy(weight).cuda().contiguous() + + # Apply an optional weight transformation if specified + if self.weight_transform is not None: + weight = self.weight_transform(weight.cpu()).cuda().contiguous() + + # Prepare the return list with the transformed weight and optionally include scale, zeros, and bias + result = [weight] + if scale is not None: + result.append(scale) + if zeros is not None: + result.append(zeros) + if bias is not None: + result.append(bias) + + return next(iter(result), result) + + def transform_input(self, input_tensor): + if self.propagate_a is not TransformKind.NonTransform: + # check workspace size + if input_tensor.numel() > WORKSPACE_SIZE: + raise ValueError( + f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." + ) + self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) + return self.workspace + return input_tensor + + def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: + args = [] + args.append(self.transform_input(A)) + if self.lut is not None: + args.append(self.lut) + args.append(W) + + if output is None: + output = torch.empty(A.shape[:-1] + (self.N,), dtype=A.dtype, device=A.device) + if scale is not None: + args.append(scale) + if zeros is not None: + args.append(zeros) + if bias is not None: + args.append(bias) + args.append(output) + + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) + + if self.lib is None: + self._forward_from_torch_func(*args) + self._forward_from_prebuild_lib(*args) + + return output + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def K(self): + return self.config.K + + @property + def A_dtype(self): + return self.config.A_dtype + + @property + def W_dtype(self): + return self.config.W_dtype + + @property + def out_dtype(self): + return self.config.out_dtype + + @property + def accum_dtype(self): + return self.config.accum_dtype + + @property + def storage_dtype(self): + return self.config.storage_dtype + + @property + def with_scaling(self): + return self.config.with_scaling + + @property + def with_zeros(self): + return self.config.with_zeros + + @property + def group_size(self): + return self.config.group_size + + @property + def fast_decoding(self): + return self.config.fast_decoding + + @property + def with_bias(self): + return self.config.with_bias + + @property + def propagate_a(self): + return self.config.propagate_a + + @property + def propagate_b(self): + return self.config.propagate_b + + @property + def layout(self): + return self.config.layout + + @property + def zeros_mode(self): + return self.config.zeros_mode + + @property + def input_transform(self): + return self.input_executors if self.input_executors.size else None + + @property + def weight_transform(self): + return self.weight_executors if self.weight_executors.size else None diff --git a/python/bitblas/ops/impl/convolution2d_impl.py b/python/bitblas/ops/impl/convolution2d_impl.py new file mode 100644 index 000000000..d77d8f573 --- /dev/null +++ b/python/bitblas/ops/impl/convolution2d_impl.py @@ -0,0 +1,190 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te, tir + + +def conv2d_nhwc_ohwi( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype="float16", + accum_dtype="float16", + out_dtype="float16", +): + + A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) + B = te.placeholder((f, kh, kw, c), name="weight", dtype=in_dtype) + + pad_shape = (n, h + 2 * p, w + 2 * p, c) + pad_value = tir.const(0.0, A.dtype) + pad = te.compute( + pad_shape, + lambda n, h, w, c: te.if_then_else( + tir.all( + h >= p, + w >= p, + h < pad_shape[1] - p, + w < pad_shape[2] - p, + ), + A[n, h - p, w - p, c], + pad_value, + ), + name="pad", + ) + kernel_h, kernel_w = kh, kw + stride_h, stride_w = s, s + dilation_h, dilation_w = d, d + out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + out_shape = (n, out_h, out_w, f) + kh = te.reduce_axis((0, kernel_h), name="kh") + kw = te.reduce_axis((0, kernel_w), name="kw") + c = te.reduce_axis((0, c), name="c") + C = te.compute( + out_shape, + lambda n, h, w, f: te.sum( + pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), + c,].astype(accum_dtype) * B[f, kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( + dilation_w), c].astype(accum_dtype), + axis=[kh, kw, c], + ), + name="C", + ) + args = [A, B] + last_output = C + if accum_dtype != out_dtype: + D = te.compute(out_shape, lambda n, h, w, c: C[n, h, w, c].astype(out_dtype), name="D") + last_output = D + args.append(last_output) + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def conv2d_nhwc_hwio( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype="float16", + accum_dtype="float16", + out_dtype="float16", +): + + A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) + B = te.placeholder((kh, kw, c, f), name="weight", dtype=in_dtype) + + pad_shape = (n, h + 2 * p, w + 2 * p, c) + pad_value = tir.const(0.0, A.dtype) + pad = te.compute( + pad_shape, + lambda n, h, w, c: te.if_then_else( + tir.all( + h >= p, + w >= p, + h < pad_shape[1] - p, + w < pad_shape[2] - p, + ), + A[n, h - p, w - p, c], + pad_value, + ), + name="pad", + ) + kernel_h, kernel_w = kh, kw + stride_h, stride_w = s, s + dilation_h, dilation_w = d, d + out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + out_shape = (n, out_h, out_w, f) + kh = te.reduce_axis((0, kernel_h), name="kh") + kw = te.reduce_axis((0, kernel_w), name="kw") + c = te.reduce_axis((0, c), name="c") + C = te.compute( + out_shape, + lambda n, h, w, f: te.sum( + pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), + c,].astype(accum_dtype) * B[kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( + dilation_w), c, f].astype(accum_dtype), + axis=[kh, kw, c], + ), + name="C", + ) + args = [A, B] + last_output = C + if accum_dtype != out_dtype: + D = te.compute(out_shape, lambda n, h, w, c: C[n, h, w, c].astype(out_dtype), name="D") + last_output = D + args.append(last_output) + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def select_implementation( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype="float16", + accum_dtype="float16", + out_dtype="float16", + input_layout="nhwc", + weight_layout="ohwi", +): + assert input_layout in ["nhwc", "nchw"] + if input_layout == "nhwc" and weight_layout == "ohwi": + return conv2d_nhwc_ohwi( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype, + accum_dtype, + out_dtype, + ) + elif input_layout == "nhwc" and weight_layout == "hwio": + return conv2d_nhwc_hwio( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype, + accum_dtype, + out_dtype, + ) + else: + raise ValueError("Unsupported input_layout: {} and weight_layout: {}".format( + input_layout, weight_layout)) diff --git a/python/bitblas/ops/impl/ladder_permutate_impl.py b/python/bitblas/ops/impl/ladder_permutate_impl.py index ce7543243..5ac4a5334 100644 --- a/python/bitblas/ops/impl/ladder_permutate_impl.py +++ b/python/bitblas/ops/impl/ladder_permutate_impl.py @@ -22,27 +22,21 @@ def select_implementation( # This is trick to get the basic tile size for the current datatype # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 - l = r = 16 + l = r = 16 # noqa: E741 if datatype == "int8": - l, r = 16, 32 - + l, r = 16, 32 # noqa: E741 intra_index_map, _ = get_propagate_map( - transpose_matrix, dtype=datatype, matrix_name=propagate_kind - ) + transpose_matrix, dtype=datatype, matrix_name=propagate_kind) target_dtype = DataType(datatype) scaling_factor = 1 if dequantize_bits > 0 and dequantize_bits < target_dtype.bits: - scaling_factor = ( - (target_dtype.bits // dequantize_bits) - * DataType(storage_dtype).bits - // target_dtype.bits - ) + scaling_factor = ((target_dtype.bits // dequantize_bits) * DataType(storage_dtype).bits // + target_dtype.bits) r = r // scaling_factor initial_indices = intra_index_map.initial_indices - scaling_final_indices = intra_index_map.map_indices( - initial_indices[:-1] + [initial_indices[-1] * scaling_factor] - ) + scaling_final_indices = intra_index_map.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) scaling_final_indices = scaling_final_indices[:-1] + [ scaling_final_indices[-1] // scaling_factor ] diff --git a/python/bitblas/ops/impl/lop3_permutate_impl.py b/python/bitblas/ops/impl/lop3_permutate_impl.py index bf7b83709..893458409 100644 --- a/python/bitblas/ops/impl/lop3_permutate_impl.py +++ b/python/bitblas/ops/impl/lop3_permutate_impl.py @@ -5,6 +5,7 @@ from tvm.ir import GlobalVar from tvm.script import tir as T + # fmt: off # TIR interleave weight impl-> 2D implementation def tir_interleave_weight( @@ -23,24 +24,17 @@ def tir_interleave_weight( elems_per_group = bits_stride // bits @T.prim_func - def interleave_weight( - A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype) - ): + def interleave_weight(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype)): for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): with T.block("B"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) @T.prim_func - def interleave_weight_f16_2b( - A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype) - ): + def interleave_weight_f16_2b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") @@ -48,12 +42,8 @@ def interleave_weight_f16_2b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -64,9 +54,8 @@ def interleave_weight_f16_2b( B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] @T.prim_func - def interleave_weight_f16_1b( - A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype) - ): + def interleave_weight_f16_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") @@ -78,12 +67,8 @@ def interleave_weight_f16_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -102,13 +87,11 @@ def interleave_weight_f16_1b( | B_tmp_4[v0, v1] | B_tmp_5[v0, v1] | B_tmp_6[v0, v1] - | B_tmp_7[v0, v1] - ) + | B_tmp_7[v0, v1]) @T.prim_func - def interleave_weight_int8_1b( - A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype) - ): + def interleave_weight_int8_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") @@ -118,12 +101,8 @@ def interleave_weight_int8_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -138,8 +117,7 @@ def interleave_weight_int8_1b( | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] | B_tmp_4[v0, v1] - | B_tmp_5[v0, v1] - ) + | B_tmp_5[v0, v1]) if target_dtype == "float16" and bits == 2: return interleave_weight_f16_2b @@ -149,6 +127,8 @@ def interleave_weight_int8_1b( return interleave_weight_int8_1b return interleave_weight + + # fmt: on diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index 14da708e9..f0f59e035 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -2,13 +2,15 @@ # Licensed under the MIT License. # pre-transformed tir expression of matmul import tvm -from tvm.script import tir as T from tvm import te, DataType from tvm.tir import IndexMap +from bitblas.ops.operator import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import ( _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, ) @@ -27,7 +29,7 @@ def matmul_nt_dequantize_b( group_size=-1, fast_decoding=False, with_bias=False, - zeros_type="original", + zeros_mode="original", ): if not isinstance(M, int): M = tvm.te.var("m") @@ -42,26 +44,50 @@ def matmul_nt_dequantize_b( LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((K // group_size), N // storage_nbit * bit), + name="QZeros", + dtype=storage_dtype) Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + def qzeros_dequantize(k, n): + return _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + QZeros[k, n // n_float_per_elem], + n % n_float_per_elem, + dtype=storage_dtype, + ) + + Dequantize_qzeros = te.compute( + (K // group_size, N), + qzeros_dequantize, + name="Dequantize_zeros", + ) + def decode_func(n, k): - if source_format == "uint": - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype + if with_zeros and zeros_mode == "quantized": + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + Dequantize_qzeros[k // group_size, n], + dtype=in_dtype, ) + elif source_format == "uint": + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "int": w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype - ) - elif source_format == "af": - w = LUT[ - _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - ) - ] + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] else: raise ValueError("Unsupported source_format: {}".format(source_format)) @@ -71,12 +97,14 @@ def decode_func(n, k): if not with_zeros: return w * Scale[n, k // group_size] - if zeros_type == "original": + if zeros_mode == "original": w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_type == "rescale": + elif zeros_mode == "rescale": w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + elif zeros_mode == "quantized": + w = w * Scale[n, k // group_size] else: - raise ValueError("Unsupported zeros_type: {}".format(zeros_type)) + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) return w @@ -86,19 +114,21 @@ def decode_func(n, k): C = te.compute( (M, N), lambda i, j: te.sum( - A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k - ), + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") args = [A, B] last_output = D - if source_format == "af": + if source_format == "nf": args.append(LUT) if with_scaling: args.append(Scale) if with_zeros: - args.append(Zeros) + if zeros_mode == "quantized": + args.append(QZeros) + else: + args.append(Zeros) if with_bias: E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") last_output = E @@ -119,7 +149,7 @@ def decode_func(n, k): "target_format": in_dtype, "with_scaling": with_scaling, "with_zeros": with_zeros, - "zeros_type": zeros_type, + "zeros_mode": zeros_mode, "group_size": group_size, } }, @@ -142,28 +172,25 @@ def matmul_nt_dequantize_b_propagate_b( group_size=-1, fast_decoding=False, with_bias=False, - zeros_type="original", + zeros_mode="original", + transform_kind: TransformKind = TransformKind.IntraWarpTransform, ): if not isinstance(M, int): M = tvm.te.var("m") - l = r = 16 + l = r = 16 # noqa: E741 if in_dtype == "int8": - l, r = 16, 32 + l, r = 16, 32 # noqa: E741 _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") target_dtype = DataType(in_dtype) scaling_factor = 1 if bit > 0 and bit < target_dtype.bits: - scaling_factor = ( - (target_dtype.bits // bit) - * DataType(storage_dtype).bits - // target_dtype.bits - ) + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) initial_indices = inverse_indexmap.initial_indices - scaling_final_indices = inverse_indexmap.map_indices( - initial_indices[:-1] + [initial_indices[-1] * scaling_factor] - ) + scaling_final_indices = inverse_indexmap.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) scaling_final_indices = scaling_final_indices[:-1] + [ scaling_final_indices[-1] // scaling_factor ] @@ -180,9 +207,7 @@ def matmul_nt_dequantize_b_propagate_b( group_size = K qr = r * bit // storage_nbit A = te.placeholder((M, K), name="A", dtype=in_dtype) - B = te.placeholder( - (N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype - ) + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) @@ -191,8 +216,9 @@ def matmul_nt_dequantize_b_propagate_b( def fcompute(i, j): warp_i, warp_j = i % l, j % qr spatial_args = i // l, j // qr - permutate_i, permutate_j = inverse_indexmap.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, permutate_i, permutate_j) + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inverse_indexmap.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) return B[new_index] B_reindex = te.compute( @@ -216,15 +242,20 @@ def decode_func(n, k): k % n_float_per_elem, dtype=in_dtype, ) - elif source_format == "af": - w = LUT[ - _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - ) - ] + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] else: raise ValueError("Unsupported source_format: {}".format(source_format)) @@ -234,12 +265,12 @@ def decode_func(n, k): if not with_zeros: return w * Scale[n, k // group_size] - if zeros_type == "original": + if zeros_mode == "original": w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_type == "rescale": + elif zeros_mode == "rescale": w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] else: - raise ValueError("Unsupported zeros_type: {}".format(zeros_type)) + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) return w @@ -250,17 +281,18 @@ def decode_func(n, k): C = te.compute( (M, N), lambda i, j: te.sum( - A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k - ), + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") args = [A, B] last_output = D - if source_format == "af": + if source_format == "nf": args.append(LUT) if with_scaling: args.append(Scale) + if with_zeros: + args.append(Zeros) if with_bias: E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") last_output = E @@ -280,13 +312,13 @@ def decode_func(n, k): "storage_dtype": storage_dtype, "target_format": in_dtype, "with_zeros": with_zeros, - "zeros_type": zeros_type, + "zeros_mode": zeros_mode, "with_scaling": with_scaling, "group_size": group_size, } }, ) - func = func.with_attr("smooth_b", True) + func = func.with_attr("weight_transform_kind", transform_kind.value) return tvm.IRModule.from_expr(func) @@ -305,22 +337,25 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( group_size=-1, fast_decoding=False, with_bias=False, - zeros_type="original", + zeros_mode="original", + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, ): if not isinstance(M, int): M = tvm.te.var("m") - l = r = 16 + l = r = 16 # noqa: E741 if in_dtype == "int8": - l, r = 16, 32 + l, r = 16, 32 # noqa: E741 _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) def fcompute(i, j): warp_i, warp_j = i % l, j % r spatial_args = i // l, j // r - permutate_i, permutate_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, permutate_i, permutate_j) + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) return A[new_index] A_reindex = te.compute( @@ -333,15 +368,11 @@ def fcompute(i, j): target_dtype = DataType(in_dtype) scaling_factor = 1 if bit > 0 and bit < target_dtype.bits: - scaling_factor = ( - (target_dtype.bits // bit) - * DataType(storage_dtype).bits - // target_dtype.bits - ) + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) initial_indices = inversed_index_map.initial_indices scaling_final_indices = inversed_index_map.map_indices( - initial_indices[:-1] + [initial_indices[-1] * scaling_factor] - ) + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) scaling_final_indices = scaling_final_indices[:-1] + [ scaling_final_indices[-1] // scaling_factor ] @@ -357,18 +388,18 @@ def fcompute(i, j): if group_size == -1: group_size = K qr = r * bit // storage_nbit - B = te.placeholder( - (N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype - ) + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) def fcompute(i, j): warp_i, warp_j = i % l, j % qr spatial_args = i // l, j // qr - permutate_i, permutate_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, permutate_i, permutate_j) + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) return B[new_index] B_reindex = te.compute( @@ -392,20 +423,36 @@ def decode_func(n, k): k % n_float_per_elem, dtype=in_dtype, ) - elif source_format == "af": - w = LUT[ - _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - ) - ] + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] else: raise ValueError("Unsupported source_format: {}".format(source_format)) - if with_scaling: - w = w * Scale[n, k // group_size] + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + return w B_decode = te.compute((N, K), decode_func, name="B_decode") @@ -423,10 +470,12 @@ def decode_func(n, k): D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") args = [A, B] last_output = D - if source_format == "af": + if source_format == "nf": args.append(LUT) if with_scaling: args.append(Scale) + if with_zeros: + args.append(Zeros) if with_bias: E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") last_output = E @@ -446,14 +495,14 @@ def decode_func(n, k): "storage_dtype": storage_dtype, "target_format": in_dtype, "with_zeros": with_zeros, - "zeros_type": zeros_type, + "zeros_mode": zeros_mode, "with_scaling": with_scaling, "group_size": group_size, } }, ) - func = func.with_attr("smooth_a", True) - func = func.with_attr("smooth_b", True) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) return tvm.IRModule.from_expr(func) @@ -473,7 +522,7 @@ def select_implementation( fast_decoding=False, with_bias=False, layout="nt", - zeros_type="original", + zeros_mode="original", propagate_a=False, propagate_b=False, ): @@ -498,7 +547,9 @@ def select_implementation( group_size, fast_decoding, with_bias, - zeros_type, + zeros_mode, + transform_kind_input=propagate_a, + transform_kind_weight=propagate_b, ) elif propagate_a: raise NotImplementedError @@ -518,7 +569,8 @@ def select_implementation( group_size, fast_decoding, with_bias, - zeros_type, + zeros_mode, + transform_kind=propagate_b, ) else: return matmul_nt_dequantize_b( @@ -536,7 +588,7 @@ def select_implementation( group_size, fast_decoding, with_bias, - zeros_type, + zeros_mode, ) else: raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/matmul_impl.py b/python/bitblas/ops/impl/matmul_impl.py index c716227a0..26b748a88 100644 --- a/python/bitblas/ops/impl/matmul_impl.py +++ b/python/bitblas/ops/impl/matmul_impl.py @@ -2,9 +2,9 @@ # Licensed under the MIT License. # pre-transformed tir expression of matmul import tvm -from tvm.script import tir as T -from tvm import te, tir +from tvm import te from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.ops.operator import TransformKind def matmul_nn( @@ -26,9 +26,7 @@ def matmul_nn( k = te.reduce_axis((0, K), name="k") C = te.compute( (M, N), - lambda i, j: te.sum( - A[i, k].astype(accum_dtype) * B[k, j].astype(accum_dtype), axis=k - ), + lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[k, j].astype(accum_dtype), axis=k), name="C", ) last_output = C @@ -40,10 +38,7 @@ def matmul_nn( E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") last_output = E - if with_bias: - args = [A, B, Bias, last_output] - else: - args = [A, B, last_output] + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] func = te.create_prim_func(args) @@ -69,9 +64,7 @@ def matmul_nt( k = te.reduce_axis((0, K), name="k") C = te.compute( (M, N), - lambda i, j: te.sum( - A[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k - ), + lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), name="C", ) last_output = C @@ -83,10 +76,7 @@ def matmul_nt( E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") last_output = E - if with_bias: - args = [A, B, Bias, last_output] - else: - args = [A, B, last_output] + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] func = te.create_prim_func(args) @@ -108,41 +98,6 @@ def matmul( return matmul_nt(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) -# always assume propagate both intra and inter layout in BitBLAS -# as we do not have to do runtime layout transform -def matmul_nt_propagate_a_dyn_m( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): ... - - -def matmul_nt_propagate_b_dyn_m( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): ... - - -def matmul_nt_propagate_a_propagate_b_dyn_m( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): ... - - def matmul_nt_propagate_a( M, N, @@ -151,16 +106,15 @@ def matmul_nt_propagate_a( out_dtype="float16", accum_dtype="float16", with_bias=False, + transform_kind: TransformKind = TransformKind.IntraWarpTransform, ): if not isinstance(M, int): M = tvm.te.var("m") - l = r = 16 + l = r = 16 # noqa: E741 if in_dtype == "int8": - l, r = 16, 32 + l, r = 16, 32 # noqa: E741 - _, inversed_index_map = get_propagate_map( - trans=False, dtype=in_dtype, matrix_name="A" - ) + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) B = te.placeholder((N, K), name="B", dtype=in_dtype) @@ -169,8 +123,9 @@ def matmul_nt_propagate_a( def fcompute(i, j): warp_i, warp_j = i % l, j % r spatial_args = i // l, j // r - permutate_i, permutate_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, permutate_i, permutate_j) + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) return A[new_index] A_reindex = te.compute( @@ -183,8 +138,7 @@ def fcompute(i, j): C = te.compute( (M, N), lambda i, j: te.sum( - A_reindex[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k - ), + A_reindex[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), name="C", ) last_output = C @@ -196,13 +150,10 @@ def fcompute(i, j): E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") last_output = E - if with_bias: - args = [A, B, Bias, last_output] - else: - args = [A, B, last_output] + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] func = te.create_prim_func(args) - func = func.with_attr("smooth_a", True) + func = func.with_attr("input_transform_kind", transform_kind.value) return tvm.IRModule.from_expr(func) @@ -215,16 +166,15 @@ def matmul_nt_propagate_b( out_dtype="float16", accum_dtype="float16", with_bias=False, + transform_kind: TransformKind = TransformKind.IntraWarpTransform, ): if not isinstance(M, int): M = tvm.te.var("m") - l = r = 16 + l = r = 16 # noqa: E741 if in_dtype == "int8": - l, r = 16, 32 + l, r = 16, 32 # noqa: E741 - _, inversed_index_map = get_propagate_map( - trans=True, dtype=in_dtype, matrix_name="B" - ) + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") A = te.placeholder((M, K), name="A", dtype=in_dtype) B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) @@ -233,8 +183,9 @@ def matmul_nt_propagate_b( def fcompute(i, j): warp_i, warp_j = i % l, j % r spatial_args = i // l, j // r - permutate_i, permutate_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, permutate_i, permutate_j) + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) return B[new_index] B_reindex = te.compute( @@ -247,8 +198,7 @@ def fcompute(i, j): C = te.compute( (M, N), lambda i, j: te.sum( - A[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), axis=k - ), + A[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), axis=k), name="C", ) last_output = C @@ -260,13 +210,10 @@ def fcompute(i, j): E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") last_output = E - if with_bias: - args = [A, B, Bias, last_output] - else: - args = [A, B, last_output] + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] func = te.create_prim_func(args) - func = func.with_attr("smooth_b", True) + func = func.with_attr("weight_transform_kind", transform_kind.value) return tvm.IRModule.from_expr(func) @@ -279,26 +226,27 @@ def matmul_nt_propagate_a_propagate_b( out_dtype="float16", accum_dtype="float16", with_bias=False, + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, ): if not isinstance(M, int): M = tvm.te.var("m") - l = r = 16 + l = r = 16 # noqa: E741 if in_dtype == "int8": - l, r = 16, 32 + l, r = 16, 32 # noqa: E741 A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - _, inversed_index_map = get_propagate_map( - trans=False, dtype=in_dtype, matrix_name="A" - ) + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") def fcompute(i, j): warp_i, warp_j = i % l, j % r spatial_args = i // l, j // r - permutate_i, permutate_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, permutate_i, permutate_j) + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) return A[new_index] A_reindex = te.compute( @@ -307,15 +255,14 @@ def fcompute(i, j): name="A_reindex", ) - _, inversed_index_map = get_propagate_map( - trans=True, dtype=in_dtype, matrix_name="B" - ) + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") def fcompute(i, j): warp_i, warp_j = i % l, j % r spatial_args = i // l, j // r - permutate_i, permutate_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, permutate_i, permutate_j) + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) return B[new_index] B_reindex = te.compute( @@ -342,14 +289,11 @@ def fcompute(i, j): E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") last_output = E - if with_bias: - args = [A, B, Bias, last_output] - else: - args = [A, B, last_output] + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] func = te.create_prim_func(args) - func = func.with_attr("smooth_a", True) - func = func.with_attr("smooth_b", True) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) return tvm.IRModule.from_expr(func) @@ -363,27 +307,48 @@ def select_implementation( accum_dtype="float16", with_bias=False, layout="nt", - propagate_a=False, - propagate_b=False, + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, ): if layout == "nn": if propagate_a or propagate_b: raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn" - ) + "Currently only support propagate_a=False and propagate_b=False for layout=nn") return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) elif layout == "nt": if propagate_a and propagate_b: return matmul_nt_propagate_a_propagate_b( - M, N, K, in_dtype, out_dtype, accum_dtype, with_bias + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind_input=propagate_a, + transform_kind_weight=propagate_b, ) elif propagate_a: return matmul_nt_propagate_a( - M, N, K, in_dtype, out_dtype, accum_dtype, with_bias + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind=propagate_a, ) elif propagate_b: return matmul_nt_propagate_b( - M, N, K, in_dtype, out_dtype, accum_dtype, with_bias + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind=propagate_b, ) else: return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) diff --git a/python/bitblas/ops/impl/param_permutate_impl.py b/python/bitblas/ops/impl/param_permutate_impl.py new file mode 100644 index 000000000..620212eb6 --- /dev/null +++ b/python/bitblas/ops/impl/param_permutate_impl.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.gpu.matmul_analysis import get_propagate_map +from ..operator import TransformKind +from typing import Literal +from tvm import te, IRModule + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16"] = "float16", + transpose_matrix: bool = True, + group_size: int = -1, + propagate_kind: TransformKind = TransformKind.NonTransform, + target_instruction: Literal["nvidia-mma"] = "nvidia-mma", +): + if target_instruction != "nvidia-mma": + raise ValueError("Currently only support nvidia-mma instruction") + if propagate_kind < TransformKind.IntraWarpTransform: + raise ValueError("Currently only support propagate_kind >= IntraWarpTransform") + if transpose_matrix is not True: + raise ValueError("Currently only support transpose_matrix == True") + # This is trick to get the basic tile size for the current datatype + # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 + l = r = 16 # noqa: E741 + if datatype == "int8": + l, r = 16, 32 # noqa: E741 + if group_size == -1: + group_size = N + + intra_index_map, inverse_indexmap = get_propagate_map( + transpose_matrix, dtype=datatype, matrix_name=propagate_kind) + + inp = te.placeholder((M, N // group_size), name="inp", dtype=datatype) + + def fcompute(n, k): + rl, rr = n, k + warp_i, warp_j = rl % l, rr % r + spatial_i, spatial_j = rl // l, rr // r + if propagate_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (spatial_i * l + warp_i, (spatial_j * r + warp_j) // group_size) + return inp[new_index] + + inp_prmt = te.compute( + (M, N // group_size), + fcompute, + name="intra_warp_permutate", + ) + + args = [inp, inp_prmt] + + func = te.create_prim_func(args) + + return IRModule.from_expr(func) diff --git a/python/bitblas/ops/ladder_permutate.py b/python/bitblas/ops/ladder_permutate.py index 298e95dda..f8a2be28b 100644 --- a/python/bitblas/ops/ladder_permutate.py +++ b/python/bitblas/ops/ladder_permutate.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm from tvm.target import Target -from typing import List, Union, Literal +from typing import Literal, Union from .operator import Operator from .impl.ladder_permutate_impl import select_implementation from dataclasses import dataclass @@ -24,22 +23,22 @@ class LadderPermutateConfig: class LadderPermutate(Operator): + def __init__( self, config: LadderPermutateConfig, name: str = "permutate", - target: Target = tvm.target.Target("llvm"), # assume to do permutation on gpu. + target: Union[str, Target] = "llvm", # assume to do permutation on cpu. + enable_tuning: bool = False, ): # consider to warp the arguments to MatmulConfig - super().__init__(name, target) - self.config = config - - if target.kind.name != "llvm": - raise ValueError("Currently only support llvm target for Permutation") + super().__init__(name, config, target) - prim_func_mod = self._select_implementation() - self.prim_func_mod = prim_func_mod - self.target = target + target = self.target + if target.kind.name == "cuda": + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + if enable_tuning: + self.hardware_aware_finetune() self._build_runtime_module(target) # select implementation based on the Operator config diff --git a/python/bitblas/ops/lop3_permutate.py b/python/bitblas/ops/lop3_permutate.py index 956c1f9b4..867432a5e 100644 --- a/python/bitblas/ops/lop3_permutate.py +++ b/python/bitblas/ops/lop3_permutate.py @@ -1,12 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm from tvm.target import Target -from typing import Literal +from typing import Literal, Union from .operator import Operator from .impl.lop3_permutate_impl import select_implementation from dataclasses import dataclass -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch import torch @@ -20,21 +18,19 @@ class LOP3PermutateConfig: class LOP3Permutate(Operator): + def __init__( - self, - config: LOP3PermutateConfig, - name: str = "permutate", - target: Target = tvm.target.Target("llvm"), # assume to do permutation on gpu. + self, + config: LOP3PermutateConfig, + name: str = "permutate", + target: Union[str, Target] = "llvm", # assume to do permutation on cpu. ): # consider to warp the arguments to MatmulConfig - super().__init__(name, target) - self.config = config + super().__init__(name, config, target) if target.kind.name != "llvm": raise ValueError("Currently only support llvm target for Permutation") - prim_func_mod = self._select_implementation() - self.prim_func_mod = prim_func_mod self.target = target self._build_runtime_module(target) @@ -46,11 +42,11 @@ def _select_implementation(self): dequantize_bits=self.dequantize_bits, ) - def forward_from_torch(self, weight, res): - # reintepret the input tensor to int32 format - _tvm_args = [self._tensor_adapter(arg.view(torch.int32), self.arch.device) for arg in [weight, res]] - self.rt_mod(*_tvm_args) - return tvm_tensor_to_torch(_tvm_args[-1]).view(weight.dtype) + def forward(self, weight, res): + # reinterpret the input tensor to int32 format + args = [arg.view(torch.int32) for arg in [weight, res]] + self.torch_func(*args) + return args[-1].view(weight.dtype) @property def M(self): diff --git a/python/bitblas/ops/matmul.py b/python/bitblas/ops/matmul.py index 4844bc739..4f3da8005 100644 --- a/python/bitblas/ops/matmul.py +++ b/python/bitblas/ops/matmul.py @@ -5,15 +5,15 @@ from tvm.target import Target from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from typing import List, Union, Optional, Any, Tuple -from .operator import Operator +from .operator import Operator, TransformKind from .impl.matmul_impl import select_implementation -from ..base.utils import get_rasterization_code -from bitblas.utils import match_global_kernel, tensor_replace_dp4a +from bitblas.utils import tensor_replace_dp4a from dataclasses import dataclass from .ladder_permutate import LadderPermutate, LadderPermutateConfig class TransformExecutorCPU: + def __init__(self, operators: Optional[List[Operator]] = None): if operators is None: operators = [] @@ -49,46 +49,67 @@ class MatmulConfig: out_dtype: str = "float16" accum_dtype: str = "float16" with_bias: bool = False + # layout of matrix A and B + # "nn": C[i, j] = A[i, k] * B[k, j] + # "nt": C[i, j] = A[i, k] * B[j, k] layout: str = "nt" - propagate_a: bool = False - propagate_b: bool = False + # weight transformation kind of matrix A + propagate_a: TransformKind = TransformKind.NonTransform + # weight transformation kind of matrix B + propagate_b: TransformKind = TransformKind.NonTransform def __post_init__(self): # set M to tuple if it is list # otherwise, M is not hashable - object.__setattr__( - self, "M", tuple(self.M) if isinstance(self.M, list) else self.M - ) + object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) + if isinstance(self.propagate_a, bool): + object.__setattr__( + self, + "propagate_a", + (TransformKind.IntraWarpTransform + if self.propagate_a else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_a, int): + object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) + + if isinstance(self.propagate_b, bool): + object.__setattr__( + self, + "propagate_b", + (TransformKind.IntraWarpTransform + if self.propagate_b else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_b, int): + object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) class Matmul(Operator): + def __init__( self, config: MatmulConfig, name: str = "matmul", - target: Target = tvm.target.Target("cuda"), + target: Union[str, Target] = "cuda", + enable_tuning: bool = False, ): - super().__init__(name, target) - self.config = config + super().__init__(name, config, target) target = self.target if target.kind.name != "cuda": raise ValueError("Currently only support cuda target") - prim_func_mod = self._select_implementation() - self.prim_func_mod = prim_func_mod - self.optimized_func = self.apply_default_schedule(prim_func_mod, target) + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} - self.update_func( - self.prim_func.with_attrs({"opt_shapes": self.dynamic_range}) - ) + self.update_func(self.prim_func.with_attrs({"opt_shapes": self.dynamic_range})) else: self.dynamic_range = None self._build_runtime_module(target) if self.propagate_a: + assert (self.propagate_a is + TransformKind.NonTransform), "Currently only support NonTransform for input" ladder_permutate_config = LadderPermutateConfig( M=self.M, N=self.K, @@ -96,7 +117,7 @@ def __init__( storage_dtype=self.in_dtype, propagate_kind="A", transpose_matrix=False, - transform_kind=2, + transform_kind=self.propagate_a, ) self.ladder_permutate_a = LadderPermutate( config=ladder_permutate_config, @@ -112,8 +133,8 @@ def __init__( datatype=self.in_dtype, storage_dtype=self.in_dtype, propagate_kind="B", - transpose_matrix=True if self.layout == "nt" else False, - transform_kind=2, + transpose_matrix=(self.layout == "nt"), + transform_kind=self.propagate_b, ) self.ladder_permutate_b = LadderPermutate( config=ladder_permutate_config, @@ -134,6 +155,9 @@ def __init__( self.weight_executors = weight_executors + if enable_tuning: + self.hardware_aware_finetune() + def _select_implementation(self): return select_implementation( M=self.M, @@ -149,11 +173,6 @@ def _select_implementation(self): ) def post_process(self, code: str) -> str: - index = code.index("{", match_global_kernel(code)) - # some tricky judge to decide whether to insert rasterization code - if self.N * self.K > 10**6: - rasterization_code = get_rasterization_code(10) - code = code[: index + 2] + rasterization_code + code[index + 2 :] code = tensor_replace_dp4a(code) return code @@ -181,18 +200,26 @@ def var_warpper(v, m): arg = func.buffer_map[param] profile_tensors.append( tvm.nd.array( - np.random.uniform( - 0, 1, [var_warpper(i, m) for i in arg.shape] - ).astype(arg.dtype), + np.random.uniform(0, 1, + [var_warpper(i, m) for i in arg.shape]).astype(arg.dtype), device=device, - ) - ) + )) self.profile_tensors = profile_tensors latency = self.time_evaluator(*profile_tensors).mean * 1e3 benchmark_latencies.append({"m": m, "latency": latency}) # ms return benchmark_latencies + def forward(self, *args) -> Any: + if self.lib is None: + self._forward_from_torch_func(*args) + dynamic_symbolic = [] + if self.dynamic_range is not None: + # assume we only have one dynamic range + m = args[0].shape[0] + dynamic_symbolic.append(m) + self._forward_from_prebuild_lib(*args, *dynamic_symbolic) + @property def M(self): return self.config.M diff --git a/python/bitblas/ops/matmul_dequantize.py b/python/bitblas/ops/matmul_dequantize.py index 2ff586679..ee4a46a3a 100644 --- a/python/bitblas/ops/matmul_dequantize.py +++ b/python/bitblas/ops/matmul_dequantize.py @@ -4,17 +4,20 @@ from tvm.target import Target from bitblas.base.roller.arch.cuda import CUDA from typing import Any, List, Literal, Optional, Tuple, Union -from .operator import Operator +from .operator import Operator, TransformKind from .impl.matmul_dequantize_impl import select_implementation -from ..base.utils import get_rasterization_code, tensor_replace_dp4a +from ..base.utils import tensor_replace_dp4a from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from dataclasses import dataclass -from bitblas.utils import match_global_kernel from .ladder_permutate import LadderPermutate, LadderPermutateConfig from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig +import logging +logger = logging.getLogger(__name__) + + +class OPExecutorCPU: -class WeightExecutorCPU: def __init__(self, operators: Optional[List[Operator]] = None): if operators is None: operators = [] @@ -51,29 +54,53 @@ class MatmulWeightOnlyDequantizeConfig: accum_dtype: str = "float16" bit: int = 4 storage_dtype: str = "int8" - source_format: str = "int" + # documents for source_format: + # the format of the source data, which can be "int", "uint", "fp", "nf" + # "int": dequantize_weight = (target)((int)(quantize_weight - fixed_zero_point)) * scale + # where the fixed_zero_point is 2^(bit - 1) - 1 + # "uint": dequantize_weight = (target)((uint)(quantize_weight - zero_point)) * scale + # where the zero_point is manually set by zeros tensor + # "fp": dequantize_weight = (quantize_weight - zero_point) * scale + # "nf": dequantize_weight = (lut[quantize_weight] - zero_point) * scale + source_format: Literal["int", "uint", "fp", "nf"] = "int" with_scaling: bool = False with_zeros: bool = False group_size: int = -1 fast_decoding: bool = False with_bias: bool = False - propagate_a: bool = False - propagate_b: bool = False + propagate_a: TransformKind = TransformKind.NonTransform + propagate_b: TransformKind = TransformKind.NonTransform layout: str = "nt" - # documents for zeros_type: + # documents for zeros_mode: # original: target = (dequantize_weight - zero_point) * scale # rescale: target = dequantize_weight * scale - zero_point - # quantzied: target = (dequantize_weight - dequantize_zeros) * scale - # Notice: only support "original" and "rescale" now - # The auto-gptq framework prefer "original" for alignment with cuda. - zeros_type: Literal["original", "rescale", "quantzied"] = "original" + # quantized: target = (dequantize_weight - dequantize_zeros) * scale + # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. + zeros_mode: Literal["original", "rescale", "quantized"] = "original" def __post_init__(self): # set M to tuple if it is list # otherwise, M is not hashable - object.__setattr__( - self, "M", tuple(self.M) if isinstance(self.M, list) else self.M - ) + object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) + if isinstance(self.propagate_a, bool): + object.__setattr__( + self, + "propagate_a", + (TransformKind.IntraWarpTransform + if self.propagate_a else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_a, int): + object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) + + if isinstance(self.propagate_b, bool): + object.__setattr__( + self, + "propagate_b", + (TransformKind.IntraWarpTransform + if self.propagate_b else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_b, int): + object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) class MatmulWeightOnlyDequantize(Operator): @@ -82,10 +109,10 @@ def __init__( self, config: MatmulWeightOnlyDequantizeConfig, name: str = "matmul_weight_only_dequantize", - target: Target = tvm.target.Target("cuda"), + target: Target = "cuda", + enable_tuning: bool = False, ): - super().__init__(name, target) - self.config = config + super().__init__(name, config, target) target = self.target if target.kind.name != "cuda": @@ -93,22 +120,18 @@ def __init__( self.arch = CUDA(target) - self.prim_func_mod = self._select_implementation() try: - self.optimized_func = self.apply_default_schedule( - self.prim_func_mod, target - ) + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) except Exception: self.optimized_func = None - print( - f"[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." + logger.warnning( + "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." ) if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range} - ) + {"opt_shapes": self.dynamic_range}) else: self.dynamic_range = None @@ -122,7 +145,7 @@ def __init__( storage_dtype=self.in_dtype, propagate_kind="A", transpose_matrix=False, - transform_kind=2, + transform_kind=self.propagate_a, ) self.ladder_permutate_a = LadderPermutate( config=ladder_permutate_config, @@ -139,8 +162,8 @@ def __init__( dequantize_bits=self.bit, storage_dtype=self.storage_dtype, propagate_kind="B", - transpose_matrix=True if self.layout == "nt" else False, - transform_kind=2, + transpose_matrix=self.layout == "nt", + transform_kind=self.propagate_b, ) self.ladder_permutate_b = LadderPermutate( config=ladder_permutate_config, @@ -164,7 +187,12 @@ def __init__( else: self.lop3_permutate = None - weight_executors = WeightExecutorCPU() + input_executors = OPExecutorCPU() + if self.ladder_permutate_a is not None: + input_executors.append(self.ladder_permutate_a) + self.input_executors = input_executors + + weight_executors = OPExecutorCPU() if self.lop3_permutate is not None: weight_executors.append(self.lop3_permutate) @@ -173,6 +201,9 @@ def __init__( self.weight_executors = weight_executors + if enable_tuning: + self.hardware_aware_finetune() + def _select_implementation(self): return select_implementation( M=self.M, @@ -190,22 +221,28 @@ def _select_implementation(self): fast_decoding=self.fast_decoding, with_bias=self.with_bias, layout=self.layout, - zeros_type=self.zeros_type, + zeros_mode=self.zeros_mode, propagate_a=self.propagate_a, propagate_b=self.propagate_b, ) def post_process(self, code: str) -> str: - index = code.index("{", match_global_kernel(code)) code = tensor_replace_dp4a(code) - # some tricky judge to decide whether to insert rasterization code - if self.M == 1: - return code - if self.N * self.K > 10**6: - rasterization_code = get_rasterization_code(10) - code = code[: index + 2] + rasterization_code + code[index + 2 :] return code + def retrieve_weight_shape(self): + return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] + + def forward(self, *args) -> Any: + if self.lib is None: + self._forward_from_torch_func(*args) + dynamic_symbolic = [] + if self.dynamic_range is not None: + # assume we only have one dynamic range + m = args[0].shape[0] + dynamic_symbolic.append(m) + self._forward_from_prebuild_lib(*args, *dynamic_symbolic) + @property def M(self): return self.config.M @@ -275,14 +312,12 @@ def layout(self): return self.config.layout @property - def zeros_type(self): - return self.config.zeros_type + def zeros_mode(self): + return self.config.zeros_mode @property def input_transform(self): - if self.ladder_permutate_a is not None: - return self.ladder_permutate_a - return None + return self.input_executors if self.input_executors.size else None @property def weight_transform(self): diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py index 9cb42c6fa..9585fe499 100644 --- a/python/bitblas/ops/operator.py +++ b/python/bitblas/ops/operator.py @@ -6,27 +6,45 @@ from tvm.target import Target from tvm.tir import PrimFunc from tvm.contrib.dlpack import to_pytorch_func +from tvm._ffi.base import _LIB, raise_last_ffi_error +from tvm._ffi._ctypes.types import TVMValue, ArgTypeCode import bitblas -from typing import List, Dict, Any +import ctypes +from typing import List, Dict, Any, Optional import numpy as np from ..base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy from bitblas.base.roller.arch import get_arch -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -import torch +from bitblas.wrapper import CUDASourceWrapper, CUDASourceWrapperWithDynamic +from dataclasses import dataclass +from enum import IntEnum +import logging +logger = logging.getLogger(__name__) + + +class TransformKind(IntEnum): + NonTransform = 0 + InterWarpTransform = 1 + IntraWarpTransform = 2 + + +@dataclass class OperatorConfig: """Base class for operator configurations. Used for typing.""" + pass class Operator(ABC): - def __init__(self, name, target: Target = None): + + def __init__(self, name, config: OperatorConfig, target: Target = None): if isinstance(target, str): target = Target(target) self.name = name + self.config = config self.target = target - self.prim_func_mod = None + self.prim_func_mod = self._select_implementation() self.optimized_func = None self.rt_mod = None self.time_evaluator = None @@ -34,8 +52,16 @@ def __init__(self, name, target: Target = None): self.arch = get_arch(target) if target else None self.dynamic_range = None self.pass_context: Dict = {} + self.num_args = len(self.prim_func.params) + self.function_handle = None + self.num_output_args: int = ( + 1 # todo(lei): should be analyzed from the prim_func. + ) + self.wrapper = None + self.lib_name = None + self.lib = None - def codegen(self, target: Target = None) -> str: + def get_source(self, target: Target = None) -> str: if target is None: target = self.target if self.rt_mod is None: @@ -73,13 +99,16 @@ def tvm_callback_cuda_postproc(code, _): try: # Use a specific TVM pass context for CUDA platforms - with tvm.transform.PassContext(config={"tir.use_async_copy": True, **self.pass_context}): - rt_mod = tvm.build( - self.optimized_func, target=target, name=self.name - ) - except Exception: - # Log the exception for debugging purposes. Replace 'print' with logging if necessary. - print(f"Failed to build optimized function for CUDA target") + with tvm.transform.PassContext(config={ + "tir.use_async_copy": True, + **self.pass_context + }): + rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) + except Exception as e: + rt_build_error = e # noqa + logger.debug( + "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" + ) else: # For non-CUDA platforms or when no optimized function is available, build with the primary function rt_mod = tvm.build(self.prim_func, target=target, name=self.name) @@ -89,9 +118,26 @@ def tvm_callback_cuda_postproc(code, _): self.rt_mod = rt_mod # Initialize a time evaluator with the built module, specifying the device and the number of runs self.time_evaluator = rt_mod.time_evaluator( - rt_mod.entry_name, self.arch.device, number=10 - ) + rt_mod.entry_name, self.arch.device, number=10) + self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) + if self.arch.platform == "CUDA": + try: + if (self.dynamic_range is not None and len(self.optimized_func.functions) > 1): + wrapper = CUDASourceWrapperWithDynamic(self.optimized_func, + self.get_source(target), self.arch) + else: + wrapper = CUDASourceWrapper(self.optimized_func, self.get_source(target), + self.arch) + wrapper.compile_lib() + self.wrapper = wrapper + self.lib_name = self.wrapper.lib_name + self.lib = self.wrapper.load_lib() + self.lib.init() + except Exception as e: + build_runtime_library_error = e + logger.debug( + "Failed to build runtime library {}".format(build_runtime_library_error)) return rt_mod @@ -105,8 +151,7 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule bitblas.gpu.Reduction(), bitblas.gpu.GeneralReduction(), bitblas.gpu.Fallback(), - )(mod_for_opt) - ) + )(mod_for_opt)) if optimized_mod is not None: return optimized_mod @@ -115,9 +160,11 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule def post_process(self, code: str) -> str: return code - def apply_fast_tuning( - self, func: PrimFunc, target: Target, topk: int = 20, parallel_build=True - ) -> IRModule: + def apply_fast_tuning(self, + func: PrimFunc, + target: Target, + topk: int = 20, + parallel_build=True) -> IRModule: _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) if best is not None: return best.sch.mod @@ -132,35 +179,37 @@ def apply_fast_tuning_with_dynamic_range( dynamic_range: Dict[str, List[int]] = None, ): optimized_mod = fast_tune_with_dynamic_range( - func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range - ) + func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range) if optimized_mod is not None: return optimized_mod return None - def hardware_aware_finetune( - self, topk: int = 20, target: tvm.target.Target = None, parallel_build=True - ): + def hardware_aware_finetune(self, + topk: int = 20, + target: tvm.target.Target = None, + parallel_build=True): if target is None: target = self.target dynamic_range = self.dynamic_range func = self.prim_func if dynamic_range is not None: self.optimized_func = self.apply_fast_tuning_with_dynamic_range( - func, target, topk, dynamic_range - ) + func, target, topk, dynamic_range) else: self.optimized_func = self.apply_fast_tuning( - func, target, topk, parallel_build=parallel_build - ) + func, target, topk, parallel_build=parallel_build) self._build_runtime_module(self.target) - def get_profile_tensors(self): + def get_profile_tensors(self, dynamic_symbolic_constrains: Optional[Dict] = None): + if dynamic_symbolic_constrains is None: + dynamic_symbolic_constrains = {} func = self.prim_func device = self.arch.device def var_warpper(v): if isinstance(v, tvm.tir.Var): + if v.name in dynamic_symbolic_constrains: + return dynamic_symbolic_constrains[v.name] assert "opt_shapes" in func.attrs assert v.name in func.attrs["opt_shapes"] return func.attrs["opt_shapes"][v.name].value @@ -177,26 +226,19 @@ def var_warpper(v): arg = func.buffer_map[param] profile_tensors.append( tvm.nd.array( - np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype( - arg.dtype - ), + np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype(arg.dtype), device=device, - ) - ) + )) self.profile_tensors = profile_tensors return profile_tensors - def profile_latency(self) -> str: - if self.dynamic_range is not None: - return self._profile_latency_with_dynamic_range() - - profile_tensors = self.get_profile_tensors() + def profile_latency(self, dynamic_symbolic_constrains: Optional[Dict] = None) -> str: + if dynamic_symbolic_constrains is None: + dynamic_symbolic_constrains = {} + profile_tensors = self.get_profile_tensors(dynamic_symbolic_constrains) latency = self.time_evaluator(*profile_tensors).mean * 1e3 return latency - def _profile_latency_with_dynamic_range(self) -> List: - raise NotImplementedError - def _tensor_adapter(self, tensor, device): import torch from torch.utils.dlpack import to_dlpack @@ -210,22 +252,55 @@ def _tensor_adapter(self, tensor, device): else: raise RuntimeError("Not supported type: ", type(tensor)) - def forward_from_torch(self, *args): - # convert tensor from torch to tvm + def _forward_from_tvm_args(self, *args): _tvm_args = [self._tensor_adapter(arg, self.arch.device) for arg in args] self.rt_mod(*_tvm_args) - return tvm_tensor_to_torch(_tvm_args[-1]) - def forward(self, *args): - # "Currently only support forward from torch tensor" - return self.torch_func(*args) + def _forward_from_torch_func(self, *args): + self.torch_func(*args) + return args[-1] - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) + def forward(self, *args): + return self._forward_from_torch_func(*args) + + def _forward_from_prebuild_lib(self, *args): + ctypes_args = [ + ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args + ] + self.lib.call(*ctypes_args) + + def _forward_from_tvm_lib_func(self, values): + tcodes = (ctypes.c_int * self.num_args)() + ret_val = TVMValue() + ret_tcode = ctypes.c_int() + for i in range(self.num_args): + tcodes[i] = ArgTypeCode.NDARRAY_HANDLE + if (_LIB.TVMFuncCall( + self.function_handle, + values, + tcodes, + ctypes.c_int(self.num_args), + ctypes.byref(ret_val), + ctypes.byref(ret_tcode), + ) != 0): + raise_last_ffi_error() + + def __call__(self, *args: Any) -> Any: + return self.forward(*args) def update_func(self, func: PrimFunc): self.prim_func_mod["main"] = func + def update_runtime_module(self, rt_mod, lib_name=None): + self.rt_mod = rt_mod + self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) + self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle + self.torch_func = to_pytorch_func(rt_mod) + if lib_name is not None: + self.lib_name = lib_name + self.lib = ctypes.CDLL(lib_name) + self.lib.init() + @abstractmethod def _select_implementation(self) -> IRModule: pass diff --git a/python/bitblas/ops/param_permutate.py b/python/bitblas/ops/param_permutate.py new file mode 100644 index 000000000..ca28c86eb --- /dev/null +++ b/python/bitblas/ops/param_permutate.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.target import Target +from typing import Literal, Union +from .operator import Operator, TransformKind +from .impl.param_permutate_impl import select_implementation +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ParamPermutateConfig: + M: int + N: int + datatype: Literal["float16"] = "float16" + transpose_matrix: bool = True + group_size: int = -1 + propagate_kind: TransformKind = TransformKind.NonTransform + target_instruction: Literal["nvidia-mma"] = ( + "nvidia-mma" # maybe extend to "cdna-mfma" in future. + ) + + def __post_init__(self): + if isinstance(self.propagate_kind, bool): + object.__setattr__( + self, + "propagate_kind", + (TransformKind.IntraWarpTransform + if self.propagate_kind else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_kind, int): + object.__setattr__(self, "propagate_kind", TransformKind(self.propagate_kind)) + + +class ParamPermutate(Operator): + + def __init__( + self, + config: ParamPermutateConfig, + name: str = "permutate", + target: Union[str, Target] = "llvm", # assume to do permutation on cpu. + ): + super().__init__(name, config, target) + + if target.kind.name != "llvm": + raise ValueError("Currently only support llvm target for Permutation") + + self.target = target + self._build_runtime_module(target) + + # select implementation based on the Operator config + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + datatype=self.datatype, + transpose_matrix=self.transpose_matrix, + group_size=self.group_size, + propagate_kind=self.propagate_kind, + target_instruction=self.target_instruction, + ) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def datatype(self): + return self.config.datatype + + @property + def propagate_kind(self): + return self.config.propagate_kind + + @property + def transpose_matrix(self): + return self.config.transpose_matrix + + @property + def group_size(self): + return self.config.group_size + + @property + def target_instruction(self): + return self.config.target_instruction + + +__all__ = ["ParamPermutate", "ParamPermutateConfig"] diff --git a/python/bitblas/quantization/__init__.py b/python/bitblas/quantization/__init__.py index e32f99e33..227cf61a4 100644 --- a/python/bitblas/quantization/__init__.py +++ b/python/bitblas/quantization/__init__.py @@ -1,8 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .quantization import ( - _tir_packed_int_to_int_to_float, - _tir_packed_uint_to_uint_to_float, - _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, + _tir_packed_int_to_int_to_float, # noqa: F401 + _tir_packed_uint_to_uint_to_float, # noqa: F401 + _tir_packed_to_signed_convert, # noqa: F401 + _tir_packed_to_unsigned_convert, # noqa: F401 + _tir_u32_to_f4_to_f16, # noqa: F401 + _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 ) + +from .utils import gen_quant4, general_compress # noqa: F401 diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py index f6efff036..aeecfb874 100644 --- a/python/bitblas/quantization/quantization.py +++ b/python/bitblas/quantization/quantization.py @@ -15,23 +15,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +# # Modifications Copyright (c) Microsoft. -# The code below is mostly copied from mlc.ai quantization.py in mlc-llm. +# The code below is mostly copied from mlc.ai quantization.py in mlc-llm. # pylint: disable=invalid-name,missing-function-docstring,unused-variable """TIR computation utilities for quantization.""" import tvm from tvm import tir + # fmt: off -def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool=True): +def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): mask = tir.const((1 << 16) - 1, "uint32") res = [] for data in [v0, v1]: u32_val = tir.reinterpret("uint32", data) if round_to_even: - rounding_bias = ((u32_val >> tir.const(16, "uint32")) & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") + rounding_bias = ((u32_val >> tir.const(16, "uint32")) + & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") u32_val += rounding_bias res.append((u32_val >> tir.const(16, "uint32")) & mask) return res[0] | (res[1] << tir.const(16, "uint32")) @@ -56,7 +58,8 @@ def _tir_packed_uint_to_uint_to_float(storage_nbit: int): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" max_int_value = (1 << (nbit - 1)) - 1 - return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const((1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( + (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) return f_convert @@ -68,7 +71,8 @@ def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" mask = tir.const((1 << nbit) - 1, "int32") unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask - return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + return tir.Cast( + dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) return f_convert @@ -82,7 +86,11 @@ def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") s = (val_u32 >> tir.const(31, "uint32")) - e_f4 = tir.Select(e_f32 > tir.const(120, "uint32"), tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) + e_f4 = tir.Select( + e_f32 > tir.const(120, "uint32"), + tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), + tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), + tir.const(0, "uint32"))) return (s << tir.const(3, "uint32")) | e_f4 @@ -92,7 +100,10 @@ def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") s = (val_u32 >> tir.const(15, "uint32")) - e_f4 = tir.Select(e_f16 > tir.const(8, "uint32"), tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) + e_f4 = tir.Select( + e_f16 > tir.const(8, "uint32"), + tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), + tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) return (s << tir.const(3, "uint32")) | e_f4 @@ -107,7 +118,8 @@ def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype s = f4 >> tir.const(3, "uint32") e_f4 = f4 & tir.const(7, "uint32") e_f32 = e_f4 | tir.const(120, "uint32") - val_f32 = tir.reinterpret("float32", (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) + val_f32 = tir.reinterpret("float32", + (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) @@ -122,27 +134,44 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype s = f4 >> tir.const(3, "uint32") e_f4 = f4 & tir.const(7, "uint32") e_f16 = e_f4 | tir.const(8, "uint32") - val_f16 = tir.reinterpret("float16", (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) + val_f16 = tir.reinterpret("float16", + (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) + def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): storage_dtype = storage_type + str(storage_nbit) def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" max_int_value = (1 << (nbit - 1)) - 1 - return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const((1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( + (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) return f_convert + def _tir_packed_to_unsigned_convert(storage_type="uint", storage_nbit=8): storage_dtype = storage_type + str(storage_nbit) - def f_convert( - nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str - ): + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype) + return f_convert + + +def _tir_packed_to_unsigned_convert_with_zeros(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm.tir.PrimExpr, + dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return (((val >> (pos * nbit).astype(storage_dtype)) & mask) - zero).astype(dtype) + + return f_convert + + # fmt: on diff --git a/python/bitblas/quantization/utils.py b/python/bitblas/quantization/utils.py index 4ac3a60a1..3d369afe4 100644 --- a/python/bitblas/quantization/utils.py +++ b/python/bitblas/quantization/utils.py @@ -1,6 +1,54 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import numpy as np +import torch +import torch.nn as nn + + +def gen_quant4(k, n, groupsize=-1): + maxq = 2**4 - 1 + w = torch.randn((k, n), dtype=torch.half, device="cpu") + + original_w = w.clone() + + if groupsize == -1: + groupsize = k + + if groupsize != -1: + w = w.reshape((-1, groupsize, n)) + w = w.permute(1, 0, 2) + w = w.reshape((groupsize, -1)) + + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / maxq + + # Quantize. + w = torch.round(w / s).int() + + # Unsigned storage. + w += (maxq) // 2 + + w = torch.clamp(w, 0, maxq) + + # Dequantize. + ref = (w - (maxq) // 2).half() * s + + if groupsize != -1: + + def reshape(w): + w = w.reshape((groupsize, -1, n)) + w = w.permute(1, 0, 2) + w = w.reshape((k, n)).contiguous() + return w + + ref = reshape(ref) + w = reshape(w) + + s = s.reshape((-1, n)).contiguous() + linear = nn.Linear(k, n, bias=False) + linear.weight.data = ref.t() + + return original_w, linear, s, (w - (maxq) // 2) def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): @@ -16,9 +64,7 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): ) for j in range(lowprecision_weight.shape[-1] // elems_per_byte): for k in range(elems_per_byte): - int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << ( - source_bits * k - ) + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) return int8_weight.view(storage_dtype) diff --git a/python/bitblas/relax/transform/annotate_decode_block.py b/python/bitblas/relax/transform/annotate_decode_block.py index c08f55db6..601647839 100644 --- a/python/bitblas/relax/transform/annotate_decode_block.py +++ b/python/bitblas/relax/transform/annotate_decode_block.py @@ -22,6 +22,7 @@ # Define a module pass to annotate dequantization information @module_pass(opt_level=0, name="AnnotateDecodeInformation") class AnnotateDecodeInformation: + def __init__(self, spec: str = "q4f16_0"): # Validate and store the specified quantization scheme if spec not in quantization_schemes: @@ -80,24 +81,18 @@ def transform_module(self, mod: IRModule, _: PassContext) -> IRModule: mod[g_var] = func.with_attr("dequantize_info", dequantize_info) return mod - def prepare_dequantize_info( - self, sch: tir.Schedule, dequantize_block: BlockRV - ) -> Dict: + def prepare_dequantize_info(self, sch: tir.Schedule, dequantize_block: BlockRV) -> Dict: """Generate dequantize information for a given block.""" block_stmt = sch.get(dequantize_block) block_name = block_stmt.name_hint - dequantize_info = { - block_name: {"decode_block": block_name, "fast_decoding": False} - } + dequantize_info = {block_name: {"decode_block": block_name, "fast_decoding": False}} quantize_spec = self.quantize_scheme.linear_weight if isinstance(quantize_spec, GroupQuantizationSpec): - dequantize_info[block_name].update( - { - "with_scaling": True, - "group_size": quantize_spec.group_size, - } - ) + dequantize_info[block_name].update({ + "with_scaling": True, + "group_size": quantize_spec.group_size, + }) # Determine source format based on quantization mode quantize_mod = quantize_spec.mode @@ -124,8 +119,5 @@ def parse_quantize_mode(self, quantize_mod: str) -> Tuple[int, str]: def get_storage_dtype(self, block_stmt: BlockRV, source_format: str) -> str: """Determine storage data type based on source format.""" - return ( - block_stmt.reads[0].buffer.dtype - if "af" not in source_format - else block_stmt.reads[1].buffer.dtype - ) + return (block_stmt.reads[0].buffer.dtype + if "nf" not in source_format else block_stmt.reads[1].buffer.dtype) diff --git a/python/bitblas/relax/transform/weight_only_propagate.py b/python/bitblas/relax/transform/weight_only_propagate.py index 309068f40..8240a0fd8 100644 --- a/python/bitblas/relax/transform/weight_only_propagate.py +++ b/python/bitblas/relax/transform/weight_only_propagate.py @@ -21,8 +21,7 @@ layout_propagate_chain, ) from tvm.dlight.base import ( - analysis, -) + analysis,) from dataclasses import dataclass @@ -86,6 +85,7 @@ class LayoutTransformHint: @module_pass(opt_level=0, name="InsertLayoutTransform") class WeightOnlyLayoutPropagation: + def __init__( self, transform_level: Union[int, TransformKind] = TransformKind.InterWarpTransform, @@ -109,17 +109,14 @@ def __init__( self.layout_transform_hints: Dict[str, List[LayoutTransformHint]] = {} def detect_propagate_matmul(self, func: tir.PrimFunc, target: Target): - _, tags = get_tensorized_func_and_tags( - func, target, skip_normalize=True, allow_gemv=True - ) + _, tags = get_tensorized_func_and_tags(func, target, skip_normalize=True, allow_gemv=True) if tags is None: return False, None return True, tags["intrin_info"] def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info): from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, - ) + get_mma_intrin_group,) sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) @@ -147,20 +144,13 @@ def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info): # checkout whether the weight buffer has dynamic symbol def check_dynamic_symbol(buffer): - for axis in buffer.shape: - if isinstance(axis, tir.Var): - return True - return False + return any([isinstance(axis, tir.Var) for axis in buffer.shape]) if check_dynamic_symbol(weight_buffer): - print( - "[BitBLAS] Weight buffer has dynamic symbol, skip weight propagation." - ) + print("[BitBLAS] Weight buffer has dynamic symbol, skip weight propagation.") return False - transformed_block = find_last_producer_from_buffer( - sch, main_block, weight_buffer - ) + transformed_block = find_last_producer_from_buffer(sch, main_block, weight_buffer) if transformed_block is None: return False if transformed_block != main_block: @@ -170,8 +160,7 @@ def check_dynamic_symbol(buffer): # create inter-warp memory layout index map inter_warp_layout = IndexMap.from_func( - lambda i, j: (i // inter_j, j // inter_k, i % inter_j, j % inter_k) - ) + lambda i, j: (i // inter_j, j // inter_k, i % inter_j, j % inter_k)) inter_warp_layout = layout_propagate_chain( sch, @@ -186,9 +175,8 @@ def check_dynamic_symbol(buffer): ("read", 0), lambda i, j: inter_warp_layout.map_indices([i, j]), ) - arg_idx = find_arg_idx_from_buffer_chain( - sch, reindex_block, sch.get(reindex_block).reads[0].buffer - ) + arg_idx = find_arg_idx_from_buffer_chain(sch, reindex_block, + sch.get(reindex_block).reads[0].buffer) intra_warp_layout = None if self.transform_level.value >= TransformKind.IntraWarpTransform.value: @@ -230,8 +218,8 @@ def transform_module( # pylint: disable=missing-function-docstring # currently weight propagation only support nvidia gpus return mod - propogate_candidates = {} - propogated_funcs = {} # some funcs may not be able to transform + propagate_candidates = {} + propagated_funcs = {} # some funcs may not be able to transform candidates_intrin_info = {} decoded_funcs = {} for g_var, func in mod.functions_items(): @@ -243,10 +231,7 @@ def transform_module( # pylint: disable=missing-function-docstring # detect the pattern is_matmul, intrin_info = self.detect_propagate_matmul(func, self.target) - if ( - func.attrs is not None - and "dlight.do_not_tensorize" in func.attrs.keys() - ): + if (func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys()): # currently we only support tensorize propagation continue @@ -255,21 +240,17 @@ def transform_module( # pylint: disable=missing-function-docstring decoded_funcs[g_var] = func if self.transform_level != TransformKind.NonTransform: # lift tags to the function as it has intrinsic information that can be reused. - propogate_candidates[g_var] = func + propagate_candidates[g_var] = func candidates_intrin_info[g_var] = intrin_info - for g_var, func in propogate_candidates.items(): - updated_func = self.transform_matmul( - g_var, func, candidates_intrin_info[g_var] - ) + for g_var, func in propagate_candidates.items(): + updated_func = self.transform_matmul(g_var, func, candidates_intrin_info[g_var]) if updated_func: - updated_func = updated_func.with_attrs( - { - "transform_kind": self.transform_level.value, - "smooth_b": True, - } - ) - propogated_funcs[g_var] = updated_func + updated_func = updated_func.with_attrs({ + "transform_kind": self.transform_level.value, + "weight_transform_kind": True, + }) + propagated_funcs[g_var] = updated_func mod[g_var] = updated_func @relax.expr_functor.mutator @@ -279,8 +260,10 @@ class TensorCoreLayoutMutator(PyExprMutator): def __init__( self, transform_level: TransformKind = TransformKind.NonTransform, - layout_transform_hints: Dict[str, List[LayoutTransformHint]] = {}, + layout_transform_hints: Optional[Dict[str, List[LayoutTransformHint]]] = None, ): + if layout_transform_hints is None: + layout_transform_hints = {} super().__init__() self.transform_level = transform_level self.layout_transform_hints = layout_transform_hints @@ -289,7 +272,7 @@ def tc_layout_transform(self, call_node: Call) -> Call: if self.transform_level == TransformKind.NonTransform: return super().visit_call_(call_node) g_var = call_node.args[0] - if g_var not in propogated_funcs.keys(): + if g_var not in propagated_funcs: return super().visit_call_(call_node) args = list(call_node.args[1]) # assume we only have weight propagation currently @@ -299,10 +282,8 @@ def tc_layout_transform(self, call_node: Call) -> Call: relax.op.layout_transform( weight, index_map=lambda i, j: weight_layout_hint.inter_warp_layout.map_indices( - [i, j] - ), - ) - ) + [i, j]), + )) if self.transform_level.value >= TransformKind.IntraWarpTransform.value: weight = self.builder_.emit( relax.op.layout_transform( @@ -310,22 +291,17 @@ def tc_layout_transform(self, call_node: Call) -> Call: index_map=lambda i, j, ii, jj: ( i, j, - *weight_layout_hint.intra_warp_layout.map_indices( - [ii, jj] - ), + *weight_layout_hint.intra_warp_layout.map_indices([ii, jj]), ), - ) - ) + )) call_node = self.builder_.emit( relax.call_tir( g_var, - args[: weight_layout_hint.apply_arg_idx] - + [weight] - + args[weight_layout_hint.apply_arg_idx + 1 :], + args[:weight_layout_hint.apply_arg_idx] + [weight] + + args[weight_layout_hint.apply_arg_idx + 1:], out_sinfo=call_node.struct_info, - ) - ) + )) return call_node def visit_call_(self, call_node: Call): @@ -365,7 +341,7 @@ def lop3_layout_transform(self, call_node: Call) -> Call: from bitblas.ops.impl import tir_interleave_weight g_var = call_node.args[0] - if g_var not in decoded_funcs.keys(): + if g_var not in decoded_funcs: return super().visit_call_(call_node) args = list(call_node.args[1]) @@ -379,34 +355,30 @@ def lop3_layout_transform(self, call_node: Call) -> Call: sch = tir.Schedule(func) dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) - # weight is the first read buffer if format in ["int", "uint"], otherwise the second read buffer, af .etc + # weight is the first read buffer if format in ["int", "uint"], otherwise the second read buffer, nf .etc source_format = weight_dequantize_info["source_format"]["format"] source_bits = weight_dequantize_info["source_format"]["bits"] target_dtype = weight_dequantize_info["target_format"] if source_format in ["int", "uint"]: weight_buffer = sch.get(dequantize_block).reads[0].buffer - elif source_format in ["af"]: + elif source_format in ["nf"]: weight_buffer = sch.get(dequantize_block).reads[1].buffer else: raise ValueError(f"Unsupported source format {source_format}") # update func with dequantize_info dequantize_info["fast_decoding"] = True - self.builder_.update_func( - g_var, func.with_attrs({"dequantize_info": dequantize_info}) - ) + self.builder_.update_func(g_var, + func.with_attrs({"dequantize_info": dequantize_info})) - weight_idx = find_arg_idx_from_buffer_chain( - sch, dequantize_block, weight_buffer - ) + weight_idx = find_arg_idx_from_buffer_chain(sch, dequantize_block, weight_buffer) weight = args[weight_idx] weight_shape = weight_buffer.shape # reshape the weight shape to 2d reshape_weight = self.builder_.emit( - relax.op.reshape(weight, (-1, weight_shape[-1])) - ) + relax.op.reshape(weight, (-1, weight_shape[-1]))) # register g_var to the func lop3_interleave_func = tir_interleave_weight( N=reshape_weight.struct_info.shape[0], @@ -424,18 +396,15 @@ def lop3_layout_transform(self, call_node: Call) -> Call: interleave_gvar, [reshape_weight], out_sinfo=reshape_weight.struct_info, - ), - ) + ),) reshape_weight = self.builder_.emit( - relax.op.reshape(lop3_interleave_weight, weight_shape) - ) + relax.op.reshape(lop3_interleave_weight, weight_shape)) call_node = self.builder_.emit( relax.call_tir( g_var, - args[:weight_idx] + [reshape_weight] + args[weight_idx + 1 :], + args[:weight_idx] + [reshape_weight] + args[weight_idx + 1:], out_sinfo=call_node.struct_info, - ), - ) + ),) return call_node @@ -457,7 +426,6 @@ def transform( return mod mod = FastTypeConversionLayoutMutator( - faster_conversion=self.faster_conversion - ).transform(mod) + faster_conversion=self.faster_conversion).transform(mod) mod = relax.transform.LegalizeOps()(mod) return mod diff --git a/python/bitblas/testing/__init__.py b/python/bitblas/testing/__init__.py index 3f1236901..24f896bd8 100644 --- a/python/bitblas/testing/__init__.py +++ b/python/bitblas/testing/__init__.py @@ -5,18 +5,21 @@ import pytest from bitblas.base import DefaultPolicy, TensorCorePolicy from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags + + # pytest.main() wrapper to allow running single test file def main(): test_file = inspect.getsourcefile(sys._getframe(1)) sys.exit(pytest.main([test_file] + sys.argv[1:])) + def debug_with_schedule(func, arch, sch_rule): policy = DefaultPolicy(func=func, arch=arch) try: tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: + except Exception: tags = None if tags: policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) configs = policy.emit_config(1) - return sch_rule.apply_config(configs[0]) + return sch_rule.apply_config(func, configs[0]) diff --git a/python/bitblas/utils/__init__.py b/python/bitblas/utils/__init__.py index 59c7388cc..1c3748dfb 100644 --- a/python/bitblas/utils/__init__.py +++ b/python/bitblas/utils/__init__.py @@ -1,8 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .post_process import match_global_kernel, tensor_replace_dp4a -from .tensor_adapter import tvm_tensor_to_torch -import os - -def get_target_from_env() -> str: - return os.environ.get("TVM_TARGET") or "cuda" +from .post_process import match_global_kernel, tensor_replace_dp4a # noqa: F401 +from .tensor_adapter import tvm_tensor_to_torch # noqa: F401 +from .target_detector import auto_detect_nvidia_target # noqa: F401 diff --git a/python/bitblas/utils/post_process.py b/python/bitblas/utils/post_process.py index 328d5401b..12f785ee0 100644 --- a/python/bitblas/utils/post_process.py +++ b/python/bitblas/utils/post_process.py @@ -11,10 +11,8 @@ def match_global_kernel(source: str) -> int: def tensor_replace_dp4a(source: str) -> str: - # as under block reduction in tir dsl, the dp4a tensorize will fail, so we should do dp4a in post processer. + # as under block reduction in tir dsl, the dp4a tensorize will fail, so we should do dp4a in post processor. pattern = r"""for\s*\(int\s*(?P\w+)\s*=\s*0;\s*\1\s*<\s*4;\s*\+\+\1\)\s*\{\s*(?P\w+)\[0\]\s*=\s*\(\2\[0\]\s*\+\s*\(\(\(int\)(?P\w+)\[\(\((?P\w+)\s*\*\s*4\)\s*\+\s*\1\)\]\)\s*\*\s*\(\(int\)(?P\w+)\[\(\((?P\w+)\s*\*\s*4\)\s*\+\s*\1\)\]\)\)\);\s*\}""" - replacement = ( - r"""\2[0] = __dp4a(*(int *)&\3[((\4 * 4))],*(int *)&\5[((\6 * 4))], \2[0]);""" - ) + replacement = (r"""\2[0] = __dp4a(*(int *)&\3[((\4 * 4))],*(int *)&\5[((\6 * 4))], \2[0]);""") source = re.sub(pattern, replacement, source) return source diff --git a/python/bitblas/utils/target_detector.py b/python/bitblas/utils/target_detector.py new file mode 100644 index 000000000..04cd4dc21 --- /dev/null +++ b/python/bitblas/utils/target_detector.py @@ -0,0 +1,68 @@ +# Import necessary libraries +import os +import subprocess +import logging +from fuzzywuzzy import process +from tvm.target import Target +from tvm.target.tag import list_tags + +logger = logging.getLogger(__name__) + + +def get_gpu_model_from_nvidia_smi(): + """ + Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU. + + Returns: + str: The name of the GPU, or None if 'nvidia-smi' command fails. + """ + try: + # Execute nvidia-smi command to get the GPU name + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"], + encoding="utf-8", + ).strip() + except subprocess.CalledProcessError as e: + logger.info("nvidia-smi failed with error: %s", e) + return None + + # Return the name of the first GPU if multiple are present + return output.split("\n")[0] + + +def find_best_match(tags, query): + """ + Finds the best match for a query within a list of tags using fuzzy string matching. + """ + MATCH_THRESHOLD = 25 + best_match, score = process.extractOne(query, tags) + + def check_target(best, default): + return best if Target(best).arch == Target(default).arch else default + + if check_target(best_match, "cuda"): + return best_match if score >= MATCH_THRESHOLD else "cuda" + else: + return "cuda" + + +def auto_detect_nvidia_target() -> str: + """ + Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target. + + Returns: + str: The detected TVM target architecture. + """ + # Return a predefined target if specified in the environment variable + # if "TVM_TARGET" in os.environ: + # return os.environ["TVM_TARGET"] + + # Fetch all available tags and filter for NVIDIA tags + all_tags = list_tags() + nvidia_tags = [tag for tag in all_tags if "nvidia" in tag] + + # Get the current GPU model and find the best matching target + gpu_model = get_gpu_model_from_nvidia_smi() + target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" + + return target diff --git a/python/bitblas/utils/tensor_adapter.py b/python/bitblas/utils/tensor_adapter.py index 7708cc40c..0f7eeefa9 100644 --- a/python/bitblas/utils/tensor_adapter.py +++ b/python/bitblas/utils/tensor_adapter.py @@ -2,12 +2,86 @@ # Licensed under the MIT License. import tvm from typing import Union +from enum import IntEnum +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +from tvm.relay import TensorType +from tvm._ffi.base import _LIB, c_str +from tvm._ffi._ctypes.types import TVMValue, check_call +from tvm._ffi.runtime_ctypes import ( + TVMArrayHandle,) +import ctypes + +TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +_c_str_dltensor = c_str("dltensor") +_c_str_used_dltensor = c_str("used_dltensor") + + +def get_values_from_torch_tensors(tensors, num_args): + values = (TVMValue * num_args)() + dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in tensors] + for i, dltensor in enumerate(dlpack_tensors): + dltensor = ctypes.py_object(dltensor) + if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor): + ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor) + # enforce type to make sure it works for all ctypes + ptr = ctypes.cast(ptr, ctypes.c_void_p) + handle = TVMArrayHandle() + check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) + # ndarray = tvm.runtime.ndarray._make_array(handle, False, False) + ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) + ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0)) + values[i].v_handle = ctypes.cast(handle, ctypes.c_void_p) + else: + raise ValueError("Invalid DLTensor") + return values + + +class TensorSupplyType(IntEnum): + Integer = 1 + Uniform = 2 + Normal = 3 + Randn = 4 + Zero = 5 + One = 6 -def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): - import torch - from torch.utils.dlpack import from_dlpack +def get_tensor_supply(supply_type: TensorSupplyType, opt_shapes: dict = None): + + def var_wrapper(v, opt_shapes): + if isinstance(v, tvm.tir.Var): + assert opt_shapes + assert v.name in opt_shapes + return opt_shapes[v.name] + elif isinstance(v, tvm.tir.IntImm): + return v.value + else: + raise RuntimeError("Not supported type: ", type(v)) + + def get_tensor(tensor: TensorType) -> torch.Tensor: + dtype = torch.__getattribute__(str(tensor.dtype)) + device = torch.cuda.current_device() + shape = [var_wrapper(i, opt_shapes) for i in tensor.shape] + if supply_type == TensorSupplyType.Integer: + return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) + elif supply_type == TensorSupplyType.Uniform: + return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0) + elif supply_type == TensorSupplyType.Normal: + return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0) + elif supply_type == TensorSupplyType.Randn: + return torch.randn(*shape, device=device).to(dtype) + elif supply_type == TensorSupplyType.Zero: + return torch.zeros(*shape, device=device, dtype=dtype) + elif supply_type == TensorSupplyType.One: + return torch.ones(*shape, device=device, dtype=dtype) + else: + raise NotImplementedError(supply_type) + + return get_tensor + + +def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): if isinstance(tensor, tvm.te.Tensor): return torch.from_numpy(tensor.numpy()) elif isinstance(tensor, tvm.nd.NDArray): diff --git a/python/bitblas/wrapper/__init__.py b/python/bitblas/wrapper/__init__.py new file mode 100644 index 000000000..1d87f8020 --- /dev/null +++ b/python/bitblas/wrapper/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .general import CUDASourceWrapper, CUDASourceWrapperWithDynamic # noqa: F401 diff --git a/python/bitblas/wrapper/general.py b/python/bitblas/wrapper/general.py new file mode 100644 index 000000000..cca40cfb9 --- /dev/null +++ b/python/bitblas/wrapper/general.py @@ -0,0 +1,507 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from typing import Optional, List, Dict, Union +from tvm import IRModule +from bitblas import TileDevice +from tvm.runtime import ndarray +from bitblas.utils import match_global_kernel +import re +import ctypes +import os +import tempfile +import subprocess +import logging +from tvm.driver import lower +from tvm.target import Target + +logger = logging.getLogger(__name__) + +_TYPE_MAP = { + "float32": "float", + "float16": "half", + "bfloat16": "__nv_bfloat162", + "float64": "double", + "int64": "int64_t", + "int32": "int", + "uint32": "unsigned int", + "bool": "int8_t", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uchar": "uint8_t", +} + + +def get_annotated_device_mod(mod: IRModule, target: Target): + """ + Lower the given IRModule and create a device module for the specified target. + + Parameters: + - mod: The input IRModule. + - target: The compilation target. + + Returns: + - A device module ready for execution. + """ + input_mod = lower(mod) + target_input_mod = {target: input_mod} + annotated_mods = {} + runtime = None + target_host = None + for tgt, mod in target_input_mod.items(): + if not isinstance(tgt, (str, Target)): + raise ValueError("The key of inputs must be str or " + "Target when inputs is dict.") + if not isinstance(mod, tvm.IRModule): + raise ValueError("inputs must be Schedule, IRModule, " + "or dict of str to IRModule.") + annotated_mods[tgt] = mod.with_attr("runtime", runtime) + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + if not target_host: + for tar, _ in annotated_mods.items(): + device_type = ndarray.device(tar.kind.name, 0).device_type + if device_type == ndarray.cpu(0).device_type: + target_host = tar + break + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + for target, mod in annotated_mods.items(): + mixed_mod_passes = tvm.get_global_func("driver.mixed_mod_passes") + device_mod_passes = tvm.get_global_func("driver.device_mod_passes") + mod = mixed_mod_passes(mod, target)(mod) + device_mod = device_mod_passes(mod, target)(mod) + return device_mod + + +def get_thread_block_information(mod: IRModule): + """ + Extracts the thread block and grid dimensions for the reduction block within a given IRModule. + + Parameters: + - mod: The input IRModule from which to extract thread block and grid information. + + Returns: + A tuple containing two lists: + - The first list contains the dimensions of the thread block (threadIdx.x, threadIdx.y, threadIdx.z). + - The second list contains the dimensions of the grid (blockIdx.x, blockIdx.y, blockIdx.z). + """ + + # Initialize the schedule from the IRModule + sch = tvm.tir.Schedule(mod) + + # Get the root block and its child blocks + root_block = sch.get_block("root") + child_blocks = sch.get_child_blocks(root_block) + + # Initialize default block and grid dimensions (1, 1, 1) + block_dims, grid_dims = [1, 1, 1], [1, 1, 1] + + for block in child_blocks: + # Get the loops surrounding the main block + loops = sch.get_loops(block) + + # Iterate over each loop to extract thread and block bindings + for loop in loops: + stmt = sch.get(loop) + thread_binding = stmt.thread_binding + extent = int(stmt.extent) + + # Skip loops without thread binding + if thread_binding: + if "threadIdx" in thread_binding.thread_tag: + block_dims["xyz".index(thread_binding.thread_tag[-1])] = extent + elif "blockIdx" in thread_binding.thread_tag: + grid_dims["xyz".index(thread_binding.thread_tag[-1])] = extent + + return block_dims, grid_dims + + +class CUDASourceWrapper(object): + + def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): + self.mod = optimized_mod + self.arch = arch + self.source = source + self.function_name: Optional[str] = None + self.dynamic_smem_buf: Optional[int] = None + self.block_info: Union[List[int], Dict] = [1, 1, 1] + self.grid_info: Union[List[int], Dict] = [1, 1, 1] + self.parse_source_information() + self.src_name: Optional[str] = None + self.lib_name: Optional[str] = None + self.lib_code: Optional[str] = self.update_lib_code(source) + + def load_lib(self): + return ctypes.CDLL(self.lib_name) + + def remove_lib(self): + if self.lib_name: + os.remove(self.lib_name) + self.lib_name = None + + def compile_lib(self, timeout: float = None): + arch = self.arch + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + compute_version = arch.compute_capability + lib_name = src.name.replace(".cu", ".so") + + command = [ + "nvcc", + "-std=c++17", + "-Xcudafe", + "--diag_suppress=177", + "--compiler-options", + "'-fPIC'", + "-lineinfo", + "--shared", + src.name, + "-lcuda", + f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", + "-o", + lib_name, + ] + src.write(self.lib_code) + src.flush() + try: + ret = subprocess.run(command, timeout=timeout) + except subprocess.TimeoutExpired: + logger.warning(f"Compilation Timeout! {command}") + return None + if ret.returncode != 0: + logger.warning(f"Compilation Failed! {command}") + return None + self.src_name = src.name + self.lib_name = lib_name + + def parse_source_information(self): + device_mod = get_annotated_device_mod(self.mod, self.arch.target) + assert (len(device_mod.functions) == 1 + ), "Only support one function in the module for static shape kernel." + for g_var, func in device_mod.functions.items(): + self.function_name = g_var.name_hint + attrs = func.attrs + if "dyn_shared_memory_buf" in attrs: + self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + self.block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + self.grid_info["xyz".index(tag[-1])] = extent + + def get_dynamic_symbolic_set(self, prim_func): + # Determine the set of dynamic symbols used in the function + dynamic_symbolic_set = set() + for param in prim_func.params: + buffer = prim_func.buffer_map[param] + for dim in buffer.shape: + if isinstance(dim, tvm.tir.Var): + dynamic_symbolic_set.add(dim.name) + return dynamic_symbolic_set + + def get_cuda_init_func(self): + # Initialize an empty string for the CUDA function call + call_str = """""" + # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call + if self.dynamic_smem_buf is not None: + call_str = """ + cudaFuncSetAttribute({}, + cudaFuncAttributeMaxDynamicSharedMemorySize, {}); + """.format(self.function_name, self.dynamic_smem_buf) + # Format the initialization function using the call_str + init_funcs = """ + extern "C" void init() {{ + {} + }} + """.format(call_str) + return init_funcs + + def update_lib_code(self, code: str): + # Update the library code with the given code string + self.lib_code = code + # Find the index of the global kernel function in the code + index = match_global_kernel(code) + # Extract the declaration of the function starting from the found index + declaration = code[index:].split(";")[0] + + function_name = self.function_name + # Get the CUDA initialization function + init_func = self.get_cuda_init_func() + + # Locate the opening brace of the function to insert arguments + index = code.index("{", index) + function_args = [] + # Populate the function arguments from the primary function's parameters and buffers + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + # Add dynamic symbolic parameters as integers to the function arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s, function_args): + # Extract the function call arguments matching the function definition + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(declaration, function_args)) + block_info, grid_info = self.block_info, self.grid_info + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + # Prepare the block and grid dimensions for the CUDA kernel launch + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) + # Determine the shared memory size, defaulting to 0 if not specified + smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf + # Format the CUDA kernel launch string + call_str = "{}<<<{}, {}, {}>>>({});".format(function_name, grid_str, block_str, smem_str, + call_args) + # Create the host function wrapper for the CUDA kernel + host_func = """ + extern "C" void call({}) {{ + {} + }} + """.format(def_args, call_str) + # Combine the source, initialization function, and host function to form the complete library code + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + return self.mod["main"] + + +class CUDASourceWrapperWithDynamic(CUDASourceWrapper): + + def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): + super().__init__(optimized_mod, source, arch) + + def get_cuda_init_func(self): + # Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory + call_str = """""" + # Iterate over functions and their dynamic shared memory requirements + for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): + if dynamic_smem_buf is not None: + # Format the cudaFuncSetAttribute call for dynamic shared memory + call_str += """ + cudaFuncSetAttribute({}, + cudaFuncAttributeMaxDynamicSharedMemorySize, {}); + """.format(function_name, dynamic_smem_buf) + # Define the init function that will set the attributes for each kernel + init_funcs = """ +extern "C" void init() {{ + {} +}} + """.format(call_str) + return init_funcs + + def create_dispatch_func(self, code, function_informations): + # Extract the set of dynamic symbolic names used in the primary function + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + + # Find the location of the global kernel function in the code + index = match_global_kernel(code) + + # Analyze the function declaration to prepare for argument extraction + dummy_declaration = code[index:].split(";")[0] + + function_name = self.function_name + + # Identify the start of the function body to insert arguments + index = code.index("{", index) + function_args = [] + # Collect function arguments based on primary function's parameters and buffer mappings + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + # Add dynamic symbols as integer arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + # Format the argument definitions for function declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s: str, function_args): + # Extract and clean the function call arguments to match the declaration + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + match = re.sub(r"\d+", "", match) # Remove numbers + match = re.sub(r"_", "", match) # Remove underscores + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(dummy_declaration, function_args)) + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + last_range = 0 + num_items = len(function_informations) + _call_str = """""" + for function_name, info in function_informations.items(): + # Prepare block and grid configurations for kernel launches + block_info, grid_info = info["block_info"], info["grid_info"] + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), + legalize_c(grid_info[1]), + legalize_c(grid_info[2]), + ) + # Handle dynamic shared memory specification + smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"]) + opt_shapes = info["opt_shapes"] + # Generate conditional kernel launch code based on dynamic symbolic ranges + (symbolic,) = list(dynamic_symbolic_set) + range_str = opt_shapes[symbolic] + if last_range == 0: + call_str = "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format( + symbolic, + range_str, + function_name, + grid_str, + block_str, + smem_str, + call_args, + ) + else: + call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format( + symbolic, + range_str, + function_name, + grid_str, + block_str, + smem_str, + call_args, + ) + if last_range == num_items - 1: + call_str += ("\t\telse {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format( + function_name, grid_str, block_str, smem_str, call_args)) + last_range += 1 + _call_str += call_str + + # Wrap the kernel dispatch logic in an external C function + host_func = """ +extern "C" void call({}) {{ + {} +}} + """.format(def_args, _call_str) + return host_func + + def parse_source_information(self): + # Parse device module to extract execution configurations for each function + device_mod = get_annotated_device_mod(self.mod, self.arch.target) + block_info_map = {} + grid_info_map = {} + dynamic_smem_buf_map = {} + for g_var, func in device_mod.functions.items(): + # Default block and grid configurations + block_info = [1, 1, 1] + grid_info = [1, 1, 1] + function_name = g_var.name_hint + attrs = func.attrs + dynamic_smem_buf = None + if "dyn_shared_memory_buf" in attrs: + dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + # Extract block and grid sizes from thread extents + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + grid_info["xyz".index(tag[-1])] = extent + # Map the extracted configurations to each function + block_info_map[function_name] = block_info + grid_info_map[function_name] = grid_info + dynamic_smem_buf_map[function_name] = dynamic_smem_buf + # Store the mappings for use in code generation + self.block_info = block_info_map + self.grid_info = grid_info_map + self.dynamic_smem_buf = dynamic_smem_buf_map + + def update_lib_code(self, code: str): + # Organize function information for code generation + function_informations = {} + for g_var, func in self.mod.functions.items(): + if g_var.name_hint == "main": + continue + function_name = g_var.name_hint + attrs = func.attrs + assert "opt_shapes" in attrs + opt_shapes = attrs["opt_shapes"] + function_informations[function_name] = { + "function_name": function_name, + "opt_shapes": opt_shapes, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + } + + def compare_map_objects(map_obj): + comparable_representation = list(map_obj.values()) + return comparable_representation + + function_informations = dict( + sorted( + function_informations.items(), + key=lambda item: compare_map_objects(item[1]["opt_shapes"]))) + + self.lib_code = code + + # Generate the initialization and dispatch functions + init_func = self.get_cuda_init_func() + host_func = self.create_dispatch_func(code, function_informations) + # Concatenate source code with generated code segments + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + return self.mod["main"] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..20d8be152 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,30 @@ +# formatting +yapf==0.32.0 +toml==0.10.2 +tomli==2.0.1 +ruff==0.1.5 +codespell==2.2.6 + +cffi +cpplint +Cython +decorator +docutils +dtlib +numpy>=1.23.5 +pylint +pytest>=6.2.4 +pytest_xdist>=2.2.1 +packaging>=21.0 +PyYAML +tqdm>=4.62.3 +typing_extensions>=4.10.0 +requests +fuzzywuzzy +python-Levenshtein +attrs +cloudpickle +ml_dtypes +psutil +scipy +tornado diff --git a/requirements.txt b/requirements.txt index fadce86e2..3747eecff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,24 +1,23 @@ -cffi==1.16.0     -cpplint==1.6.1     -Cython==0.29.32     -decorator==5.1.1     -docutils==0.19     -dtlib==0.0.0.dev2     -flatbuffers==22.12.6     -matplotlib==3.6.3     -numpy>=1.23.5     -Pillow==10.2.0     -protobuf==3.20.3     -psutil==5.9.4     -pylint==3.1.0     -pyserial==3.5     -pytest==7.2.0     -pytest_xdist==3.5.0     -PyYAML==6.0.1     -tornado==6.3.3     -tqdm==4.64.1     -transformers==4.36.0     -treelib==1.7.0     -typeguard==2.13.3              -typing_extensions==4.10.0     -xgboost==1.7.1     +cffi +cpplint +Cython +decorator +docutils +dtlib +numpy>=1.23.5 +pylint +pytest>=6.2.4 +pytest_xdist>=2.2.1 +packaging>=21.0 +PyYAML +tqdm>=4.62.3 +typing_extensions>=4.10.0 +requests +fuzzywuzzy +python-Levenshtein +attrs +cloudpickle +ml_dtypes +psutil +scipy +tornado diff --git a/setup.py b/setup.py index 3afdea26c..3b0c88305 100644 --- a/setup.py +++ b/setup.py @@ -1,31 +1,292 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import os +import io +import subprocess +import shutil from setuptools import setup, find_packages from setuptools.command.install import install +from setuptools.command.build_py import build_py +from setuptools.command.sdist import sdist +import distutils.dir_util +from typing import List +import re +import tarfile +from io import BytesIO +import os +import sys +import urllib.request +from distutils.version import LooseVersion +import platform + +PACKAGE_NAME = "bitblas" +ROOT_DIR = os.path.dirname(__file__) +MAIN_CUDA_VERSION = "12.1" + +# BitBLAS only supports Linux platform +assert sys.platform.startswith("linux"), "BitBLAS only supports Linux platform (including WSL)." + + +def get_path(*filepath) -> str: + return os.path.join(ROOT_DIR, *filepath) + + +def get_requirements() -> List[str]: + """Get Python package dependencies from requirements.txt.""" + with open(get_path("requirements.txt")) as f: + requirements = f.read().strip().split("\n") + return requirements + + +def find_version(filepath: str) -> str: + """Extract version information from the given filepath. + + Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py + """ + with open(filepath) as fp: + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M) + if version_match: + return version_match.group(1) + raise RuntimeError("Unable to find version string.") + + +def get_nvcc_cuda_version(): + """Get the CUDA version from nvcc. + + Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py + """ + nvcc_output = subprocess.check_output(["nvcc", "-V"], universal_newlines=True) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + nvcc_cuda_version = LooseVersion(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + +def get_bitblas_version(with_cuda=True, with_system_info=True) -> str: + version = find_version(get_path("python/bitblas", "__init__.py")) + if with_system_info: + version += f"+{get_system_info()}" + if with_cuda: + cuda_version = str(get_nvcc_cuda_version()) + cuda_version_str = cuda_version.replace(".", "")[:3] + version += f".cu{cuda_version_str}" + return version + + +def get_system_info(): + system = platform.system().lower() + if system == "linux": + try: + with open("/etc/os-release") as f: + os_release = f.read() + version_id_match = re.search(r'VERSION_ID="(\d+\.\d+)"', os_release) + if version_id_match: + version_id = version_id_match.group(1) + distro = "ubuntu" + return f"{distro}-{version_id}" + except FileNotFoundError: + pass + return system + + +def read_readme() -> str: + """Read the README file if present.""" + p = get_path("README.md") + if os.path.isfile(p): + return io.open(get_path("README.md"), "r", encoding="utf-8").read() + else: + return "" + + +def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"): + """ + Downloads and extracts the specified version of LLVM for the given platform. + Args: + version (str): The version of LLVM to download. + is_aarch64 (bool): True if the target platform is aarch64, False otherwise. + extract_path (str): The directory path where the archive will be extracted. + + Returns: + str: The path where the LLVM archive was extracted. + """ + ubuntu_version = "16.04" + if version >= "16.0.0": + ubuntu_version = "20.04" + elif version >= "13.0.0": + ubuntu_version = "18.04" + base_url = (f"https://github.com/llvm/llvm-project/releases/download/llvmorg-{version}") + file_name = f"clang+llvm-{version}-{'aarch64-linux-gnu' if is_aarch64 else f'x86_64-linux-gnu-ubuntu-{ubuntu_version}'}.tar.xz" -class ApacheTVMInstallCommand(install): - """Customized setuptools install command - builds the submodule first.""" + download_url = f"{base_url}/{file_name}" + + # Download the file + print(f"Downloading {file_name} from {download_url}") + with urllib.request.urlopen(download_url) as response: + if response.status != 200: + raise Exception(f"Download failed with status code {response.status}") + file_content = response.read() + # Ensure the extract path exists + os.makedirs(extract_path, exist_ok=True) + + # if the file already exists, remove it + if os.path.exists(os.path.join(extract_path, file_name)): + os.remove(os.path.join(extract_path, file_name)) + + # Extract the file + print(f"Extracting {file_name} to {extract_path}") + with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar: + tar.extractall(path=extract_path) + + print("Download and extraction completed successfully.") + return os.path.abspath(os.path.join(extract_path, file_name.replace(".tar.xz", ""))) + + +package_data = { + "bitblas": ["py.typed"], +} + +LLVM_VERSION = "10.0.1" +IS_AARCH64 = False # Set to True if on an aarch64 platform +EXTRACT_PATH = "3rdparty" # Default extraction path + + +def update_submodules(): + """Updates git submodules.""" + try: + subprocess.check_call(["git", "submodule", "update", "--init", "--recursive"]) + except subprocess.CalledProcessError as error: + raise RuntimeError("Failed to update submodules") from error + + +def build_tvm(llvm_config_path): + """Configures and builds TVM.""" + os.chdir("3rdparty/tvm") + if not os.path.exists("build"): + os.makedirs("build") + os.chdir("build") + # Copy the config.cmake as a baseline + if not os.path.exists("config.cmake"): + shutil.copy("../cmake/config.cmake", "config.cmake") + # Set LLVM path and enable CUDA in config.cmake + with open("config.cmake", "a") as config_file: + config_file.write(f"set(USE_LLVM {llvm_config_path})\n") + config_file.write("set(USE_CUDA ON)\n") + # Run CMake and make + try: + subprocess.check_call(["cmake", ".."]) + subprocess.check_call(["make", "-j"]) + except subprocess.CalledProcessError as error: + raise RuntimeError("Failed to build TVM") from error + finally: + # Go back to the original directory + os.chdir("../../..") + + +def setup_llvm_for_tvm(): + """Downloads and extracts LLVM, then configures TVM to use it.""" + # Assume the download_and_extract_llvm function and its dependencies are defined elsewhere in this script + extract_path = download_and_extract_llvm(LLVM_VERSION, IS_AARCH64, EXTRACT_PATH) + llvm_config_path = os.path.join(extract_path, "bin", "llvm-config") + return extract_path, llvm_config_path + + +class BitBLASInstallCommand(install): + """Customized setuptools install command - builds TVM after setting up LLVM.""" def run(self): - os.system("./maint/scripts/installation.sh") + # Recursively update submodules + # update_submodules() + # Set up LLVM for TVM + _, llvm_path = setup_llvm_for_tvm() + # Build TVM + build_tvm(llvm_path) + # Continue with the standard installation process install.run(self) +class BitBLASBuilPydCommand(build_py): + """Customized setuptools install command - builds TVM after setting up LLVM.""" + + def run(self): + build_py.run(self) + # custom build tvm + update_submodules() + # Set up LLVM for TVM + _, llvm_path = setup_llvm_for_tvm() + # Build TVM + build_tvm(llvm_path) + # Copy the built TVM to the package directory + TVM_PREBUILD_ITEMS = [ + "3rdparty/tvm/build/libtvm_runtime.so", + "3rdparty/tvm/build/libtvm.so", + "3rdparty/tvm/build/config.cmake", + "3rdparty/tvm/python", + "3rdparty/tvm/licenses", + "3rdparty/tvm/conftest.py", + "3rdparty/tvm/CONTRIBUTORS.md", + "3rdparty/tvm/KEYS", + "3rdparty/tvm/LICENSE", + "3rdparty/tvm/README.md", + "3rdparty/tvm/mypy.ini", + "3rdparty/tvm/pyproject.toml", + "3rdparty/tvm/version.py", + ] + for item in TVM_PREBUILD_ITEMS: + source_dir = os.path.join(ROOT_DIR, item) + target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) + if os.path.isdir(source_dir): + self.mkpath(target_dir) + distutils.dir_util.copy_tree(source_dir, target_dir) + else: + target_dir = os.path.dirname(target_dir) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + shutil.copy2(source_dir, target_dir) + + +class BitBLASSdistCommand(sdist): + """Customized setuptools sdist command - includes the pyproject.toml file.""" + + def make_distribution(self): + self.distribution.metadata.name = PACKAGE_NAME + self.distribution.metadata.version = get_bitblas_version( + with_cuda=False, with_system_info=False) + super().make_distribution() + + setup( - name="bitblas", - version="0.1", - packages=find_packages(), - install_requires=[], + name=PACKAGE_NAME, + version=get_bitblas_version(), + packages=find_packages(where="python"), + package_dir={"": "python"}, author="Microsoft Research", - author_email="leiwang1999@outlook.com", description="A light weight framework to generate high performance CUDA/HIP code for BLAS operators.", + long_description=read_readme(), license="MIT", keywords="BLAS, CUDA, HIP, Code Generation, TVM", url="https://github.com/microsoft/BitBLAS", + classifiers=[ + "Programming Language :: Python :: 3.8", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Mathematics, Scientific/Engineering :: Artificial Intelligence", + ], + python_requires=">=3.8", + install_requires=get_requirements(), + tests_require=[ + "yapf>=0.32.0", "toml>=0.10.2", "tomli>=2.0.1", "ruff>=0.1.5", "codespell>=2.2.6" + ], + package_data=package_data, + include_package_data=True, + data_files=[ + "requirements.txt", + ], cmdclass={ - "install": ApacheTVMInstallCommand, + "install": BitBLASInstallCommand, + "build_py": BitBLASBuilPydCommand, + "sdist": BitBLASSdistCommand, }, ) diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp index 824de1bdf..f86ca2cf0 100644 --- a/testing/cpp/lop3_type_conversion/fast_decoding.hpp +++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp @@ -97,16 +97,16 @@ void general_interleave_fp16(int8_t *origin_arr, int8_t *interleaved, const int /* Kind 0: original Kind 1: rescale -Kind 2: quantzied -# documents for zeros_type: +Kind 2: quantized +# documents for zeros_mode: # original: target = (dequantize_weight - zero_point) * scale # rescale: target = dequantize_weight * scale - zero_point -# quantzied: target = (dequantize_weight - dequantize_zeros) * scale +# quantized: target = (dequantize_weight - dequantize_zeros) * scale # Notice: only support "original" and "rescale" now -zeros_type: Literal["original", "rescale", "quantzied"] = "original" +zeros_mode: Literal["original", "rescale", "quantized"] = "original" */ -template -__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8, const half *scale = nullptr, const half *zeros = nullptr) +template +__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) { uint *h = reinterpret_cast(B_local_decode); @@ -116,6 +116,7 @@ __device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8, // Minus 7 to scale the value to signed static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; uint const i4s = *reinterpret_cast(_i4s); + #pragma unroll // decode 2 elems at one time. for (int i = 0; i < (N / 2); i++) @@ -124,18 +125,8 @@ __device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8, asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[i]) : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - if constexpr (withZeros && ZerosKind == 0) - { - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); - } - if constexpr (withScaling) - { - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); - } - if constexpr (withZeros && ZerosKind == 1){ - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); - } } } @@ -151,40 +142,169 @@ __device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8) decode_i4b_to_f16(_i4u, B_local_decode, N); } -template -__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +template +__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr) { - decode_i4b_to_f16(_i4s, B_local_decode, N, scale); + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } } -template -__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +template +__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) { - decode_i4b_to_f16(_i4u, B_local_decode, N, scale); + decode_i4b_to_f16_scale(_i4s, B_local_decode, N, scale); } -template -__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, half *scale = nullptr, half *zeros = nullptr, const int N = 8) +template +__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) { - decode_i4b_to_f16(_i4u, B_local_decode, N, scale, zeros); + decode_i4b_to_f16_scale(_i4u, B_local_decode, N, scale); } -template -__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, half *scale = nullptr, half *zeros = nullptr, const int N = 8) +template +__device__ void decode_i4b_to_f16_zeros_original(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_zeros_original(_i4u, B_local_decode, N, scale, zeros); +} + +template +__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} + +template +__device__ void decode_i4b_to_f16_scale_zeros_quantized(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_quantized(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, zero_dtype *zeros = nullptr, const int N = 8) { - decode_i4b_to_f16(_i4u, B_local_decode, N, scale, zeros); + decode_i4b_to_f16_scale_zeros_quantized(_i4u, B_local_decode, N, scale, zeros); } /* Kind 0: original Kind 1: rescale -Kind 2: quantzied -# documents for zeros_type: +Kind 2: quantized +# documents for zeros_mode: # original: target = (dequantize_weight - zero_point) * scale # rescale: target = dequantize_weight * scale - zero_point -# quantzied: target = (dequantize_weight - dequantize_zeros) * scale +# quantized: target = (dequantize_weight - dequantize_zeros) * scale # Notice: only support "original" and "rescale" now -zeros_type: Literal["original", "rescale", "quantzied"] = "original" +zeros_mode: Literal["original", "rescale", "quantized"] = "original" */ template __device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8, half *scale = nullptr, half *zeros = nullptr) @@ -261,17 +381,16 @@ __device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_deco decode_i2b_to_f16(_i2u, B_local_decode, N, scale, zeros); } - /* Kind 0: original Kind 1: rescale -Kind 2: quantzied -# documents for zeros_type: +Kind 2: quantized +# documents for zeros_mode: # original: target = (dequantize_weight - zero_point) * scale # rescale: target = dequantize_weight * scale - zero_point -# quantzied: target = (dequantize_weight - dequantize_zeros) * scale +# quantized: target = (dequantize_weight - dequantize_zeros) * scale # Notice: only support "original" and "rescale" now -zeros_type: Literal["original", "rescale", "quantzied"] = "original" +zeros_mode: Literal["original", "rescale", "quantized"] = "original" */ template __device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, half *scale = nullptr, half *zeros = nullptr) @@ -323,30 +442,122 @@ __device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8) decode_i1b_to_f16(_i1u, B_local_decode, N); } -template -__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +template +__device__ void decode_i1b_to_f16_scale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr) { - decode_i1b_to_f16(_i1s, B_local_decode, N, scale); + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i1b_to_f16_scale(_i1s, B_local_decode, N, scale); } -template -__device__ void decode_i1u_to_f16_scale(T1 *_i1u, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +template +__device__ void decode_i1u_to_f16_scale(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) { - decode_i1b_to_f16(_i1u, B_local_decode, N, scale); + decode_i1b_to_f16_scale(_i1u, B_local_decode, N, scale); } -template -__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, half *scale = nullptr, half *zeros = nullptr, const int N = 8) +template +__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) { - decode_i1b_to_f16(_i1u, B_local_decode, N, scale, zeros); + decode_i1b_to_f16_zeros_original(_i1u, B_local_decode, N, scale, zeros); } -template -__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i1u, T2 *B_local_decode, half *scale = nullptr, half *zeros = nullptr, const int N = 8) +template +__device__ void decode_i1b_to_f16_scale_zeros_rescale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) { - decode_i1b_to_f16(_i1u, B_local_decode, N, scale, zeros); + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } } +template +__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} void general_interleave_int8(int8_t *origin_arr, int8_t *interleaved, const int nbit, size_t size_in_bytes, bool verbose = false) { diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu index 0d9cca26f..eda2be206 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu @@ -45,6 +45,7 @@ REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1u_to_f16_scale_zeros_original, de REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale_zeros_rescale, decode_i4u_to_f16_scale_zeros_rescale) REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16_scale_zeros_rescale, decode_i2u_to_f16_scale_zeros_rescale) REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1u_to_f16_scale_zeros_rescale, decode_i1u_to_f16_scale_zeros_rescale) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale_zeros_quantized, decode_i4u_to_f16_scale_zeros_quantized) TEST(DecodeTest, DecodeInt4ToFloat16) { @@ -1022,3 +1023,57 @@ TEST(DecodeTest, DecodeUInt1ToFloat16WithScalingWithZerosRescale) free(interleaved); free(decoded); } + +TEST(DecodeTest, DecodeUInt4ToFloat16WithScalingWithZerosQuantized) +{ + constexpr int nbits = 4; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + + // create four int8_t values + int8_t in_data[N] = { + 0}; + half scale[1] = {__float2half(1.2)}; + uint qzeros[1] = {(1 << (nbits - 1)) - 1}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)); + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + uint *qzeros_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&qzeros_gpu, 1 * sizeof(uint))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(qzeros_gpu, qzeros, 1 * sizeof(uint), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + kernelWrapper_i4u_to_f16_scale_zeros_quantized<<>>(ins_gpu, decoded_gpu, scale_gpu, qzeros_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(((int)in_data[i] - (int)qzeros[0]) * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); +} \ No newline at end of file diff --git a/testing/python/cache/test_operator_cache.py b/testing/python/cache/test_operator_cache.py index c2964849f..fcb863f9a 100644 --- a/testing/python/cache/test_operator_cache.py +++ b/testing/python/cache/test_operator_cache.py @@ -1,17 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import pytest -import tvm +import os import torch import bitblas from bitblas.ops.matmul import Matmul, MatmulConfig +from bitblas.ops.matmul_dequantize import ( + MatmulWeightOnlyDequantize, + MatmulWeightOnlyDequantizeConfig, +) from bitblas.cache import global_operator_cache -target = tvm.target.Target("nvidia/nvidia-a100") +target = bitblas.utils.auto_detect_nvidia_target() def get_codegen_result(ops, target): - code = ops.codegen(target=target) + code = ops.get_source(target=target) return code @@ -67,6 +71,7 @@ def test_config_hashable( print(hash_error) assert success + @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout,enable_tuning", [ @@ -115,10 +120,11 @@ def test_global_cache_inquery( except Exception as hash_error: print(hash_error) assert success - + matmul = global_operator_cache.get(matmul.config) assert matmul is not None + @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout,enable_tuning", [ @@ -167,7 +173,7 @@ def test_global_cache_inquery_torch_forward( except Exception as hash_error: print(hash_error) assert success - + matmul = global_operator_cache.get(matmul.config) assert matmul is not None if not isinstance(M, int): @@ -184,37 +190,300 @@ def test_global_cache_inquery_torch_forward( permuted_inputs = [] if matmul.input_transform is not None: - permuted_inputs.append( - matmul.input_transform(inputs[0].cpu()) - ).cuda() + permuted_inputs.append(matmul.input_transform(inputs[0].cpu())).cuda() else: permuted_inputs.append(inputs[0]) if matmul.weight_transform is not None: - permuted_inputs.append( - matmul.weight_transform(inputs[1].cpu()).cuda() - ) + permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda()) else: permuted_inputs.append(inputs[1]) permuted_inputs.append(inputs[2]) matmul(*permuted_inputs) torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout,enable_tuning", + [ + (1, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), + ([1, 32], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), + ], +) +def test_global_cache_save_to_database( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + propagate_a, + propagate_b, + layout, + enable_tuning, +): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = Matmul( + config=matmul_config, + target=target, + ) + if enable_tuning: + matmul.hardware_aware_finetune(topk=20) + success = False + try: + global_operator_cache.add(matmul.config, matmul) + success = True + except Exception as hash_error: + print(hash_error) + assert success + + database_path = "debug/test_database" + global_operator_cache.save_into_database(database_path, target=target) + assert os.path.exists(database_path) + global_operator_cache.clear() + assert global_operator_cache.size() == 0 + global_operator_cache.load_from_database(database_path, target=target) + assert global_operator_cache.size() > 0 + + matmul = global_operator_cache.get(matmul.config) + assert matmul is not None + if not isinstance(M, int): + M = 32 + # convert tensors to torch + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda()) + inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda()) + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + ref_result = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) + + permuted_inputs = [] + if matmul.input_transform is not None: + permuted_inputs.append(matmul.input_transform(inputs[0].cpu())).cuda() + else: + permuted_inputs.append(inputs[0]) + if matmul.weight_transform is not None: + permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda()) + else: + permuted_inputs.append(inputs[1]) + permuted_inputs.append(inputs[2]) + matmul(*permuted_inputs) + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", + [ + ( + 1, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "uint", + False, + False, + -1, + False, + False, + False, + False, + "nt", + ), + ( + 1, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + False, + False, + -1, + False, + False, + False, + False, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + False, + False, + -1, + False, + False, + False, + False, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + False, + False, + -1, + False, + False, + False, + True, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + False, + False, + -1, + False, + False, + True, + True, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + True, + False, + -1, + False, + False, + True, + True, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + True, + False, + 128, + False, + False, + True, + True, + "nt", + ), + ], +) +def test_matmul_dequantize_save_into_database( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + propagate_a, + propagate_b, + layout, +): + + matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, + ) + matmul.hardware_aware_finetune(topk=20) + database_path = "debug/test_database" + success = False + + try: + global_operator_cache.add(matmul.config, matmul) + success = True + except Exception as hash_error: + print(hash_error) + assert success + global_operator_cache.save_into_database(database_path, target=target) + assert os.path.exists(database_path) + global_operator_cache.clear() + assert global_operator_cache.size() == 0 + global_operator_cache.load_from_database(database_path, target=target) + assert global_operator_cache.size() > 0 + + # fmt: on if __name__ == "__main__": - # bitblas.testing.main() - - # test_config_hashable([1, 1024], 1024, 1024, "float16", "float16", "float16", False, False, False, "nt", True) - - test_global_cache_inquery_torch_forward( - [1, 1024], - 1024, - 1024, - "float16", - "float16", - "float16", - False, - False, - False, - "nt", - True, - ) + bitblas.testing.main() diff --git a/testing/python/dsl/test_auto_normalized_tensorcore.py b/testing/python/dsl/test_auto_normalized_tensorcore.py index 171fef0fa..c7c9a7235 100644 --- a/testing/python/dsl/test_auto_normalized_tensorcore.py +++ b/testing/python/dsl/test_auto_normalized_tensorcore.py @@ -2,68 +2,23 @@ # Licensed under the MIT License. import numpy as np import tvm -from tvm.script import tir as T from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.base.roller.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul +from bitblas.ops.impl.convolution2d_impl import conv2d_nhwc_hwio, conv2d_nhwc_ohwi from bitblas.base.utils import apply_and_build import time -from tvm import te, tir - - -def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dtype="float16"): - A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) - B = te.placeholder((kh, kw, c, f), name="weight", dtype=in_dtype) - - pad_shape = (n, h + 2 * p, w + 2 * p, c) - pad_value = tir.const(0.0, A.dtype) - pad = te.compute( - pad_shape, - lambda n, h, w, c: te.if_then_else( - tir.all( - h >= p, - w >= p, - h < pad_shape[1] - p, - w < pad_shape[2] - p, - ), - A[n, h - p, w - p, c], - pad_value, - ), - name="pad", - ) - kernel_h, kernel_w = kh, kw - stride_h, stride_w = s, s - dilation_h, dilation_w = d, d - out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 - out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 - out_shape = (n, out_h, out_w, f) - kh = te.reduce_axis((0, kernel_h), name="kh") - kw = te.reduce_axis((0, kernel_w), name="kw") - c = te.reduce_axis((0, c), name="c") - C = te.compute( - out_shape, - lambda n, h, w, f: te.sum( - pad[ - n, - h * stride_h + kh * dilation_h, - w * stride_w + kw * dilation_w, - c, - ] - * B[kh, kw, c, f], - axis=[kh, kw, c], - ), - name="C", - ) - return tvm.ir.IRModule({"main": te.create_prim_func([A, B, C])}) - benchmark_sets = [ # (prim_func, input_args, default_dlight_schedule), - # (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float16", "float16"), Matmul), - # (conv2d_nhwc_hwio, (128, 64, 224, 224, 64, 1, 1, 2, 1, 3, "float16", "float16"), Matmul), - # (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float32", "float32"), Matmul), (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float16", "float16"), Matmul), + (conv2d_nhwc_ohwi, (128, 64, 56, 56, 64, 3, 3, 1, 1, 1, "float16", "float16"), Matmul), + (conv2d_nhwc_hwio, (128, 64, 56, 56, 64, 1, 1, 1, 1, 1, "float16", "float16"), Matmul), + (conv2d_nhwc_ohwi, (128, 64, 56, 56, 64, 1, 1, 1, 1, 1, "float16", "float16"), Matmul), + (conv2d_nhwc_ohwi, (128, 128, 28, 28, 128, 3, 3, 1, 1, 1, "float16", "float16"), Matmul), + (conv2d_nhwc_hwio, (128, 256, 14, 14, 128, 3, 3, 2, 1, 1, "float16", "float16"), Matmul), + (conv2d_nhwc_ohwi, (128, 256, 14, 14, 128, 1, 1, 2, 1, 1, "float16", "float16"), Matmul), ] benchmark_results = {} for get_prim_func, input_args, d_schedule in benchmark_sets: @@ -72,12 +27,14 @@ def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dty target = tvm.target.Target("nvidia/nvidia-a100") arch = CUDA(target) policy = DefaultPolicy(func=func, arch=arch) + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) try: - func, tags = get_tensorized_func_and_tags(func, arch.target) - except: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception as e: + print(f"Failed to get tensorized function and tags: {e}") tags = None if tags: - policy = TensorCorePolicy(func=func, arch=arch, tags=tags) + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) configs = policy.emit_config(20) @@ -104,8 +61,7 @@ def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dty tvm.nd.array( np.random.uniform(0, 1, [int(i) for i in arg.shape]).astype(arg.dtype), device=arch.device, - ) - ) + )) timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) t = timer_cuda_mod(*profile_tensors).mean @@ -133,9 +89,8 @@ def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dty "DefaultDLight Latency", ] -col_width = ( - max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 -) # padding +col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 + ) # padding print("".join(word.ljust(col_width) for word in headers)) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py new file mode 100644 index 000000000..7d1694f03 --- /dev/null +++ b/testing/python/module/test_bitblas_linear.py @@ -0,0 +1,195 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import Linear as BitBLASLinear +import torch +import time +import numpy as np +import torch.nn as nn +import pytest + +torch.manual_seed(0) + + +@pytest.mark.parametrize( + "m, in_features, out_features, bias", + [ + (1, 1024, 1024, False), + (1, 1024, 1024, True), + (1024, 1024, 1024, True), + ([1, 1024], 1024, 1024, True), + ], +) +def test_correctness_consistent(m, in_features, out_features, bias): + linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) + linear_bitblas = BitBLASLinear( + in_features, + out_features, + bias=bias, + A_dtype="float16", + W_dtype="float16", + accum_dtype="float16", + out_dtype="float16", + opt_M=m, + ).cuda() + + with torch.no_grad(): + linear_bitblas.load_and_transform_weight(linear_torch.weight.clone()) + if bias: + linear_bitblas.bias = nn.Parameter(linear_torch.bias.clone()) + + with torch.no_grad(): + if not isinstance(m, int): + # average m + m = sum(m) // len(m) + input_data = torch.randn(m, in_features, dtype=torch.float16).cuda() + output_torch = linear_torch(input_data) + output_bitblas = linear_bitblas(input_data) + torch.testing.assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2) + + +@pytest.mark.parametrize( + "m, in_features, out_features, bias, W_dtype, group_size, with_scaling, with_zeros, zeros_mode", + [ + (1, 1024, 1024, False, "uint4", -1, False, False, None), + (1, 1024, 1024, False, "uint4", -1, False, False, None), + (1024, 1024, 1024, True, "uint4", -1, False, False, None), + (1, 1024, 1024, True, "uint2", -1, True, False, None), + (1, 1024, 1024, True, "uint2", 128, True, True, "original"), + (1024, 1024, 1024, True, "uint2", 128, True, True, "original"), + (1, 1024, 1024, True, "uint2", 128, True, True, "rescale"), + ], +) +def test_correctness_weight_only_dequantize( + m, + in_features, + out_features, + bias, + W_dtype, + group_size, + with_scaling, + with_zeros, + zeros_mode, +): + import numpy as np + from bitblas.quantization.utils import general_compress + + linear_bitblas = BitBLASLinear( + in_features, + out_features, + bias=bias, + A_dtype="float16", + W_dtype=W_dtype, + accum_dtype="float16", + out_dtype="float16", + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + opt_M=m, + ).cuda() + if not isinstance(m, int): + # average m + m = sum(m) // len(m) + input_shape = (m, in_features) + weight_shape = (out_features, in_features) + output_shape = (m, out_features) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + source_format, bit = ( + linear_bitblas.bitblas_matmul.source_format, + linear_bitblas.bitblas_matmul.bit, + ) + + maxq = 2**(bit - 1) - 1 + zeros = maxq + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().numpy().astype(np.int8) + if source_format == "int": + intweight = intweight + maxq + if with_zeros: + inputs[1] = inputs[1] - zeros + bias_tensor = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + ref_result = torch.matmul(inputs[0], (inputs[1].t()).to(torch.float16)) + if bias: + ref_result = ref_result + bias_tensor + + with torch.no_grad(): + qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) + qw_torch = torch.from_numpy(qw_np).cuda() + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + if linear_bitblas.bitblas_matmul.weight_transform is not None: + permuted_inputs.append( + linear_bitblas.bitblas_matmul.weight_transform(qw_torch.cpu()).cuda()) + else: + permuted_inputs.append(qw_torch) + linear_bitblas.qweight.data = permuted_inputs[-1].clone() + if with_scaling: + if group_size == -1: + group_size = in_features + permuted_inputs.append( + torch.ones([out_features, in_features // group_size], dtype=torch.float16).cuda()) + linear_bitblas.scales.data = permuted_inputs[-1].clone() + if with_zeros: + if zeros_mode == "original": + permuted_inputs.append( + torch.ones([out_features, in_features // group_size], + dtype=torch.float16).cuda() * zeros) + elif zeros_mode == "rescale": + original_zeros = ( + torch.ones([out_features, in_features // group_size], + dtype=torch.float16).cuda() * zeros) + scaled_zeros = original_zeros * permuted_inputs[-1] + permuted_inputs.append(scaled_zeros) + elif zeros_mode == "quantized": + original_zeros = ( + torch.ones([in_features // group_size, out_features], dtype=torch.int8).cuda() * + zeros) + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + linear_bitblas.zeros.data = permuted_inputs[-1].clone() + if bias: + permuted_inputs.append(bias_tensor) + linear_bitblas.bias.data = bias_tensor.clone() + + with torch.no_grad(): + output_bitblas = linear_bitblas(inputs[0]) + torch.testing.assert_close(output_bitblas, ref_result, rtol=1e0, atol=1e0) + + +def profile(model, input_data): + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + # print("Warming up ...") + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py new file mode 100644 index 000000000..6eb588d9e --- /dev/null +++ b/testing/python/operators/test_general_matmul_ops.py @@ -0,0 +1,286 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import bitblas +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +# fmt: off +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + (768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + (1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + (768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original"), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original"), + ], +) +def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False) + assert get_codegen_result(matmul) + + +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + (768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + (1, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + (768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, False, None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original"), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original"), + ], +) +def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False) + matmul.hardware_aware_finetune(topk=10) + assert get_codegen_result(matmul) + + +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 1024, 1024, "float16", "int4", "float16", "float16", "nt", None, None, None, None, + None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, None), + (1, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original"), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, False, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", True, -1, False, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, False, + None), + (768, 768, 768, "float16", "uint4", "float16", "float16", "nt", False, -1, True, True, + "original"), + ], +) +def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + import torch + torch.random.manual_seed(0) + import numpy as np + from bitblas.quantization import general_compress + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False) + + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + source_format, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype] + maxq = 2**(bit - 1) - 1 + zeros = maxq + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().numpy().astype(np.int8) + if source_format == "int": + intweight = intweight + maxq + if with_zeros: + inputs[1] = inputs[1] - zeros + bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + ref_result = torch.matmul(inputs[0], + (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) + if with_bias: + ref_result = ref_result + bias + qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) + qw_torch = torch.from_numpy(qw_np).cuda() + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + if matmul.weight_transform is not None: + permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda()) + else: + permuted_inputs.append(qw_torch) + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + if with_zeros: + if zeros_mode == "original": + permuted_inputs.append( + torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + elif zeros_mode == "rescale": + original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros + scaled_zeros = original_zeros * permuted_inputs[-1] + permuted_inputs.append(scaled_zeros) + elif zeros_mode == "quantized": + original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + if with_bias: + permuted_inputs.append(bias) + permuted_inputs.append(inputs[2]) + matmul(*permuted_inputs) + print(permuted_inputs[-1]) + print(ref_result) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e-0) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e-1) + + +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,with_bias", + [ + (1, 768, 768, "float16", "uint4", "float16", "float16", False), + (1, 768, 768, "float16", "int4", "float16", "float16", False), + (768, 768, 768, "float16", "uint4", "float16", "float16", False), + (768, 768, 768, "float16", "int4", "float16", "float16", False), + ], +) +def test_matmul_transform_weight( + M, + N, + K, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_bias, +): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + with_bias=with_bias, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False) + + input_shape = (M, K) + weight_shape = (N, K) + output_shape = (M, N) + + _, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype] + maxq = 2**(bit - 1) - 1 + + input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda() + intweight_tensor = torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda() + output_tensor = torch.rand(output_shape, dtype=torch.float16).cuda() + + bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + ref_result = torch.matmul(input_tensor, intweight_tensor.t().to(torch.float16)) + if with_bias: + ref_result = ref_result + bias + + bitblas_inputs = [input_tensor] + intweight_tensor = matmul.transform_weight(intweight_tensor) + bitblas_inputs.append(intweight_tensor) + if with_bias: + bitblas_inputs.append(bias) + output_tensor = matmul(*bitblas_inputs) + torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0) + + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_ladder_permutate_ops.py b/testing/python/operators/test_ladder_permutate_ops.py index a50fb11c7..e14b5c509 100644 --- a/testing/python/operators/test_ladder_permutate_ops.py +++ b/testing/python/operators/test_ladder_permutate_ops.py @@ -7,14 +7,17 @@ target = tvm.target.Target("llvm") + # fmt: off -@pytest.mark.parametrize("M,N,datatype,dequantize_bits,storage_dtype,propagate_kind,transpose_matrix,transform_kind,target_instruction", [ - (1024, 1024, "float16", -1, "float16", "B", True, 0, "nvidia-mma"), - (1024, 1024, "float16", -1, "float16", "B", True, 1, "nvidia-mma"), - (1024, 1024, "float16", -1, "float16", "B", True, 2, "nvidia-mma"), - # dequantize propagation - (1024, 1024, "float16", 4, "uint32", "B", True, 2, "nvidia-mma"), -]) +@pytest.mark.parametrize( + "M,N,datatype,dequantize_bits,storage_dtype,propagate_kind,transpose_matrix,transform_kind,target_instruction", + [ + (1024, 1024, "float16", -1, "float16", "B", True, 0, "nvidia-mma"), + (1024, 1024, "float16", -1, "float16", "B", True, 1, "nvidia-mma"), + (1024, 1024, "float16", -1, "float16", "B", True, 2, "nvidia-mma"), + # dequantize propagation + (1024, 1024, "float16", 4, "uint32", "B", True, 2, "nvidia-mma"), + ]) def test_ladder_permutate_profile_latency( M, N, @@ -45,6 +48,49 @@ def test_ladder_permutate_profile_latency( latency = ladder_permutate.profile_latency() assert latency + +@pytest.mark.parametrize( + "M,N,datatype,dequantize_bits,storage_dtype,propagate_kind,transpose_matrix,transform_kind,target_instruction", + [ + (1024, 1024, "float16", -1, "float16", "A", True, 0, "nvidia-mma"), + (1024, 1024, "float16", -1, "float16", "A", True, 1, "nvidia-mma"), + (1024, 1024, "float16", -1, "float16", "A", True, 2, "nvidia-mma"), + # dequantize propagation + (1024, 1024, "float16", 4, "uint32", "A", True, 2, "nvidia-mma"), + ]) +def test_ladder_permutate_profile_latency_cuda( + M, + N, + datatype, + dequantize_bits, + storage_dtype, + propagate_kind, + transpose_matrix, + transform_kind, + target_instruction, +): + + ladder_permutate_config = LadderPermutateConfig( + M=M, + N=N, + datatype=datatype, + dequantize_bits=dequantize_bits, + storage_dtype=storage_dtype, + propagate_kind=propagate_kind, + transpose_matrix=transpose_matrix, + transform_kind=transform_kind, + target_instruction=target_instruction, + ) + ladder_permutate = LadderPermutate( + config=ladder_permutate_config, + target="cuda", + ) + # ladder_permutate.hardware_aware_finetune() + latency = ladder_permutate.profile_latency() + print(latency) + assert latency + + # fmt: on if __name__ == "__main__": diff --git a/testing/python/operators/test_matmul_dequantize_ops.py b/testing/python/operators/test_matmul_dequantize_ops.py index 40a6012c1..018ad8256 100644 --- a/testing/python/operators/test_matmul_dequantize_ops.py +++ b/testing/python/operators/test_matmul_dequantize_ops.py @@ -3,28 +3,75 @@ import pytest import tvm import bitblas -from bitblas.utils import get_target_from_env +from bitblas.utils import auto_detect_nvidia_target from bitblas.ops.matmul_dequantize import ( MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig, ) +import logging +from bitblas import set_log_level -target = tvm.target.Target(get_target_from_env()) +set_log_level(logging.DEBUG) +target = tvm.target.Target(auto_detect_nvidia_target()) def get_codegen_result(ops, target): - code = ops.codegen(target=target) + code = ops.get_source(target=target) return code # fmt: off +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,layout,propagate_a,propagate_b,zeros_mode", + [ + (16, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, + False, "nt", False, False, "original"), + (1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, + False, "nt", False, False, "original"), + (1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", True, True, -1, True, + True, "nt", True, True, "original"), + ], +) +def test_matmul_dequantize_codegen_default(M, N, K, in_dtype, out_dtype, accum_dtype, bit, + storage_dtype, source_format, with_scaling, with_zeros, + group_size, fast_decoding, with_bias, layout, + propagate_a, propagate_b, zeros_mode): + + matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + zeros_mode=zeros_mode, + ) + matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, + ) + assert get_codegen_result(matmul, target) + + @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", [ - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, False, False, False, "nt"), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, + False, False, False, "nt"), ], ) -def test_matmul_dequantize_codegen_default( +def test_matmul_dequantize_retrieve_weight_shape( M, N, K, @@ -67,13 +114,50 @@ def test_matmul_dequantize_codegen_default( config=matmul_config, target=target, ) - assert get_codegen_result(matmul, target) + assert matmul.retrieve_weight_shape() @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", [ - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, False, False, False, "nt",), + ( + 1, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "uint", + False, + False, + -1, + False, + False, + False, + False, + "nt", + ), + ( + 1, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "uint", + False, + False, + -1, + False, + False, + False, + True, + "nt", + ), ], ) def test_matmul_dequantize_codegen_finetune( @@ -126,13 +210,139 @@ def test_matmul_dequantize_codegen_finetune( @pytest.mark.parametrize( "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", [ - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, False, False, False, "nt",), - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "af", False, False, -1, False, False, False, False, "nt",), - (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "af", False, False, -1, False, False, False, False, "nt",), - (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "af", False, False, -1, False, False, False, True, "nt",), - (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "af", False, False, -1, False, False, True, True, "nt",), - (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "af", True, False, -1, False, False, True, True, "nt",), - (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "af", True, False, 128, False, False, True, True, "nt",), + ( + 1, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "uint", + False, + False, + -1, + False, + False, + False, + False, + "nt", + ), + ( + 1, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + False, + False, + -1, + False, + False, + False, + False, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + False, + False, + -1, + False, + False, + False, + False, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + False, + False, + -1, + False, + False, + False, + True, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + False, + False, + -1, + False, + False, + True, + True, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + True, + False, + -1, + False, + False, + True, + True, + "nt", + ), + ( + 1024, + 1024, + 1024, + "float16", + "float16", + "float16", + 4, + "int8", + "nf", + True, + False, + 128, + False, + False, + True, + True, + "nt", + ), ], ) def test_matmul_dequantize_profile_latency( @@ -182,46 +392,46 @@ def test_matmul_dequantize_profile_latency( latency = matmul.profile_latency() assert latency + @pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,zeros_type", + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,zeros_mode", [ - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "int", False, False, -1, False, False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "int", False, False, -1, True, False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", False, False, -1, True, False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, -1, True, False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, True, False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "uint", True, True, 128, False, False, False, False, "nt", "rescale"), - (1, 1024, 4096, "float16", "float16", "float16", 2, "int8", "uint", True, True, 128, True, False, False, False, "nt", "rescale"), - (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, False, False, False, False, "nt", "rescale"), - (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, False, False, False, True, "nt", "rescale"), - (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, False, False, True, True, "nt", "rescale"), - (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, False, False, True, True, "nt", "original"), - ([1, 1024], 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, False, False, False, "nt", "original"), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, + False, False, False, "nt", "rescale"), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, + False, False, False, "nt", "rescale"), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "int", False, False, -1, False, + False, False, False, "nt", "rescale"), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "int", False, False, -1, True, + False, False, False, "nt", "rescale"), + (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", False, False, -1, True, + False, False, False, "nt", "rescale"), + (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, -1, True, + False, False, False, "nt", "rescale"), + (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, True, + False, False, False, "nt", "rescale"), + (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "uint", True, True, 128, False, + False, False, False, "nt", "rescale"), + (1, 1024, 4096, "float16", "float16", "float16", 2, "int8", "uint", True, True, 128, True, + False, False, False, "nt", "rescale"), + (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, + False, False, False, False, "nt", "rescale"), + (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, + False, False, False, True, "nt", "rescale"), + (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, + False, False, True, True, "nt", "rescale"), + (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, + False, False, True, True, "nt", "original"), + ([1, 1024], 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, + -1, False, False, False, False, "nt", "original"), ], ) -def test_matmul_dequantize_torch_forward( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - propagate_a, - propagate_b, - layout, - zeros_type -): +def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dtype, bit, + storage_dtype, source_format, with_scaling, with_zeros, + group_size, fast_decoding, with_bias, propagate_a, + propagate_b, layout, zeros_mode): import torch + torch.random.manual_seed(0) import numpy as np from bitblas.quantization.utils import general_compress @@ -243,8 +453,7 @@ def test_matmul_dequantize_torch_forward( propagate_a=propagate_a, propagate_b=propagate_b, layout=layout, - zeros_type=zeros_type - ) + zeros_mode=zeros_mode) matmul = MatmulWeightOnlyDequantize( config=matmul_config, target=target, @@ -256,8 +465,8 @@ def test_matmul_dequantize_torch_forward( weight_shape = (N, K) if layout == "nt" else (K, N) output_shape = (M, N) inputs = [] - inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda()) - maxq = 2 ** (bit - 1) - 1 + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + maxq = 2**(bit - 1) - 1 zeros = maxq if source_format == "uint": inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) @@ -267,27 +476,27 @@ def test_matmul_dequantize_torch_forward( raise NotImplementedError inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) - + intweight = inputs[1] intweight = intweight.cpu().numpy().astype(np.int8) if source_format == "int": intweight = intweight + maxq if with_zeros: inputs[1] = inputs[1] - zeros - - ref_result = torch.matmul(inputs[0], (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) - - qw_np = general_compress( - intweight, source_bits=bit, storage_dtype=np.int8 - ) + bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + ref_result = torch.matmul(inputs[0], + (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) + if with_bias: + ref_result = ref_result + bias + qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) qw_torch = torch.from_numpy(qw_np).cuda() permuted_inputs = [] - permuted_inputs.append(inputs[0]) - + if matmul.input_transform is not None: + permuted_inputs.append(matmul.input_transform(inputs[0].cpu()).cuda()) + else: + permuted_inputs.append(inputs[0]) if matmul.weight_transform is not None: - permuted_inputs.append( - matmul.weight_transform(qw_torch.cpu()).cuda() - ) + permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda()) else: permuted_inputs.append(qw_torch) if with_scaling: @@ -296,9 +505,296 @@ def test_matmul_dequantize_torch_forward( permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) if with_zeros: permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + if with_bias: + permuted_inputs.append(bias) permuted_inputs.append(inputs[2]) matmul(*permuted_inputs) - torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-0, atol=1e-0) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-0, atol=1e-1) + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,layout,zeros_mode", + [ + (16, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, + False, "nt", "original"), + (16, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, True, -1, True, + True, "nt", "original"), + (16, 3072, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, + False, "nt", "original"), + (16, 768, 3072, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, + False, "nt", "original"), + ], +) +def test_matmul_dequantize_propgate_comparison(M, N, K, in_dtype, out_dtype, accum_dtype, bit, + storage_dtype, source_format, with_scaling, + with_zeros, group_size, fast_decoding, with_bias, + layout, zeros_mode): + import torch + torch.random.manual_seed(0) + original_matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=False, + with_bias=with_bias, + propagate_a=False, + propagate_b=False, + layout=layout, + zeros_mode=zeros_mode) + original_matmul = MatmulWeightOnlyDequantize( + config=original_matmul_config, + target=target, + ) + if not isinstance(M, int): + M = 32 + + if group_size == -1: + group_size = K + input_shape = (M, K) + weight_shape = (N, K // 2) if layout == "nt" else (K, N) + output_shape = (M, N) + scales_shape = (N, K // group_size) + zeros_shape = (N, K // group_size) + bias_shape = (N,) + + inputs = [] + input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda() + weight_tensor = torch.randint(0, 2**(bit - 1) - 1, weight_shape, dtype=torch.int8).cuda() + scales_tensor = torch.rand(scales_shape, dtype=torch.float16).cuda() + zeros_tensor = torch.rand(zeros_shape, dtype=torch.float16).cuda() + bias_tensor = torch.rand(bias_shape, dtype=torch.float16).cuda() + output_tensor = torch.zeros(output_shape, dtype=torch.float16).cuda() + inputs.append(input_tensor) + inputs.append(weight_tensor) + if with_scaling: + inputs.append(scales_tensor) + if with_zeros: + inputs.append(zeros_tensor) + if with_bias: + inputs.append(bias_tensor) + inputs.append(output_tensor) + ref_result = original_matmul(*inputs) + + propagated_matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=False, + propagate_b=True, + layout=layout, + zeros_mode=zeros_mode) + propagated_matmul = MatmulWeightOnlyDequantize( + config=propagated_matmul_config, + target=target, + ) + + propagated_matmul.hardware_aware_finetune(topk=20) + propagated_inputs = [] + propagated_inputs.append(input_tensor) + if propagated_matmul.weight_transform is not None: + propagated_inputs.append(propagated_matmul.weight_transform(weight_tensor.cpu()).cuda()) + else: + propagated_inputs.append(weight_tensor) + if with_scaling: + propagated_inputs.append(scales_tensor) + if with_zeros: + propagated_inputs.append(zeros_tensor) + if with_bias: + propagated_inputs.append(bias_tensor) + propagated_inputs.append(torch.zeros(output_shape, dtype=torch.float16).cuda()) + + propagated_result = propagated_matmul(*propagated_inputs) + torch.testing.assert_close(ref_result, propagated_result, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,source_zeros_mode,target_zeros_mode", + [ + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, + False, False, False, "nt", "rescale", "quantized"), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, + False, False, False, "nt", "rescale", "original"), + (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, + False, False, False, False, "nt", "rescale", "quantized"), + (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, + False, False, False, False, "nt", "rescale", "original"), + ], +) +def test_matmul_dequantize_diff_zero_types(M, N, K, in_dtype, out_dtype, accum_dtype, bit, + storage_dtype, source_format, with_scaling, with_zeros, + group_size, fast_decoding, with_bias, propagate_a, + propagate_b, layout, source_zeros_mode, + target_zeros_mode): + import torch + torch.random.manual_seed(0) + import numpy as np + from bitblas.quantization.utils import general_compress + + source_quantized_matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + zeros_mode=source_zeros_mode) + source_quantized_matmul = MatmulWeightOnlyDequantize( + config=source_quantized_matmul_config, + target=target, + ) + if not isinstance(M, int): + M = 32 + source_quantized_matmul.hardware_aware_finetune(topk=20) + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + maxq = 2**(bit - 1) - 1 + zeros = maxq + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().numpy().astype(np.int8) + if source_format == "int": + intweight = intweight + maxq + if with_zeros: + inputs[1] = inputs[1] - zeros + bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) + qw_torch = torch.from_numpy(qw_np).cuda() + permuted_inputs = [] + if source_quantized_matmul.input_transform is not None: + permuted_inputs.append(source_quantized_matmul.input_transform(qw_torch.cpu()).cuda()) + else: + permuted_inputs.append(inputs[0]) + if source_quantized_matmul.weight_transform is not None: + permuted_inputs.append(source_quantized_matmul.weight_transform(qw_torch.cpu()).cuda()) + else: + permuted_inputs.append(qw_torch) + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(torch.rand([N, K // group_size], dtype=torch.float16).cuda()) + if with_zeros: + if source_zeros_mode == "original": + permuted_inputs.append( + torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + elif source_zeros_mode == "rescale": + original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros + scaled_zeros = original_zeros * permuted_inputs[-1] + permuted_inputs.append(scaled_zeros) + elif source_zeros_mode == "quantized": + original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + + if with_bias: + permuted_inputs.append(bias) + permuted_inputs.append(inputs[2]) + source_quantized_matmul(*permuted_inputs) + ref_result = permuted_inputs[-1] + target_quantized_matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + zeros_mode=target_zeros_mode) + target_quantized_matmul = MatmulWeightOnlyDequantize( + config=target_quantized_matmul_config, + target=target, + ) + if not isinstance(M, int): + M = 32 + target_quantized_matmul.hardware_aware_finetune(topk=20) + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + + target_inputs = [] + target_inputs.append(permuted_inputs[0]) + target_inputs.append(permuted_inputs[1]) + + if with_scaling: + target_inputs.append(permuted_inputs[2]) + if with_zeros: + if target_zeros_mode == "original": + target_inputs.append( + torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + elif target_zeros_mode == "rescale": + original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros + scaled_zeros = original_zeros * target_inputs[-1] + target_inputs.append(scaled_zeros) + elif target_zeros_mode == "quantized": + original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + target_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + if with_bias: + target_inputs.append(bias) + target_inputs.append(torch.zeros_like(inputs[2])) + target_quantized_matmul(*target_inputs) + torch.testing.assert_close(target_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) + # fmt: on diff --git a/testing/python/operators/test_matmul_ops.py b/testing/python/operators/test_matmul_ops.py index 7418a3781..2baac71d2 100644 --- a/testing/python/operators/test_matmul_ops.py +++ b/testing/python/operators/test_matmul_ops.py @@ -1,16 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import pytest -import tvm import bitblas from bitblas.ops.matmul import Matmul, MatmulConfig -from bitblas.utils import tvm_tensor_to_torch +from bitblas.utils import auto_detect_nvidia_target -target = tvm.target.Target("nvidia/nvidia-a100") +target = auto_detect_nvidia_target() def get_codegen_result(ops, target): - code = ops.codegen(target=target) + code = ops.get_source(target=target) return code @@ -143,6 +142,9 @@ def test_matmul_profile_latency( [ (256, 256, 256, "float16", "float16", "float16", False, False, False, "nt"), (256, 256, 256, "float16", "float16", "float16", False, False, True, "nt"), + (256, 256, 256, "float16", "float16", "float16", False, False, 0, "nt"), + (256, 256, 256, "float16", "float16", "float16", False, False, 1, "nt"), + (256, 256, 256, "float16", "float16", "float16", False, False, 2, "nt"), ], ) def test_matmul_torch_forward( @@ -188,21 +190,18 @@ def test_matmul_torch_forward( permuted_inputs = [] if matmul.input_transform is not None: - permuted_inputs.append( - matmul.input_transform(inputs[0].cpu()) - ).cuda() + permuted_inputs.append(matmul.input_transform(inputs[0].cpu())).cuda() else: permuted_inputs.append(inputs[0]) if matmul.weight_transform is not None: - permuted_inputs.append( - matmul.weight_transform(inputs[1].cpu()).cuda() - ) + permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda()) else: permuted_inputs.append(inputs[1]) permuted_inputs.append(inputs[2]) matmul(*permuted_inputs) torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) + # fmt: on if __name__ == "__main__": diff --git a/testing/python/operators/test_param_permutate_ops.py b/testing/python/operators/test_param_permutate_ops.py new file mode 100644 index 000000000..4f8d32c80 --- /dev/null +++ b/testing/python/operators/test_param_permutate_ops.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import tvm +import bitblas +from bitblas.ops.param_permutate import ParamPermutate, ParamPermutateConfig + +target = tvm.target.Target("llvm") + + +# fmt: off +@pytest.mark.parametrize( + "M,N,datatype,transpose_matrix,group_size,propagate_kind,target_instruction", [ + (1024, 1024, "float16", True, 1, True, "nvidia-mma"), + ]) +def test_param_permutate_profile_latency( + M, + N, + datatype, + transpose_matrix, + group_size, + propagate_kind, + target_instruction, +): + param_permutate_config = ParamPermutateConfig( + M=M, + N=N, + datatype=datatype, + propagate_kind=propagate_kind, + group_size=group_size, + transpose_matrix=transpose_matrix, + target_instruction=target_instruction, + ) + param_permutate = ParamPermutate( + config=param_permutate_config, + target=target, + ) + latency = param_permutate.profile_latency() + assert latency + + +# fmt: on + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/tir_expr/float16xfloat16_gemm.py b/testing/python/tir_expr/float16xfloat16_gemm.py deleted file mode 100644 index 75fc24ef0..000000000 --- a/testing/python/tir_expr/float16xfloat16_gemm.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from tvm.script import tir as T -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.base.utils import apply_and_build -from bitblas.ops.matmul_impl import matmul_nt, matmul_nt_propagate_b_s8_s8_s32_mma -import numpy as np - - -def test_f16_f16_gemm(): - ir_module = matmul_nt(1024, 1024, 1024, "float16", "float16") - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(1) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3) - ) - - numpy_a = np.random.randint(-4, 3, (1024, 1024)).astype("float16") - numpy_b = np.random.randint(-4, 3, (1024, 1024)).astype("float16") - numpy_c = np.matmul(numpy_a.astype("float16"), numpy_b.T.astype("float16")) - ctx = tvm.cuda() - tvm_a = tvm.nd.array(numpy_a, device=ctx) - tvm_b = tvm.nd.array(numpy_b, device=ctx) - tvm_c = tvm.nd.array(np.zeros((1024, 1024), dtype="float16"), device=ctx) - print(best.code) - best.mod(tvm_a, tvm_b, tvm_c) - print(best.config) - print("numpy_c ", numpy_c) - print("tvm_c.asnumpy() ", tvm_c.asnumpy()) - - -def test_i8_i8_gemm_propagate_b(): - ir_module = matmul_nt_propagate_b_s8_s8_s32_mma( - 16384, 16384, 16384, "int8", "int32" - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(1) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3) - ) - print(best.sch.mod) - - -test_f16_f16_gemm() -# test_i8_i8_gemm_propagate_b() diff --git a/testing/python/tir_expr/int8xint8_gemm.py b/testing/python/tir_expr/int8xint8_gemm.py deleted file mode 100644 index 1515c307c..000000000 --- a/testing/python/tir_expr/int8xint8_gemm.py +++ /dev/null @@ -1,368 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -import bitblas -import numpy as np -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.base.utils import apply_and_build -from bitblas.ops.matmul_impl import ( - matmul_nt, - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_b, - matmul_nt_dequantize_b_propagate_a_b, - matmul_nt_propagate_b_s8_s8_s32_mma, - matmul_nt_propagate_b_s8_s8_s32_cast_s8_mma, - matmul_nt_propagate_a_propagate_b_s8_s8_s32_mma, - matmul_nt_propagate_a_propagate_b_s8_s8_s32_mma_cast_s8, -) - - -def test_i8_i8_gemm(): - ir_module = matmul_nt(16384, 16384, 16384, "int8", "int32") - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - with open("debug/after_memory_rewrite.cu", "+w") as f: - f.write(best.code) - - -def test_i8_i8_gemm_correctness(): - ir_module = matmul_nt(1024, 1024, 1024, "int8", "int32") - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - - numpy_a = np.random.randint(-4, 3, (1024, 1024)).astype("int8") - numpy_b = np.random.randint(-4, 3, (1024, 1024)).astype("int8") - numpy_c = np.matmul(numpy_a.astype("int32"), numpy_b.T.astype("int32")) - ctx = tvm.cuda() - tvm_a = tvm.nd.array(numpy_a, device=ctx) - tvm_b = tvm.nd.array(numpy_b, device=ctx) - tvm_c = tvm.nd.array(np.zeros((1024, 1024), dtype="int32"), device=ctx) - # print(best.sch.mod) - # print(best.code) - best.mod(tvm_a, tvm_b, tvm_c) - print(best.config) - print("numpy_c ", numpy_c) - print("tvm_c.asnumpy() ", tvm_c.asnumpy()) - - np.testing.assert_allclose(tvm_c.asnumpy(), numpy_c, atol=1e-5) - # print(best.code) - - -def test_i8_i8_i32_gemm_propagate_b(): - ir_module = matmul_nt_propagate_b_s8_s8_s32_mma( - 16384, 16384, 16384, "int8", "int32" - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - - -def test_i8_i8_i32_cast_i8_gemm_propagate_b(): - ir_module = matmul_nt_propagate_b_s8_s8_s32_cast_s8_mma( - 16384, 16384, 16384, "int8", "int32" - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - - -def test_i8_i8_i32_gemm_propagate_a_propagate_b(): - ir_module = matmul_nt_propagate_a_propagate_b_s8_s8_s32_mma( - 16384, 16384, 16384, "int8", "int32" - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - - -def test_i8_i8_i32_gemm_propagate_a_propagate_b_cast_s8(): - ir_module = matmul_nt_propagate_a_propagate_b_s8_s8_s32_mma_cast_s8( - 16384, 16384, 16384, "int8", "int32" - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - - -def test_i8_i4_gemm(): - ir_module = matmul_nt_dequantize_b(16384, 16384, 16384, "int8", "int32") - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - - -def test_i8_i4_propagate_b_gemm(): - ir_module = matmul_nt_dequantize_b_propagate_b(16384, 16384, 16384, "int8", "int32") - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - # print(best.sch.mod) - print(best.code) - - -def test_i8_i4_propagate_a_propagate_b_gemm(): - ir_module = matmul_nt_dequantize_b_propagate_a_b( - 16384, 16384, 16384, "int8", "int32" - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - print(best.config) - - -def test_i8_i2_gemm(): - ir_module = matmul_nt_dequantize_b(1, 16384, 16384, "int8", "int32", bit=2) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - print(configs) - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - print(best.code) - - -def test_i8_i2_propagate_b_gemm(): - ir_module = matmul_nt_dequantize_b_propagate_b( - 16384, - 16384, - 16384, - "int8", - "int8", - accum_dtype="int32", - bit=2, - fast_decoding=True, - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - with open("debug/after_memory_rewrite.cu", "+w") as f: - f.write(best.code) - - -def test_i8_i2_propagate_a_propagate_b_gemm(): - ir_module = matmul_nt_dequantize_b_propagate_a_b( - 16384, 16384, 16384, "int8", "int32", "int8", bit=2, fast_decoding=False - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) - with open("debug/after_memory_rewrite.cu", "+w") as f: - f.write(best.code) - - -# test_i8_i8_gemm() -# test_i8_i8_gemm_correctness() -# test_i8_i8_i32_gemm_propagate_b() -# test_i8_i8_i32_cast_i8_gemm_propagate_b() -# test_i8_i8_i32_gemm_propagate_a_propagate_b() -# test_i8_i8_i32_gemm_propagate_a_propagate_b_cast_s8() -# test_i8_i4_gemm() -# test_i8_i4_propagate_b_gemm() -# test_i8_i4_propagate_a_propagate_b_gemm() - -test_i8_i2_gemm() -# test_i8_i2_propagate_b_gemm() -# test_i8_i2_propagate_a_propagate_b_gemm() diff --git a/testing/python/tir_expr/test_fused_decode_matmul.py b/testing/python/tir_expr/test_fused_decode_matmul.py deleted file mode 100644 index fb90ede36..000000000 --- a/testing/python/tir_expr/test_fused_decode_matmul.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.script import relax as R - - -@T.prim_func -def fused_fused_decode3_fused_NT_matmul8_add1( - lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), - lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), - p_lv41: T.handle, - p_lv2: T.handle, - p_output0: T.handle, -): - T.func_attr( - { - "tir.noalias": T.bool(True), - "dequantize_info": { - "B": { - "decode_block": "decode", - "fast_decoding": True, - "source_format": { - "bits": 4, - "format": "int", - }, - "with_scaling": True, - "storage_dtype": "uint32", - "group_size": 32, - "target_format": "float16", - } - }, - } - ) - n = T.int64() - lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16") - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16") - T_add_intermediate_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(4096)), "float16" - ) - # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) - T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv41[v_i0, v_i1, v_k], decode_intermediate_intermediate[v_i2, v_k]) - T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv2[v_ax0, v_ax1, v_ax2], NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) - T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = ( - lv2[v_ax0, v_ax1, v_ax2] + NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - - -import tvm -from tvm import dlight as dl -import bitblas -from tvm import relax -import bitblas -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags - -dispatch_target = tvm.target.Target("cuda") -mod_deploy = tvm.IRModule.from_expr(fused_fused_decode3_fused_NT_matmul8_add1.specialize({"n": T.int64(1)})) -target = tvm.target.Target("nvidia/nvidia-a100") -arch = CUDA(target) -func = fused_fused_decode3_fused_NT_matmul8_add1.specialize({"n": T.int64(1)}) -policy = DefaultPolicy(func=func, arch=arch) -try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) -except: - tags = None -if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - -configs = policy.emit_config(20) -print(configs[0]) -sch = bitblas.gpu.gemv.GEMVWithDequantizeInfo().apply_config(func, configs[0]) - -# print(sch.mod) -# with dispatch_target: -# mod_deploy = dl.ApplyDefaultSchedule( # pylint: disable=not-callable -# dl.gpu.Matmul(), -# dl.gpu.GEMV(), -# dl.gpu.Reduction(), -# dl.gpu.GeneralReduction(), -# dl.gpu.Fallback(), -# )(mod_deploy) -# dynamic_range = { -# "n": [64], -# } -# mod_deploy = bitblas.ApplyFastTuning( -# topk=20, -# target=dispatch_target, -# meta_database_dir="vicuna_tune", -# whitelist=["matmul"], -# )(mod_deploy) - -# with tvm.transform.PassContext(config={"tir.use_async_copy": False}): -# mod = tvm.build(mod_deploy, target=dispatch_target) - -# with open("debug/test_dl_fused_decode_matmul.cu", "+w") as f: -# f.write(mod.imported_modules[0].get_source()) diff --git a/testing/python/tir_expr/test_matmul_codegen.py b/testing/python/tir_expr/test_matmul_codegen.py deleted file mode 100644 index 7fe7a4831..000000000 --- a/testing/python/tir_expr/test_matmul_codegen.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from tvm.script import ir as I -from tvm.script import tir as T -import bitblas -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.base.utils import apply_and_build -from bitblas.ops.matmul_impl import matmul_nt, matmul_nt_dequantize_b -import numpy as np - - -@I.ir_module -class Module: - @T.prim_func - def dequantize_gemv( - lv47: T.Buffer((T.int64(256), T.int64(256), T.int64(16), T.int64(2)), "uint32"), - lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), - lv41: T.Buffer((T.int64(1), 1, T.int64(4096)), "float16"), - NT_matmul_intermediate: T.Buffer((T.int64(1), 1, T.int64(4096)), "float16"), - ): - T.func_attr( - { - "dequantize_info": { - "decode": { - "decode_block": "decode", - "fast_decoding": T.bool(True), - "group_size": 32, - "source_format": {"bits": 4, "format": "int"}, - "storage_dtype": "uint32", - "target_format": "float16", - "with_scaling": T.bool(True), - } - }, - "smooth_b": T.bool(True), - "tir.noalias": T.bool(True), - "transform_kind": 2, - } - ) - # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) - lv47_global = T.alloc_buffer((T.int64(4096), T.int64(512)), "uint32") - for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)): - with T.block("lv47_global"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads( - lv47[ - v0 // T.int64(16), - v1 // T.int64(2), - v0 % T.int64(16) // T.int64(8) * T.int64(8) - + v0 % T.int64(4) * T.int64(2) - + v1 % T.int64(2), - v0 % T.int64(8) // T.int64(4), - ] - ) - T.writes(lv47_global[v0, v1]) - lv47_global[v0, v1] = lv47[ - v0 // T.int64(16), - v1 // T.int64(2), - v0 % T.int64(16) // T.int64(8) * T.int64(8) - + v0 % T.int64(4) * T.int64(2) - + v1 % T.int64(2), - v0 % T.int64(8) // T.int64(4), - ] - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads( - lv47_global[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)] - ) - T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47_global[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] - for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads( - lv41[v_i0, v_i1, v_k], decode_intermediate_intermediate[v_i2, v_k] - ) - T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] - * decode_intermediate_intermediate[v_i2, v_k] - ) - - @T.prim_func - def dequantize_gemm( - lv47: T.Buffer((T.int64(256), T.int64(256), T.int64(16), T.int64(2)), "uint32"), - lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), - lv41: T.Buffer((T.int64(1), T.int64(4096), T.int64(4096)), "float16"), - NT_matmul_intermediate: T.Buffer((T.int64(1), 4096, T.int64(4096)), "float16"), - ): - T.func_attr( - { - "dequantize_info": { - "decode": { - "decode_block": "decode", - "fast_decoding": T.bool(True), - "group_size": 32, - "source_format": {"bits": 4, "format": "int"}, - "storage_dtype": "uint32", - "target_format": "float16", - "with_scaling": T.bool(True), - } - }, - "smooth_b": T.bool(True), - "tir.noalias": T.bool(True), - "transform_kind": 2, - } - ) - # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) - lv47_global = T.alloc_buffer((T.int64(4096), T.int64(512)), "uint32") - for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)): - with T.block("lv47_global"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads( - lv47[ - v0 // T.int64(16), - v1 // T.int64(2), - v0 % T.int64(16) // T.int64(8) * T.int64(8) - + v0 % T.int64(4) * T.int64(2) - + v1 % T.int64(2), - v0 % T.int64(8) // T.int64(4), - ] - ) - T.writes(lv47_global[v0, v1]) - lv47_global[v0, v1] = lv47[ - v0 // T.int64(16), - v1 // T.int64(2), - v0 % T.int64(16) // T.int64(8) * T.int64(8) - + v0 % T.int64(4) * T.int64(2) - + v1 % T.int64(2), - v0 % T.int64(8) // T.int64(4), - ] - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads( - lv47_global[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)] - ) - T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47_global[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] - for i0, i1, i2, k in T.grid( - T.int64(1), T.int64(4096), T.int64(4096), T.int64(4096) - ): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads( - lv41[v_i0, v_i1, v_k], decode_intermediate_intermediate[v_i2, v_k] - ) - T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] - * decode_intermediate_intermediate[v_i2, v_k] - ) - - -ir_module = Module -func = ir_module["dequantize_gemm"] -target = tvm.target.Target("nvidia/nvidia-a100") -arch = CUDA(target) -policy = DefaultPolicy(func=func, arch=arch) - -tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) -if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - -configs = policy.emit_config(20) -print(configs) -sch = bitblas.gpu.MatmulTensorizationMMAWithDequantizeInfo().apply_config( - func, configs[0] -) -# cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) -# print( -# "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( -# cpresults[0].latency * 1e3 -# ) -# ) -# print( -# "[BitBLAS] The best latency of top 20 is {:.3f} ms".format( -# best.latency * 1e3 -# ) -# ) -# with open("debug/tmp.cu", "w") as f: -# f.write(str(best.code)) diff --git a/testing/python/tir_expr/test_tir.py b/testing/python/tir_expr/test_tir.py deleted file mode 100644 index 819444d0f..000000000 --- a/testing/python/tir_expr/test_tir.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# Metadata omitted. Use show_meta=True in script() method to show it. -from tvm.script import ir as I -from tvm.script import tir as T - -@I.ir_module -class Module: - @T.prim_func - def main(A: T.Buffer((1, 16384), "float16"), B: T.Buffer((16384, 8192), "int8"), Scale: T.Buffer((16384, 512), "float16"), D: T.Buffer((1, 16384), "float16")): - T.func_attr({"dequantize_info": {"B": {"decode_block": "B_decode", "fast_decoding": T.bool(True), "group_size": 32, "source_format": {"bits": 4, "format": "uint"}, "target_format": "float16", "with_scaling": T.bool(True)}}, "tir.noalias": T.bool(True)}) - # with T.block("root"): - B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") - A_local = T.alloc_buffer((1, 16384), "float16", scope="local") - B_local = T.alloc_buffer((16384, 8192), "int8", scope="local") - C_local = T.alloc_buffer((1, 16384), "float16", scope="local") - for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"): - for ax0_1 in T.thread_binding(2, thread="threadIdx.y"): - for ax1_0 in range(32): - for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): - for ax0 in range(1): - for ax1 in T.vectorized(4): - with T.block("B_local"): - v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) - v1 = T.axis.spatial(8192, ax1_0 * 256 + ax1_1 * 4 + ax1) - T.reads(B[v0, v1]) - T.writes(B_local[v0, v1]) - B_local[v0, v1] = B[v0, v1] - for ax0 in range(1): - with T.block("B_decode_local_o"): - v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) - v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1) - T.reads(B_local[v0_o, v1_o * 4:v1_o * 4 + 4], Scale[v0_o, v1_o // 4]) - T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8]) - Compressed = T.match_buffer(B_local[v0_o, v1_o * 4:v1_o * 4 + 4], (4,), "int8", scope="local") - Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local") - # Scale_1 = T.match_buffer(Scale[v0_o, v1_o // 4: v1_o // 4 + 1], (1,), "float16") - Scale_1 = T.match_buffer(Scale[v0_o, v1_o // 4], (1,), "float16", elem_offset=Scale.elem_offset) - T.call_extern("handle", "decode_i4s_to_f16_scale", Compressed.data, Decompressed.data, Scale_1.access_ptr("r"), 8) - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("A_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1) - T.reads(A[v0, v1]) - T.writes(A_local[v0, v1]) - A_local[v0, v1] = A[v0, v1] - for ax1_2 in range(8): - with T.block("C"): - v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1) - v1 = T.axis.reduce(16384, ax1_0 * 512 + ax1_1 * 8 + ax1_2) - T.reads(A_local[0, v1], B_decode_local[v0, v1]) - T.writes(C_local[0, v0]) - with T.init(): - C_local[0, v0] = T.float16(0) - C_local[0, v0] = C_local[0, v0] + A_local[0, v1] * B_decode_local[v0, v1] - for ax0, ax1 in T.grid(1, 1): - with T.block("C_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax1) - T.reads(C_local[v0, v1]) - T.writes(D[0, v1]) - D[0, v1] = C_local[v0, v1] - - -import tvm -mod = Module -sch = tvm.tir.Schedule(mod, debug_mask="all") -with tvm.transform.PassContext( - config={"tir.use_async_copy": True} - ): - dense_relu_0_rt_mod = tvm.build(sch.mod, target="cuda") -with open("debug/after_memory_rewrite.cu", "+w") as f: - f.write(dense_relu_0_rt_mod.imported_modules[0].get_source()) diff --git a/testing/python/tir_expr/test_tir_0.py b/testing/python/tir_expr/test_tir_0.py deleted file mode 100644 index bd33ce8d9..000000000 --- a/testing/python/tir_expr/test_tir_0.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.tir.tensor_intrin.cuda import get_mma_intrin_group - -@I.ir_module -class Module: - @T.prim_func - def main(A: T.Buffer((1024, 512, 16, 32), "int8"), B: T.Buffer((1024, 512, 16, 8), "int8"), C: T.Buffer((16384, 16384), "int32")): - T.func_attr({"dequantize_info": {"B": {"decode_block": "B_decode", "fast_decoding": T.bool(False), "source_format": {"bits": 2, "format": "int"}, "target_format": "int8"}}, "dlight.tensorcore_prenormlized": T.bool(True), "smooth_a": T.bool(True), "smooth_b": T.bool(True), "tir.noalias": T.bool(True)}) - # with T.block("root"): - A_reindex_reindex_shared = T.alloc_buffer((1, 1024, 512, 16, 32), "int8", scope="shared") - B_reindex_reindex_shared = T.alloc_buffer((1, 1024, 512, 16, 32), "int8", scope="shared") - B_reindex_reindex_local = T.alloc_buffer((1, 1024, 512, 16, 32), "int8", scope="local") - B_local = T.alloc_buffer((1024, 512, 16, 8), "int8", scope="local") - B_shared = T.alloc_buffer((1024, 512, 16, 8), "int8", scope="shared") - A_reindex_reindex_shared_warp = T.alloc_buffer((1, 1024, 512, 32, 16), "int8", scope="warp") - B_reindex_reindex_shared_warp = T.alloc_buffer((1, 1024, 512, 32, 16), "int8", scope="warp") - C_reindex_shared = T.alloc_buffer((1, 1024, 1024, 16, 16), "int32", scope="shared") - C_reindex_shared_warp = T.alloc_buffer((1, 1024, 1024, 32, 8), "int32", scope="warp") - for ax0 in range(1): - for ax1_0_0_ax2_0_0_fused in T.thread_binding(64, thread="blockIdx.y"): - for ax1_0_1_ax2_0_1_fused in T.thread_binding(256, thread="blockIdx.x"): - for ax1_0_2 in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_2 in T.thread_binding(2, thread="threadIdx.z"): - for ax1_0_3_init, ax2_0_3_init in T.grid(8, 2): - with T.block("C_o_init"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax1_0_3_init) - v2_o = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax2_0_3_init) - T.reads() - T.writes(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8]) - with T.block("C_init_o"): - v1_i_init_o = T.axis.spatial(1, 0) - v2_i_init_o = T.axis.spatial(1, 0) - T.reads() - T.writes(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8]) - C_warp = T.match_buffer(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8], (32, 8), "int32", scope="warp", offset_factor=1) - for tx in T.thread_binding(32, thread="threadIdx.x"): - T.mma_fill("int32", 8, C_warp.data, C_warp.elem_offset) - for ax3_0_0 in T.serial(256, annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2, 3], "software_pipeline_stage": [0, 0, 1, 1]}): - for ax0_ax1_ax2_ax3_ax4_fused_0 in T.thread_binding(2, thread="threadIdx.y"): - for ax0_ax1_ax2_ax3_ax4_fused_1 in T.thread_binding(2, thread="threadIdx.z"): - for ax0_ax1_ax2_ax3_ax4_fused_2 in T.unroll(8, annotations={"pragma_unroll_explicit": 0}): - for ax0_ax1_ax2_ax3_ax4_fused_3 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_ax4_fused_4 in T.vectorized(16): - with T.block("A_reindex_reindex_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + (ax0_ax1_ax2_ax3_ax4_fused_0 * 8192 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4096 + ax0_ax1_ax2_ax3_ax4_fused_2 * 512 + ax0_ax1_ax2_ax3_ax4_fused_3 * 16 + ax0_ax1_ax2_ax3_ax4_fused_4) // 1024) - v2 = T.axis.spatial(512, ax3_0_0 * 2 + (ax0_ax1_ax2_ax3_ax4_fused_0 * 8192 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4096 + ax0_ax1_ax2_ax3_ax4_fused_2 * 512 + ax0_ax1_ax2_ax3_ax4_fused_3 * 16 + ax0_ax1_ax2_ax3_ax4_fused_4) % 1024 // 512) - v3 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_ax4_fused_0 * 8192 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4096 + ax0_ax1_ax2_ax3_ax4_fused_2 * 512 + ax0_ax1_ax2_ax3_ax4_fused_3 * 16 + ax0_ax1_ax2_ax3_ax4_fused_4) % 512 // 32) - v4 = T.axis.spatial(32, (ax0_ax1_ax2_ax3_ax4_fused_0 * 8192 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4096 + ax0_ax1_ax2_ax3_ax4_fused_2 * 512 + ax0_ax1_ax2_ax3_ax4_fused_3 * 16 + ax0_ax1_ax2_ax3_ax4_fused_4) % 32) - T.reads(A[v1, v2, v3, v4]) - T.writes(A_reindex_reindex_shared[v0, v1, v2, v3, v4]) - T.block_attr({"permuted_layout": 0}) - A_reindex_reindex_shared[v0, v1, v2, v3, v4] = A[v1, v2, v3, v4] - for ax0_ax1_ax2_ax3_fused_0 in T.unroll(1, annotations={"pragma_unroll_explicit": 0}): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(2, thread="threadIdx.z"): - for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(2, thread="threadIdx.y"): - for ax0_ax1_ax2_ax3_fused_3 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_4 in T.vectorized(16): - with T.block("B_shared"): - v0 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + (ax0_ax1_ax2_ax3_fused_0 * 2048 + ax0_ax1_ax2_ax3_fused_1 * 1024 + ax0_ax1_ax2_ax3_fused_2 * 512 + ax0_ax1_ax2_ax3_fused_3 * 16 + ax0_ax1_ax2_ax3_fused_4) // 256) - v1 = T.axis.spatial(512, ax3_0_0 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 2048 + ax0_ax1_ax2_ax3_fused_1 * 1024 + ax0_ax1_ax2_ax3_fused_2 * 512 + ax0_ax1_ax2_ax3_fused_3 * 16 + ax0_ax1_ax2_ax3_fused_4) % 256 // 128) - v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_fused_0 * 2048 + ax0_ax1_ax2_ax3_fused_1 * 1024 + ax0_ax1_ax2_ax3_fused_2 * 512 + ax0_ax1_ax2_ax3_fused_3 * 16 + ax0_ax1_ax2_ax3_fused_4) % 128 // 8) - v3 = T.axis.spatial(8, (ax0_ax1_ax2_ax3_fused_0 * 2048 + ax0_ax1_ax2_ax3_fused_1 * 1024 + ax0_ax1_ax2_ax3_fused_2 * 512 + ax0_ax1_ax2_ax3_fused_3 * 16 + ax0_ax1_ax2_ax3_fused_4) % 8) - T.where((((ax0_ax1_ax2_ax3_fused_0 * 2 + ax0_ax1_ax2_ax3_fused_1) * 2 + ax0_ax1_ax2_ax3_fused_2) * 32 + ax0_ax1_ax2_ax3_fused_3) * 16 + ax0_ax1_ax2_ax3_fused_4 < 1024) - T.reads(B[v0, v1, v2, v3]) - T.writes(B_shared[v0, v1, v2, v3]) - B_shared[v0, v1, v2, v3] = B[v0, v1, v2, v3] - for ax0_1, ax1_ax2_ax3_ax4_0_fused_0 in T.grid(1, 2): - for ax1_ax2_ax3_ax4_0_fused_1 in T.thread_binding(2, thread="threadIdx.y"): - for ax1_ax2_ax3_ax4_0_fused_2 in T.thread_binding(2, thread="threadIdx.z"): - for ax1_ax2_ax3_ax4_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): - for ax4_1 in range(1): - for ax0_2, ax1, ax2 in T.grid(1, 1, 1): - for ax3 in T.vectorized(4): - with T.block("B_local"): - v0 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) // 64 + ax0_2) - v1 = T.axis.spatial(512, ax3_0_0 * 2 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 64 // 32 + ax1) - v2 = T.axis.spatial(16, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 32 // 2 + ax2) - v3 = T.axis.spatial(8, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 2 * 4 + ax3) - T.reads(B_shared[v0, v1, v2, v3]) - T.writes(B_local[v0, v1, v2, v3]) - B_local[v0, v1, v2, v3] = B_shared[v0, v1, v2, v3] - for ax0_2, ax1, ax2, ax3, ax4 in T.grid(1, 1, 1, 1, 16): - with T.block("B_reindex_reindex_local"): - v0 = T.axis.spatial(1, ax0_2) - v1 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) // 64 + ax1) - v2 = T.axis.spatial(512, ax3_0_0 * 2 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 64 // 32 + ax2) - v3 = T.axis.spatial(16, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 32 // 2 + ax3) - v4 = T.axis.spatial(32, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 2 * 16 + ax4) - T.reads(B_local[v1, v2, v3, v4 // 4]) - T.writes(B_reindex_reindex_local[v0, v1, v2, v3, v4]) - B_reindex_reindex_local[v0, v1, v2, v3, v4] = T.bitwise_and(T.shift_right(B_local[v1, v2, v3, v4 // 4], T.Cast("int8", v4 % 4 * 2)), T.int8(3)) - for ax4_2 in T.vectorized(16): - with T.block("B_reindex_reindex_shared"): - v0 = T.axis.spatial(1, ax0_1) - v1 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) // 64) - v2 = T.axis.spatial(512, ax3_0_0 * 2 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 64 // 32) - v3 = T.axis.spatial(16, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 32 // 2) - v4 = T.axis.spatial(32, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 2 * 16 + ax4_1 * 16 + ax4_2) - T.reads(B_reindex_reindex_local[v0, v1, v2, v3, v4]) - T.writes(B_reindex_reindex_shared[v0, v1, v2, v3, v4]) - T.block_attr({"permuted_layout": 0}) - B_reindex_reindex_shared[v0, v1, v2, v3, v4] = B_reindex_reindex_local[v0, v1, v2, v3, v4] - for ax3_0_1 in range(2): - for ax0_1, ax1, ax2, ax3_0, ax4_0 in T.grid(1, 8, 1, 1, 1): - with T.block("A_reindex_reindex_shared_warp_o"): - v0_o = T.axis.spatial(1, ax0_1) - v1_o = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax1) - v2_o = T.axis.spatial(512, ax3_0_0 * 2 + ax3_0_1 + ax2) - v3_o, v4_o = T.axis.remap("SS", [ax3_0, ax4_0]) - T.reads(A_reindex_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:32]) - T.writes(A_reindex_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:16]) - T.block_attr({"permuted_layout": 0}) - warp = T.match_buffer(A_reindex_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:16], (32, 16), "int8", scope="warp", offset_factor=32) - shared = T.match_buffer(A_reindex_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:32], (16, 32), "int8", strides=("shared_s0", "shared_s1"), scope="shared", offset_factor=32) - for tx in T.thread_binding(32, thread="threadIdx.x"): - T.ptx_ldmatrix("int8", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + 16 * tx, T.tvm_access_ptr(T.type_annotation("int8"), shared.data, shared.elem_offset, shared.strides[0] * 16, 1), tx * 16) - for ax0_1, ax1, ax2, ax3_0, ax4_0 in T.grid(1, 2, 1, 1, 1): - with T.block("B_reindex_reindex_shared_warp_o"): - v0_o = T.axis.spatial(1, ax0_1) - v1_o = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax1) - v2_o = T.axis.spatial(512, ax3_0_0 * 2 + ax3_0_1 + ax2) - v3_o, v4_o = T.axis.remap("SS", [ax3_0, ax4_0]) - T.reads(B_reindex_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:32]) - T.writes(B_reindex_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:16]) - T.block_attr({"permuted_layout": 0}) - warp = T.match_buffer(B_reindex_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:16], (32, 16), "int8", scope="warp", offset_factor=32) - shared = T.match_buffer(B_reindex_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:32], (16, 32), "int8", strides=("shared_s0", "shared_s1"), scope="shared", offset_factor=32) - for tx in T.thread_binding(32, thread="threadIdx.x"): - T.ptx_ldmatrix("int8", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + 16 * tx, T.tvm_access_ptr(T.type_annotation("int8"), shared.data, shared.elem_offset, shared.strides[0] * 16, 1), tx * 16) - for ax1_0_3, ax2_0_3 in T.grid(8, 2): - with T.block("C_o_update"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax1_0_3) - v2_o = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax2_0_3) - v3_o = T.axis.reduce(512, ax3_0_0 * 2 + ax3_0_1) - T.reads(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8], A_reindex_reindex_shared_warp[0, v1_o, v3_o, 0:32, 0:16], B_reindex_reindex_shared_warp[0, v2_o, v3_o, 0:32, 0:16]) - T.writes(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8]) - with T.block("C_o"): - v1_i_o = T.axis.spatial(1, 0) - v2_i_o = T.axis.spatial(1, 0) - v3_i_o = T.axis.reduce(1, 0) - T.reads(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8], A_reindex_reindex_shared_warp[0, v1_o, v3_o, 0:32, 0:16], B_reindex_reindex_shared_warp[0, v2_o, v3_o, 0:32, 0:16]) - T.writes(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8]) - A_1 = T.match_buffer(A_reindex_reindex_shared_warp[0, v1_o, v3_o, 0:32, 0:16], (32, 16), "int8", scope="warp", offset_factor=32) - B_1 = T.match_buffer(B_reindex_reindex_shared_warp[0, v2_o, v3_o, 0:32, 0:16], (32, 16), "int8", scope="warp", offset_factor=32) - C_1 = T.match_buffer(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8], (32, 8), "int32", scope="warp", offset_factor=16) - for tx in T.thread_binding(32, thread="threadIdx.x"): - T.ptx_mma("int32", "m16n8k32", "row", "col", "int8", "int8", "int32", A_1.data, A_1.elem_offset + tx * 16, B_1.data, B_1.elem_offset + tx * 16, C_1.data, C_1.elem_offset + tx * 8, T.bool(False)) - T.ptx_mma("int32", "m16n8k32", "row", "col", "int8", "int8", "int32", A_1.data, A_1.elem_offset + tx * 16, B_1.data, B_1.elem_offset + tx * 16 + 8, C_1.data, C_1.elem_offset + tx * 8 + 4, T.bool(False)) - for ax0_1, ax1 in T.grid(8, 2): - for ax2_0, ax3_0 in T.grid(1, 1): - with T.block("C_reindex_shared_warp_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax0_1) - v2_o = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax1) - v3_o, v4_o = T.axis.remap("SS", [ax2_0, ax3_0]) - T.reads(C_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:8]) - T.writes(C_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:16]) - C_warp = T.match_buffer(C_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:8], (32, 8), "int32", scope="warp", offset_factor=1) - C_1 = T.match_buffer(C_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared", offset_factor=1) - for tx in T.thread_binding(32, thread="threadIdx.x"): - T.mma_store("int32", 16, 16, T.tvm_access_ptr(T.type_annotation("int32"), C_1.data, C_1.elem_offset, C_1.strides[0] * 16, 2), C_warp.data, C_warp.elem_offset, C_1.strides[0]) - for ax0_ax1_ax2_ax3_ax4_fused_0 in T.unroll(2, annotations={"pragma_unroll_explicit": 0}): - for ax0_ax1_ax2_ax3_ax4_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_ax4_fused_2 in T.vectorized(4): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax0_1) - v2 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax1) - v3 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_ax4_fused_0 * 128 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4 + ax0_ax1_ax2_ax3_ax4_fused_2) // 16) - v4 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_ax4_fused_0 * 128 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4 + ax0_ax1_ax2_ax3_ax4_fused_2) % 16) - T.reads(C_reindex_shared[v0, v1, v2, v3, v4]) - T.writes(C[v3 + v1 * 16, v4 + v2 * 16]) - C[v3 + v1 * 16, v4 + v2 * 16] = C_reindex_shared[v0, v1, v2, v3, v4] - -mod = Module -sch = tvm.tir.Schedule(mod, debug_mask="all") -with tvm.transform.PassContext( - config={"tir.use_async_copy": True} - ): - dense_relu_0_rt_mod = tvm.build(sch.mod, target="cuda") -with open("after_memory_rewrite.cu", "+w") as f: - f.write(dense_relu_0_rt_mod.imported_modules[0].get_source()) diff --git a/testing/python/tir_expr/test_tir_1.py b/testing/python/tir_expr/test_tir_1.py deleted file mode 100644 index 49d1f71ec..000000000 --- a/testing/python/tir_expr/test_tir_1.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.tir.tensor_intrin.cuda import * - -# from tvm.script import tir as T -@T.prim_func -def main(input0: T.Buffer[(1024, 512, 16, 32), "int8"], input1: T.Buffer[(1024, 512, 16, 8), "int8"], output0: T.Buffer[(16384, 16384), "int8"]): - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # var definition - tx = T.env_thread("threadIdx.x") - C_s0 = T.var("int32") - C_s1 = T.var("int32") - shared_s0 = T.var("int32") - shared_s0_1 = T.var("int32") - shared_s1 = T.var("int32") - shared_s1_1 = T.var("int32") - # body - # with T.block("root") - input0_shared = T.alloc_buffer([1024, 512, 16, 32], dtype="int8", scope="shared") - mediate0_shared = T.alloc_buffer([1024, 512, 16, 32], dtype="int8", scope="shared") - mediate1_shared = T.alloc_buffer([1024, 1024, 16, 16], dtype="int32", scope="shared") - mediate1_shared_warp = T.alloc_buffer([1024, 1024, 32, 8], dtype="int32", scope="warp") - mediate0_local = T.alloc_buffer([1024, 512, 16, 32], dtype="int8", scope="local") - input1_shared = T.alloc_buffer([1024, 512, 16, 8], dtype="int8", scope="shared") - input1_shared_local = T.alloc_buffer([1024, 512, 16, 8], dtype="int8", scope="local") - input0_shared_warp = T.alloc_buffer([1024, 512, 32, 16], dtype="int8", scope="warp") - mediate0_shared_warp = T.alloc_buffer([1024, 512, 32, 16], dtype="int8", scope="warp") - for i_0 in T.thread_binding(256, thread="blockIdx.y"): - for j_0 in T.thread_binding(64, thread="blockIdx.x"): - for i_1 in T.thread_binding(2, thread="threadIdx.y"): - for j_1 in T.thread_binding(2, thread="threadIdx.z"): - for i_2_init in T.serial(2, annotations={"pragma_unroll_explicit":0, "thread_rasterization":10}): - for j_2_init in T.serial(8, annotations={"pragma_unroll_explicit":0}): - with T.block("mediate1_init_o"): - v_i = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + i_2_init) - v_j = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + j_2_init) - v_ii_o = T.axis.spatial(1, 0) - v_jj_o = T.axis.spatial(1, 0) - T.reads() - T.writes(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8]) - C_warp = T.match_buffer(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8], [32, 8], dtype="int32", scope="warp", offset_factor=1) - T.launch_thread(tx, 32) - T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32") - for k_0 in T.serial(256, annotations={"software_pipeline_async_stages":[0], "software_pipeline_order":[0, 1, 2, 3], "software_pipeline_stage":[0, 0, 1, 1]}): - for ax0_ax1_ax2_ax3_0_fused_0 in T.unroll(2, annotations={"pragma_unroll_explicit":0}): - for ax0_ax1_ax2_ax3_0_fused_1 in T.thread_binding(2, thread="threadIdx.y"): - for ax0_ax1_ax2_ax3_0_fused_2 in T.thread_binding(2, thread="threadIdx.z"): - for ax0_ax1_ax2_ax3_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): - for ax3_1 in T.vectorized(16): - with T.block("input0_shared"): - v0 = T.axis.spatial(1024, i_0 * 4 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) // 64) - v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 64 // 32) - v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 32 // 2) - v3 = T.axis.spatial(32, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 2 * 16 + ax3_1) - T.reads(input0[v0, v1, v2, v3]) - T.writes(input0_shared[v0, v1, v2, v3]) - input0_shared[v0, v1, v2, v3] = input0[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_fused_0_0_0_0 in T.serial(2): - for ax0_ax1_ax2_ax3_fused_0_0_0_1 in T.thread_binding(2, thread="threadIdx.z"): - for ax0_ax1_ax2_ax3_fused_0_0_1 in T.thread_binding(2, thread="threadIdx.y"): - for ax0_ax1_ax2_ax3_fused_0_1 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(16): - with T.block("input1_shared"): - v0 = T.axis.spatial(1024, j_0 * 16 + (ax0_ax1_ax2_ax3_fused_0_0_0_0 * 2048 + ax0_ax1_ax2_ax3_fused_0_0_0_1 * 1024 + ax0_ax1_ax2_ax3_fused_0_0_1 * 512 + ax0_ax1_ax2_ax3_fused_0_1 * 16 + ax0_ax1_ax2_ax3_fused_1) // 256) - v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_fused_0_0_0_0 * 2048 + ax0_ax1_ax2_ax3_fused_0_0_0_1 * 1024 + ax0_ax1_ax2_ax3_fused_0_0_1 * 512 + ax0_ax1_ax2_ax3_fused_0_1 * 16 + ax0_ax1_ax2_ax3_fused_1) % 256 // 128) - v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_fused_0_0_0_0 * 2048 + ax0_ax1_ax2_ax3_fused_0_0_0_1 * 1024 + ax0_ax1_ax2_ax3_fused_0_0_1 * 512 + ax0_ax1_ax2_ax3_fused_0_1 * 16 + ax0_ax1_ax2_ax3_fused_1) % 128 // 8) - v3 = T.axis.spatial(8, (ax0_ax1_ax2_ax3_fused_0_0_0_0 * 2048 + ax0_ax1_ax2_ax3_fused_0_0_0_1 * 1024 + ax0_ax1_ax2_ax3_fused_0_0_1 * 512 + ax0_ax1_ax2_ax3_fused_0_1 * 16 + ax0_ax1_ax2_ax3_fused_1) % 8) - T.reads(input1[v0, v1, v2, v3]) - T.writes(input1_shared[v0, v1, v2, v3]) - input1_shared[v0, v1, v2, v3] = input1[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_0_fused_0 in T.serial(8): - for ax0_ax1_ax2_ax3_0_fused_1 in T.thread_binding(2, thread="threadIdx.y"): - for ax0_ax1_ax2_ax3_0_fused_2 in T.thread_binding(2, thread="threadIdx.z"): - for ax0_ax1_ax2_ax3_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): - for ax3_1 in T.serial(1): - for ax0 in T.vectorized(4): - with T.block("input1_shared_local"): - v0 = T.axis.spatial(1024, j_0 * 16 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) // 64) - v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 64 // 32) - v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 32 // 2) - v3 = T.axis.spatial(8, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 2 * 4 + ax0) - T.reads(input1_shared[v0, v1, v2, v3]) - T.writes(input1_shared_local[v0, v1, v2, v3]) - input1_shared_local[v0, v1, v2, v3] = input1_shared[v0, v1, v2, v3] - for ax0 in T.serial(16): - with T.block("mediate0_local"): - v0 = T.axis.spatial(1024, j_0 * 16 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) // 64) - v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 64 // 32) - v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 32 // 2) - v3 = T.axis.spatial(32, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 2 * 16 + ax0) - T.reads(input1_shared_local[v0, v1, v2, v3 // 4]) - T.writes(mediate0_local[v0, v1, v2, v3]) - mediate0_local[v0, v1, v2, v3] = T.bitwise_and(T.shift_right(input1_shared_local[v0, v1, v2, v3 // 4], T.Cast("int8", v3 % 4), dtype="int8"), T.int8(1), dtype="int8") - for ax3_2 in T.vectorized(16): - with T.block("mediate0_shared"): - v0 = T.axis.spatial(1024, j_0 * 16 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) // 64) - v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 64 // 32) - v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 32 // 2) - v3 = T.axis.spatial(32, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 2 * 16 + ax3_1 * 16 + ax3_2) - T.reads(mediate0_local[v0, v1, v2, v3]) - T.writes(mediate0_shared[v0, v1, v2, v3]) - mediate0_shared[v0, v1, v2, v3] = mediate0_local[v0, v1, v2, v3] - for k_1 in T.serial(2): - for ax0, ax1 in T.grid(2, 1): - with T.block("input0_shared_warp_o"): - v0 = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + ax0) - v1 = T.axis.spatial(512, ax1 * 512 + k_0 * 2 + k_1) - v2_o = T.axis.spatial(1, 0) - v3_o = T.axis.spatial(1, 0) - T.reads(input0_shared[v0, v1, 0 : 16, 0 : 32]) - T.writes(input0_shared_warp[v0, v1, 0 : 32, 0 : 16]) - warp = T.match_buffer(input0_shared_warp[v0, v1, 0 : 32, 0 : 16], [32, 16], dtype="int8", scope="warp", offset_factor=16) - shared = T.match_buffer(input0_shared[v0, v1, 0 : 16, 0 : 32], [16, 32], dtype="int8", strides=[shared_s0, shared_s1], scope="shared", offset_factor=16) - T.launch_thread(tx, 32) - T.ptx_ldmatrix(False, 4, ".b16", warp.data, warp.elem_offset + 16 * tx, T.tvm_access_ptr(T.type_annotation(dtype="int8"), shared.data, shared.elem_offset, shared_s0 * 16, 1, dtype="handle"), 16 * tx, dtype="int8") - for ax0, ax1 in T.grid(8, 1): - with T.block("mediate0_shared_warp_o"): - v0 = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + ax0) - v1 = T.axis.spatial(512, ax1 * 512 + k_0 * 2 + k_1) - v2_o = T.axis.spatial(1, 0) - v3_o = T.axis.spatial(1, 0) - T.reads(mediate0_shared[v0, v1, 0 : 16, 0 : 32]) - T.writes(mediate0_shared_warp[v0, v1, 0 : 32, 0 : 16]) - warp_1 = T.match_buffer(mediate0_shared_warp[v0, v1, 0 : 32, 0 : 16], [32, 16], dtype="int8", scope="warp", offset_factor=16) - shared_1 = T.match_buffer(mediate0_shared[v0, v1, 0 : 16, 0 : 32], [16, 32], dtype="int8", strides=[shared_s0_1, shared_s1_1], scope="shared", offset_factor=16) - T.launch_thread(tx, 32) - T.ptx_ldmatrix(False, 4, ".b16", warp_1.data, warp_1.elem_offset + 16 * tx, T.tvm_access_ptr(T.type_annotation(dtype="int8"), shared_1.data, shared_1.elem_offset, shared_s0_1 * 16, 1, dtype="handle"), 16 * tx, dtype="int8") - for i_2, j_2 in T.grid(2, 8): - with T.block("mediate1_update_o"): - v_i = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + i_2) - v_j = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + j_2) - v_ii_o = T.axis.spatial(1, 0) - v_jj_o = T.axis.spatial(1, 0) - v_k = T.axis.reduce(512, k_0 * 2 + k_1) - v_kk_o = T.axis.reduce(1, 0) - T.reads(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8], input0_shared_warp[v_i, v_k, 0 : 32, 0 : 16], mediate0_shared_warp[v_j, v_k, 0 : 32, 0 : 16]) - T.writes(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8]) - A = T.match_buffer(input0_shared_warp[v_i, v_k, 0 : 32, 0 : 16], [32, 16], dtype="int8", scope="warp", offset_factor=16) - B = T.match_buffer(mediate0_shared_warp[v_j, v_k, 0 : 32, 0 : 16], [32, 16], dtype="int8", scope="warp", offset_factor=16) - C = T.match_buffer(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8], [32, 8], dtype="int32", scope="warp", offset_factor=16) - T.launch_thread(tx, 32) - T.ptx_mma("m16n8k32", "row", "col", "int8", "int8", "int32", A.data, A.elem_offset + tx * 16, B.data, B.elem_offset + tx * 16, C.data, C.elem_offset + tx * 8, False, dtype="int32") - T.ptx_mma("m16n8k32", "row", "col", "int8", "int8", "int32", A.data, A.elem_offset + tx * 16, B.data, B.elem_offset + tx * 16 + T.FloorDiv(16, 2), C.data, C.elem_offset + tx * 8 + T.FloorDiv(8, 2), False, dtype="int32") - for ax0, ax1 in T.grid(2, 8): - with T.block("mediate1_shared_warp_o"): - v0 = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + ax0) - v1 = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + ax1) - v2_o = T.axis.spatial(1, 0) - v3_o = T.axis.spatial(1, 0) - T.reads(mediate1_shared_warp[v0, v1, 0 : 32, 0 : 8]) - T.writes(mediate1_shared[v0, v1, 0 : 16, 0 : 16]) - C_warp_1 = T.match_buffer(mediate1_shared_warp[v0, v1, 0 : 32, 0 : 8], [32, 8], dtype="int32", scope="warp", offset_factor=1) - C_1 = T.match_buffer(mediate1_shared[v0, v1, 0 : 16, 0 : 16], [16, 16], dtype="int32", strides=[C_s0, C_s1], scope="shared", offset_factor=1) - T.launch_thread(tx, 32) - T.mma_store(16, 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_1.data, C_1.elem_offset, C_s0 * 16, 2, dtype="handle"), C_warp_1.data, C_warp_1.elem_offset, C_s0, dtype="int32") - for ax0_ax1_ax2_ax3_fused_0 in T.unroll(2, annotations={"pragma_unroll_explicit":0}): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(4): - with T.block("mediate1_shared"): - v0 = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + ax0) - v1 = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + ax1) - v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) // 16) - v3 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 16) - T.reads(mediate1_shared[v0, v1, v2, v3]) - T.writes(output0[v0 * 16 + v2, v1 * 16 + v3]) - output0[v2 + v0 * 16, v3 + v1 * 16] = T.Cast("int8", mediate1_shared[v0, v1, v2, v3]) - -mod = main -sch = tvm.tir.Schedule(mod, debug_mask="all") -with tvm.transform.PassContext( - config={"tir.use_async_copy": True} - ): - dense_relu_0_rt_mod = tvm.build(sch.mod, target="cuda") -with open("after_memory_rewrite.cu", "+w") as f: - f.write(dense_relu_0_rt_mod.imported_modules[0].get_source()) diff --git a/testing/python/tir_expr/test_tir_2.py b/testing/python/tir_expr/test_tir_2.py deleted file mode 100644 index d2a93aead..000000000 --- a/testing/python/tir_expr/test_tir_2.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.script import relax as R -import bitblas - - -@T.prim_func -def fused_fused_decode3_fused_NT_matmul8_add1( - lv47: T.Buffer((T.int64(256), T.int64(256), T.int64(16), T.int64(2)), "uint32"), - lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), - lv41: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - NT_matmul_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ), -): - T.func_attr( - { - "dequantize_info": { - "decode": { - "decode_block": "decode", - "fast_decoding": T.bool(False), - "group_size": 32, - "source_format": {"bits": 4, "format": "int"}, - "storage_dtype": "uint32", - "target_format": "float16", - "with_scaling": T.bool(True), - } - }, - "smooth_b": T.bool(True), - "tir.noalias": T.bool(True), - "transform_kind": 1, - } - ) - # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) - lv47_global = T.alloc_buffer((T.int64(4096), T.int64(512)), "uint32") - for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)): - with T.block("lv47_global"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads( - lv47[ - v0 // T.int64(16), - v1 // T.int64(2), - v0 % T.int64(16), - v1 % T.int64(2), - ] - ) - T.writes(lv47_global[v0, v1]) - lv47_global[v0, v1] = lv47[ - v0 // T.int64(16), - v1 // T.int64(2), - v0 % T.int64(16), - v1 % T.int64(2), - ] - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv47_global[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) - T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47_global[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] - for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv41[v_i0, v_i1, v_k], decode_intermediate_intermediate[v_i2, v_k]) - T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k] - ) - - -import tvm - -sch = bitblas.gpu.GEMV().apply( - fused_fused_decode3_fused_NT_matmul8_add1, tvm.target.Target("cuda"), False -) -print(sch) diff --git a/testing/python/tir_expr/test_tir_3.py b/testing/python/tir_expr/test_tir_3.py deleted file mode 100644 index 1a879859c..000000000 --- a/testing/python/tir_expr/test_tir_3.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.script import relax as R -import bitblas - - -@T.prim_func -def fused_fused_decode3_fused_NT_matmul8_add1( - lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), - lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), - p_lv41: T.handle, - p_output0: T.handle, -): - T.func_attr( - { - "dequantize_info": { - "decode": { - "decode_block": "decode", - "fast_decoding": T.bool(False), - "group_size": 32, - "source_format": {"bits": 4, "format": "int"}, - "storage_dtype": "uint32", - "target_format": "float16", - "with_scaling": T.bool(True), - } - }, - "tir.is_scheduled": 1, - "tir.noalias": T.bool(True), - } - ) - n = T.int64() - lv41 = T.match_buffer(p_lv41, (T.int64(1), T.int64(1), T.int64(4096)), "float16") - NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ) - # with T.block("root"): - decode_intermediate_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) - T.writes(decode_intermediate_intermediate[v_i, v_j]) - decode_intermediate_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv47[v_i, v_j // T.int64(8)], - T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv48[v_i, v_j // T.int64(32)] - for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads( - lv41[v_i0, v_i1, v_k], - decode_intermediate_intermediate[v_i2, v_k], - ) - T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k] - ) - - -import tvm -from bitblas.base.roller.policy import DefaultPolicy -from bitblas.base.roller.arch import CUDA - -func = fused_fused_decode3_fused_NT_matmul8_add1 -target = tvm.target.Target("nvidia/nvidia-a100") -arch = CUDA(target) -policy = DefaultPolicy(func=func, arch=arch) -configs = policy.emit_config(20) -print(configs) -sch = bitblas.gpu.gemv.GEMVWithDequantizeInfo().apply_config(func, configs[0]) -print(sch.mod) diff --git a/testing/python/type_conversion/test_lop3_type_conversion.py b/testing/python/type_conversion/test_lop3_type_conversion.py index 5964323b0..e434c8a95 100644 --- a/testing/python/type_conversion/test_lop3_type_conversion.py +++ b/testing/python/type_conversion/test_lop3_type_conversion.py @@ -7,7 +7,7 @@ from bitblas.base.roller.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.base.utils import apply_and_build -from bitblas.ops.matmul_impl import matmul_nt, matmul_nt_dequantize_b +from bitblas.ops.impl.matmul_impl import matmul_nt, matmul_nt_dequantize_b import numpy as np @@ -19,7 +19,7 @@ def test_f16_f16_gemm(): policy = DefaultPolicy(func=func, arch=arch) try: tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: + except Exception: tags = None if tags: policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) @@ -55,7 +55,7 @@ def test_f16_i4_gemm(M=1, N=16384, K=16384, bit=4, fast_decoding=True): policy = DefaultPolicy(func=func, arch=arch) try: tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except: + except Exception: tags = None if tags: policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)