diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index cc1d9f075..0045934f3 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -162,16 +162,15 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]): object.__setattr__(self, "zeros_mode", "original") def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): - if ( + if fast_decoding is not None: + object.__setattr__(self, "fast_decoding", fast_decoding) + elif ( "int" not in self.W_dtype - or "nf" not in self.W_dtype or self.W_dtype == self.A_dtype ): object.__setattr__(self, "fast_decoding", False) else: object.__setattr__(self, "fast_decoding", True) - if fast_decoding is not None: - object.__setattr__(self, "fast_decoding", fast_decoding) def __post_init__(self): # set M to default dynamic range if it is None