Skip to content

Commit

Permalink
Refactor BitBLASLinear test module for improved readability and maint…
Browse files Browse the repository at this point in the history
…ainability
  • Loading branch information
LeiWang1999 committed Jul 23, 2024
1 parent cffe3fd commit a8bed74
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 1 deletion.
67 changes: 67 additions & 0 deletions benchmark/operators/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Optional
from bitblas.ops import Operator, OperatorConfig
from bitblas import auto_detect_nvidia_target

class BitblasOperatorBenchmarkBase(ABC):

# separate benchmark sets for different operators set
benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig]]] = {}

# currently we only support nvidia target for benchmarking
benchmark_target: str = auto_detect_nvidia_target()

# benchmark results
benchmark_results: Dict[str, List[Optional[float]]] = {}

@abstractmethod
def prepare_benchmark_sets(self):
pass

def add_benchmark_set(self, name:str, benchmark_set:List[Tuple[Operator, OperatorConfig]]):
if name in self.benchmark_sets:
self.benchmark_sets[name].extend(benchmark_set)
else:
self.benchmark_sets[name] = benchmark_set

def run(self):
self.prepare_benchmark_sets()
self.benchmark()
print("Benchmark results:", self.benchmark_results)
self.report()
self.cleanup()

def report(self):
return NotImplementedError

def cleanup(self):
# clean up the benchmark sets
self.benchmark_sets.clear()

def benchmark(self):
for name, benchmark_set in self.benchmark_sets.items():
self.benchmark_results[name] = []
for operator, config in benchmark_set:
self.benchmark_results[name].append(self.run_benchmark(operator, config))

def run_benchmark(self, operator:Operator, config:OperatorConfig) -> Optional[float]:
op_inst = operator(config, target=self.benchmark_target)
return op_inst.profile_latency()

@abstractmethod
def get_operator(self) -> Operator:
raise NotImplementedError

@abstractmethod
def get_operator_config(self) -> OperatorConfig:
raise NotImplementedError

def get_benchmark_sets(self, name:Optional[str]=None) -> List[Tuple[Operator, OperatorConfig]]:
if name is None:
return self.benchmark_sets
else:
assert name in self.benchmark_sets, f"Operator {name} not found in benchmark sets"
return self.benchmark_sets[name]
45 changes: 45 additions & 0 deletions benchmark/operators/benchmark_ops_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from base import BitblasOperatorBenchmarkBase
from bitblas.ops import Matmul, MatmulConfig
from bitblas import set_log_level

set_log_level("DEBUG")

class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase):

config_map = {
"FP16xFP16_ACCFP16_NT": {
"in_dtype": "float16",
"out_dtype": "float16",
"accum_dtype": "float16",
}
}

def prepare_benchmark_sets(self):
self.add_benchmark_set(
"FP16xFP16_ACCFP16_NT",
[
(Matmul, self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384)),
],
)

def generate_operator_config(self, name:str, M, N, K) -> MatmulConfig:
if name not in self.config_map:
raise ValueError(f"Operator {name} not found in config map")
return MatmulConfig(
M=M,
N=N,
K=K,
**self.config_map[name],
)

def get_operator(self):
return Matmul

def get_operator_config(self):
return MatmulConfig

if __name__ == "__main__":
BitblasMatmulOpsBenchmark().run()
2 changes: 1 addition & 1 deletion bitblas/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .operator import Operator # noqa: F401
from .operator import Operator, OperatorConfig # noqa: F401
from .matmul import Matmul, MatmulConfig # noqa: F401
from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401
from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401
Expand Down

0 comments on commit a8bed74

Please sign in to comment.