-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor BitBLASLinear test module for improved readability and maint…
…ainability
- Loading branch information
1 parent
cffe3fd
commit a8bed74
Showing
3 changed files
with
113 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Dict, List, Tuple, Optional | ||
from bitblas.ops import Operator, OperatorConfig | ||
from bitblas import auto_detect_nvidia_target | ||
|
||
class BitblasOperatorBenchmarkBase(ABC): | ||
|
||
# separate benchmark sets for different operators set | ||
benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig]]] = {} | ||
|
||
# currently we only support nvidia target for benchmarking | ||
benchmark_target: str = auto_detect_nvidia_target() | ||
|
||
# benchmark results | ||
benchmark_results: Dict[str, List[Optional[float]]] = {} | ||
|
||
@abstractmethod | ||
def prepare_benchmark_sets(self): | ||
pass | ||
|
||
def add_benchmark_set(self, name:str, benchmark_set:List[Tuple[Operator, OperatorConfig]]): | ||
if name in self.benchmark_sets: | ||
self.benchmark_sets[name].extend(benchmark_set) | ||
else: | ||
self.benchmark_sets[name] = benchmark_set | ||
|
||
def run(self): | ||
self.prepare_benchmark_sets() | ||
self.benchmark() | ||
print("Benchmark results:", self.benchmark_results) | ||
self.report() | ||
self.cleanup() | ||
|
||
def report(self): | ||
return NotImplementedError | ||
|
||
def cleanup(self): | ||
# clean up the benchmark sets | ||
self.benchmark_sets.clear() | ||
|
||
def benchmark(self): | ||
for name, benchmark_set in self.benchmark_sets.items(): | ||
self.benchmark_results[name] = [] | ||
for operator, config in benchmark_set: | ||
self.benchmark_results[name].append(self.run_benchmark(operator, config)) | ||
|
||
def run_benchmark(self, operator:Operator, config:OperatorConfig) -> Optional[float]: | ||
op_inst = operator(config, target=self.benchmark_target) | ||
return op_inst.profile_latency() | ||
|
||
@abstractmethod | ||
def get_operator(self) -> Operator: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_operator_config(self) -> OperatorConfig: | ||
raise NotImplementedError | ||
|
||
def get_benchmark_sets(self, name:Optional[str]=None) -> List[Tuple[Operator, OperatorConfig]]: | ||
if name is None: | ||
return self.benchmark_sets | ||
else: | ||
assert name in self.benchmark_sets, f"Operator {name} not found in benchmark sets" | ||
return self.benchmark_sets[name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from base import BitblasOperatorBenchmarkBase | ||
from bitblas.ops import Matmul, MatmulConfig | ||
from bitblas import set_log_level | ||
|
||
set_log_level("DEBUG") | ||
|
||
class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): | ||
|
||
config_map = { | ||
"FP16xFP16_ACCFP16_NT": { | ||
"in_dtype": "float16", | ||
"out_dtype": "float16", | ||
"accum_dtype": "float16", | ||
} | ||
} | ||
|
||
def prepare_benchmark_sets(self): | ||
self.add_benchmark_set( | ||
"FP16xFP16_ACCFP16_NT", | ||
[ | ||
(Matmul, self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384)), | ||
], | ||
) | ||
|
||
def generate_operator_config(self, name:str, M, N, K) -> MatmulConfig: | ||
if name not in self.config_map: | ||
raise ValueError(f"Operator {name} not found in config map") | ||
return MatmulConfig( | ||
M=M, | ||
N=N, | ||
K=K, | ||
**self.config_map[name], | ||
) | ||
|
||
def get_operator(self): | ||
return Matmul | ||
|
||
def get_operator_config(self): | ||
return MatmulConfig | ||
|
||
if __name__ == "__main__": | ||
BitblasMatmulOpsBenchmark().run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters