Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Improve TVM Target related items #45

Merged
merged 2 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/bitblas/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Licensed under the MIT License.
from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4 # noqa: F401
from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401
from .target_detector import auto_detect_nvidia_target # noqa: F401
from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401
18 changes: 16 additions & 2 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
# Licensed under the MIT License.

import subprocess
from typing import List
from thefuzz import process
from tvm.target import Target
from tvm.target.tag import list_tags

import logging

logger = logging.getLogger(__name__)

TARGET_MISSING_ERROR = (
"TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=<target>`, "
"where <target> is one of the available targets can be found in the output of `tools/get_available_targets.py`."
)

def get_gpu_model_from_nvidia_smi():
"""
Expand Down Expand Up @@ -41,13 +47,21 @@ def find_best_match(tags, query):
def check_target(best, default):
return best if Target(best).arch == Target(default).arch else default

if check_target(best_match, "cuda"):
if check_target(best_match, "cuda") == best_match:
return best_match if score >= MATCH_THRESHOLD else "cuda"
else:
logger.info(f"Best match '{best_match}' is not a valid CUDA target, falling back to 'cuda'")
logger.warning(TARGET_MISSING_ERROR)
return "cuda"


def get_all_nvidia_targets() -> List[str]:
"""
Returns all available NVIDIA targets.
"""
all_tags = list_tags()
return [tag for tag in all_tags if "nvidia" in tag]


def auto_detect_nvidia_target() -> str:
"""
Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target.
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ scipy
tornado
torch
thefuzz
tabulate
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ scipy
tornado
torch
thefuzz
tabulate
17 changes: 17 additions & 0 deletions tools/get_available_targets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas.utils import get_all_nvidia_targets
from tabulate import tabulate

def main():
# Get all available Nvidia targets
targets = get_all_nvidia_targets()

# Print available targets to console in a table format
table = [[i + 1, target] for i, target in enumerate(targets)]
headers = ["Index", "Target"]
print(tabulate(table, headers, tablefmt="pretty"))

if __name__ == "__main__":
main()
Loading