Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Fix a but within FP8 E4M3 Fast Decoding #54

Merged
merged 29 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
75d2f3d
improve e4m3 decoding.
May 21, 2024
dd744d0
Merge branch 'main' of https://github.com/microsoft/BitBLAS into main
May 23, 2024
00bfa31
append fp16xint1
May 25, 2024
8cd8b10
Update submodule commit reference
Jun 1, 2024
9122ff7
chore: Update shared memory scope for float32 output dtype
Jun 1, 2024
b508acc
BUGFIX: UINT8/INT8 Decoding
Jun 2, 2024
58d55b7
feat: Add rasterization options for roller module
Jun 5, 2024
e7547ce
Refactor tensorcore_legalization method to optimize tensor core usage
Jun 5, 2024
678a2e1
feat: Add function to collect variables from expression, improve for …
Jun 5, 2024
3088b35
chore: Update typing import in __init__.py
Jun 5, 2024
5d206b3
chore: Refactor CPU execution of operators
Jun 5, 2024
e06ce10
Refactor matmul implementation for splitk layout
Jun 5, 2024
d67cc6d
Refactor matmul implementation for splitk layout
Jun 5, 2024
9e36b6d
Refactor matmul implementation for splitk layout
Jun 5, 2024
e1a0149
chore: Update version to 0.0.1.dev8
Jun 5, 2024
df0ed7a
chore: Enable debug output in bitblas.set_debug_level()
Jun 5, 2024
a0f651a
Refactor Linear module matmul implementation for splitk layout
Jun 5, 2024
88295a7
Refactor matmul implementation for splitk layout
Jun 5, 2024
3366dce
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lei/…
Jun 5, 2024
25b5c63
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 5, 2024
26a9f1b
Bumpt version to v0.0.1.dev9
Jun 5, 2024
251bf08
Merge branch 'main' of https://github.com/microsoft/BitBLAS into lei/…
Jun 5, 2024
e0cf62c
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 6, 2024
2e4e8dd
Bump version to v0.0.1.dev10
Jun 6, 2024
0dec7d8
Merge branch 'main' into lei/splitk
LeiWang1999 Jun 6, 2024
81f5b9a
Refactor CUDA kernel launch string for dynamic symbolic set
Jun 6, 2024
ec64f91
Merge branch 'lei/splitk' of https://github.com/LeiWang1999/MSBitBLAS…
Jun 6, 2024
5e71163
Bump version to v0.0.1.dev12 and add MatmulConfigWithSplitK and Matmu…
Jun 6, 2024
d0e0726
Merge branch 'main' into lei/splitk
LeiWang1999 Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.1.dev10
0.0.1.dev12
3 changes: 2 additions & 1 deletion python/bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401
from .module import Linear # noqa: F401

Expand Down Expand Up @@ -81,4 +82,4 @@ def _init_logger():

_init_logger()

__version__ = "0.0.1.dev10"
__version__ = "0.0.1.dev12"
15 changes: 12 additions & 3 deletions python/bitblas/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,23 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)


def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
prefix = tir.Select(s_f16 == 0, tir.const(0x2000, "uint16"), tir.const(0xc000, "uint16"))
e_f16 = (((val & tir.const(127, "uint16")) << tir.const(7, "uint16"))) | prefix
e4 = val & tir.const(0x40, "uint16")
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), tir.const(0x4000, "uint16"))
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix
return tir.reinterpret("float16", s_f16 | e_f16)

def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16")
return tir.reinterpret("float16", s_f16 | e_f16)

def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
Expand Down
4 changes: 4 additions & 0 deletions testing/python/operators/test_general_matmul_splitk_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def map_torch_type(intype):
matmul.forward(torch_a, torch_b, output=bitblas_out)
print("torch_ref_out", ref_out)
print("bitblas_out", bitblas_out)

matmul.forward(torch_a, torch_b, output=bitblas_out)
print("torch_ref_out", ref_out)
print("bitblas_out", bitblas_out)

torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1)

Expand Down
Loading