From 54b5d3fb8e79611b0347b43c6a5075f48d30afc6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 9 Aug 2024 16:33:53 +0000 Subject: [PATCH] optimize cache. --- bitblas/benchmark/operator/__init__.py | 17 +++++++++++++++++ bitblas/cache/__init__.py | 1 + bitblas/gpu/matmul_mma_dequantize.py | 2 +- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index f59ca34ee..7c21d9d0c 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -9,6 +9,10 @@ from bitblas.utils import get_default_cache_path from bitblas import auto_detect_nvidia_target from bitblas import tvm as tvm +from bitblas.cache import OperatorCache +import logging + +logger = logging.getLogger(__name__) class BitblasOperatorBenchmarkBase(ABC): @@ -28,6 +32,9 @@ class BitblasOperatorBenchmarkBase(ABC): # Log path log_path: Optional[str] = path.join(get_default_cache_path(), "benchmark") + # Operator cache + operator_cache: OperatorCache = OperatorCache() + @abstractmethod def prepare_benchmark_sets(self): pass @@ -98,6 +105,14 @@ def run_benchmark( dynamic_profiling_shape: Optional[Dict[str, int]] = None, ) -> Optional[float]: """Run a single benchmark.""" + + if self.operator_cache.exists(config): + logger.info(f"Operator {config} found in cache") + op_inst = self.operator_cache.get(config) + latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape) + op_inst.cleanup() + return latency, None + op_inst = self.make_operator(operator, config) tuning_time = None @@ -106,6 +121,8 @@ def run_benchmark( op_inst.hardware_aware_finetune(topk=20, parallel_build=True) tuning_time = perf_counter() - start + self.operator_cache.add(config, op_inst) + latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape) op_inst.cleanup() diff --git a/bitblas/cache/__init__.py b/bitblas/cache/__init__.py index 0c8fd3b9c..ee522ec3f 100644 --- a/bitblas/cache/__init__.py +++ b/bitblas/cache/__init__.py @@ -6,4 +6,5 @@ load_global_ops_cache, # noqa: F401 get_database_path, # noqa: F401 set_database_path, # noqa: F401 + OperatorCache, # noqa: F401 ) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 7421cbd47..7c40d3243 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1186,7 +1186,7 @@ def sch_shared_memory_prefetch_with_config( """ weight_transform_kind = config.intrin_info.weight_transform_kind - if weight_transform_kind == TransformKind.LDMatrixTransform: + if weight_transform_kind == TransformKind.LDMatrixTransform and config.block_reduction_depth is not None: return self.sch_warp_memory_prefetch_with_config(func, config) is_cross_thread_reduce = (