Skip to content

Commit

Permalink
optimize cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 9, 2024
1 parent 8b5f083 commit 54b5d3f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
17 changes: 17 additions & 0 deletions bitblas/benchmark/operator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from bitblas.utils import get_default_cache_path
from bitblas import auto_detect_nvidia_target
from bitblas import tvm as tvm
from bitblas.cache import OperatorCache
import logging

logger = logging.getLogger(__name__)


class BitblasOperatorBenchmarkBase(ABC):
Expand All @@ -28,6 +32,9 @@ class BitblasOperatorBenchmarkBase(ABC):
# Log path
log_path: Optional[str] = path.join(get_default_cache_path(), "benchmark")

# Operator cache
operator_cache: OperatorCache = OperatorCache()

@abstractmethod
def prepare_benchmark_sets(self):
pass
Expand Down Expand Up @@ -98,6 +105,14 @@ def run_benchmark(
dynamic_profiling_shape: Optional[Dict[str, int]] = None,
) -> Optional[float]:
"""Run a single benchmark."""

if self.operator_cache.exists(config):
logger.info(f"Operator {config} found in cache")
op_inst = self.operator_cache.get(config)
latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape)
op_inst.cleanup()
return latency, None

op_inst = self.make_operator(operator, config)
tuning_time = None

Expand All @@ -106,6 +121,8 @@ def run_benchmark(
op_inst.hardware_aware_finetune(topk=20, parallel_build=True)
tuning_time = perf_counter() - start

self.operator_cache.add(config, op_inst)

latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape)

op_inst.cleanup()
Expand Down
1 change: 1 addition & 0 deletions bitblas/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
load_global_ops_cache, # noqa: F401
get_database_path, # noqa: F401
set_database_path, # noqa: F401
OperatorCache, # noqa: F401
)
2 changes: 1 addition & 1 deletion bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def sch_shared_memory_prefetch_with_config(
"""

weight_transform_kind = config.intrin_info.weight_transform_kind
if weight_transform_kind == TransformKind.LDMatrixTransform:
if weight_transform_kind == TransformKind.LDMatrixTransform and config.block_reduction_depth is not None:
return self.sch_warp_memory_prefetch_with_config(func, config)

is_cross_thread_reduce = (
Expand Down

0 comments on commit 54b5d3f

Please sign in to comment.