diff --git a/python/bitblas/utils/__init__.py b/python/bitblas/utils/__init__.py index f9587964c..416d6b1f2 100644 --- a/python/bitblas/utils/__init__.py +++ b/python/bitblas/utils/__init__.py @@ -2,4 +2,4 @@ # Licensed under the MIT License. from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4 # noqa: F401 from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 -from .target_detector import auto_detect_nvidia_target # noqa: F401 +from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 diff --git a/python/bitblas/utils/target_detector.py b/python/bitblas/utils/target_detector.py index 927e9f8e8..33bf70d0a 100644 --- a/python/bitblas/utils/target_detector.py +++ b/python/bitblas/utils/target_detector.py @@ -2,13 +2,19 @@ # Licensed under the MIT License. import subprocess +from typing import List from thefuzz import process from tvm.target import Target from tvm.target.tag import list_tags import logging + logger = logging.getLogger(__name__) +TARGET_MISSING_ERROR = ( + "TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, " + "where is one of the available targets can be found in the output of `tools/get_available_targets.py`." +) def get_gpu_model_from_nvidia_smi(): """ @@ -41,13 +47,21 @@ def find_best_match(tags, query): def check_target(best, default): return best if Target(best).arch == Target(default).arch else default - if check_target(best_match, "cuda"): + if check_target(best_match, "cuda") == best_match: return best_match if score >= MATCH_THRESHOLD else "cuda" else: - logger.info(f"Best match '{best_match}' is not a valid CUDA target, falling back to 'cuda'") + logger.warning(TARGET_MISSING_ERROR) return "cuda" +def get_all_nvidia_targets() -> List[str]: + """ + Returns all available NVIDIA targets. + """ + all_tags = list_tags() + return [tag for tag in all_tags if "nvidia" in tag] + + def auto_detect_nvidia_target() -> str: """ Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target. diff --git a/requirements-dev.txt b/requirements-dev.txt index 4fd416900..40906bc20 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -27,3 +27,4 @@ scipy tornado torch thefuzz +tabulate diff --git a/requirements.txt b/requirements.txt index e8257a571..27ae8420f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ scipy tornado torch thefuzz +tabulate diff --git a/tools/get_available_targets.py b/tools/get_available_targets.py new file mode 100644 index 000000000..2c2753d7a --- /dev/null +++ b/tools/get_available_targets.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas.utils import get_all_nvidia_targets +from tabulate import tabulate + +def main(): + # Get all available Nvidia targets + targets = get_all_nvidia_targets() + + # Print available targets to console in a table format + table = [[i + 1, target] for i, target in enumerate(targets)] + headers = ["Index", "Target"] + print(tabulate(table, headers, tablefmt="pretty")) + +if __name__ == "__main__": + main() \ No newline at end of file