Skip to content

Commit

Permalink
fix gpu model missing from tvm target remap (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qubitium authored Jun 26, 2024
1 parent 2634815 commit e1fa655
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
"where <target> is one of the available targets can be found in the output of `tools/get_available_targets.py`."
)

# Nvidia produces non-public oem gpu models that are part of drivers but not mapped to correct tvm target
# Remap list to match the oem model name to the closest public model name
NVIDIA_GPU_REMAP = {
"NVIDIA PG506-230": "NVIDIA A100",
"NVIDIA PG506-232": "NVIDIA A100",
}

def get_gpu_model_from_nvidia_smi(gpu_id: int = 0):
"""
Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU.
Expand Down Expand Up @@ -88,11 +95,9 @@ def auto_detect_nvidia_target(gpu_id: int = 0) -> str:
# Get the current GPU model and find the best matching target
gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id)

# TODO: move to a more res-usable device remapping util method
# compat: Nvidia makes several oem (non-public) versions of A100 and perhaps other models that
# do not have clearly defined TVM matching target so we need to manually map them to the correct one.
if gpu_model == "NVIDIA PG506-230":
gpu_model = "NVIDIA A100"
# Compat: remap oem devices to their correct non-oem model names for tvm target
if gpu_model in NVIDIA_GPU_REMAP:
gpu_model = NVIDIA_GPU_REMAP[gpu_model]

target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"
return target

0 comments on commit e1fa655

Please sign in to comment.