Skip to content

Commit

Permalink
Split nvlink check into nvidia vs amd. Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TensorTemplar committed Oct 5, 2024
1 parent badc76a commit 09df0a7
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 135 deletions.
218 changes: 156 additions & 62 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,43 @@
import pickle
import re
import shutil
import subprocess
import sys
from dataclasses import asdict, is_dataclass
import warnings
from dataclasses import asdict
from dataclasses import is_dataclass
from io import BytesIO
from packaging import version
from pathlib import Path
import subprocess
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union
import warnings
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Literal
from typing import Mapping
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union

import lightning as L
import torch
import torch.nn as nn
import torch.utils._device
import yaml
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.loggers import TensorBoardLogger
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.cli import instantiate_class
from lightning.pytorch.loggers import WandbLogger
from packaging import version
from torch.serialization import normalize_storage_type
from typing_extensions import Self


if TYPE_CHECKING:
from litgpt import GPT, Config
from litgpt import Config
from litgpt import GPT


def init_out_dir(out_dir: Path) -> Path:
Expand Down Expand Up @@ -84,23 +96,25 @@ def reset_parameters(module: nn.Module) -> None:


def check_valid_checkpoint_dir(
checkpoint_dir: Path,
model_filename: str = "lit_model.pth",
verbose: bool = True,
raise_error: bool = False,
ignore_tokenizer_files: bool = False
) -> None:
checkpoint_dir: Path,
model_filename: str = "lit_model.pth",
verbose: bool = True,
raise_error: bool = False,
ignore_tokenizer_files: bool = False,
) -> None:

files = {
model_filename: (checkpoint_dir / model_filename).is_file(),
"model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
}
if not ignore_tokenizer_files:
files.update({
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or
(checkpoint_dir / "tokenizer.model").is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
})
files.update(
{
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file()
or (checkpoint_dir / "tokenizer.model").is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
}
)

if checkpoint_dir.is_dir():
if all(files.values()):
Expand Down Expand Up @@ -458,7 +472,9 @@ def copy_config_files(source_dir: Path, out_dir: Path) -> None:


def CLI(*args: Any, **kwargs: Any) -> Any:
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
from jsonargparse import CLI
from jsonargparse import set_config_read_mode
from jsonargparse import set_docstring_parse_options

set_docstring_parse_options(attribute_docstrings=True)
set_config_read_mode(urls_enabled=True)
Expand Down Expand Up @@ -539,15 +555,21 @@ def choose_logger(

def get_argument_names(cls):
sig = inspect.signature(cls.__init__)
return {name for name, param in sig.parameters.items()
if param.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]}
return {
name
for name, param in sig.parameters.items()
if param.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
}


def instantiate_bnb_optimizer(optimizer, model_parameters):
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")):
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
):
raise ValueError("The chosen quantization format only supports the AdamW optimizer.")

import bitsandbytes as bnb

if isinstance(optimizer, str):
optimizer = bnb.optim.PagedAdamW(model_parameters)
else:
Expand Down Expand Up @@ -594,10 +616,12 @@ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):

def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
new_checkpoint_dir = "checkpoints" / checkpoint_dir
should_return_new_dir = (not checkpoint_dir.is_dir() and
checkpoint_dir.parts[0] != "checkpoints" and
not checkpoint_dir.is_absolute() and
new_checkpoint_dir.exists())
should_return_new_dir = (
not checkpoint_dir.is_dir()
and checkpoint_dir.parts[0] != "checkpoints"
and not checkpoint_dir.is_absolute()
and new_checkpoint_dir.exists()
)
return new_checkpoint_dir if should_return_new_dir else checkpoint_dir


Expand All @@ -622,7 +646,9 @@ def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_fil

checkpoint_dir = extend_checkpoint_dir(Path(model_name))
try:
check_valid_checkpoint_dir(checkpoint_dir, verbose=False, raise_error=True, ignore_tokenizer_files=ignore_tokenizer_files)
check_valid_checkpoint_dir(
checkpoint_dir, verbose=False, raise_error=True, ignore_tokenizer_files=ignore_tokenizer_files
)
except FileNotFoundError as e:
if access_token is None:
access_token = os.getenv("HF_TOKEN")
Expand All @@ -637,52 +663,120 @@ def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_fil


def check_nvlink_connectivity(fabric=None):
"""Checks GPU connectivity for both NVIDIA and AMD GPUs.
This function delegates to vendor-specific implementations based on
the detected GPU vendor.
"""
if fabric is not None:
custom_print = fabric.print
else:
custom_print = print

# Only execute on the primary process
if os.getenv("RANK", "0") == "0":
try:
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)

if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return

lines = result.stdout.split('\n')
gpu_matrix = []

start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None) + 1
headers_line = lines[start_index - 1]
headers = headers_line.split()
# The regex is to avoid counting the "GPU NUMA ID" header as a GPU
# in headers like ['\x1b[4mGPU0', 'GPU1', 'GPU2', 'GPU3', 'GPU4', 'GPU5', 'GPU6', 'GPU7', 'NIC0', 'NIC1', 'NIC2', 'NIC3', 'NIC4', 'NIC5', 'NIC6', 'NIC7', 'NIC8', 'NIC9', 'CPU', 'Affinity', 'NUMA', 'Affinity', 'GPU', 'NUMA', 'ID\x1b[0m']
gpu_regex = re.compile(r'^GPU\d+$')
gpu_count = len([header for header in headers if gpu_regex.match(header)])

all_nvlink = True
for line in lines[start_index:start_index + gpu_count]:
gpu_matrix.append(line.strip())
connections = line.split()[1:1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break

if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
if torch.cuda.is_available():
device_properties = torch.cuda.get_device_properties(0)
gpu_name = device_properties.name.lower()
if "nvidia" in gpu_name:
_check_nvidia_connectivity(custom_print)
elif "advanced micro devices" in gpu_name or "amd" in gpu_name:
_check_amd_connectivity(custom_print)
else:
custom_print(f"Unrecognized GPU vendor: {device_properties.name}")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)

custom_print("No GPUs available")
except Exception as e:
custom_print(f"An error occurred: {e}")
custom_print(f"An error occurred while checking GPU connectivity: {e}")


def _check_nvidia_connectivity(custom_print):
"""Checks NVLink connectivity on NVIDIA GPUs."""
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return

lines = result.stdout.strip().split("\n")
start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None)
if start_index is None:
custom_print("Failed to parse nvidia-smi output")
return

headers_line = lines[start_index]
headers = headers_line.split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

all_nvlink = True
for line in lines[start_index + 1 : start_index + 1 + gpu_count]:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break

if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


def _check_amd_connectivity(custom_print):
"""Checks XGMI connectivity on AMD GPUs."""
result = subprocess.run(["rocm-smi", "--showtopotype"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run rocm-smi")
return

lines = result.stdout.strip().split("\n")
# Find the line that starts with "GPU0"
gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None)
if gpu_header_index is None or gpu_header_index == 0:
custom_print("Failed to parse rocm-smi output (no GPU headers found)")
return

header_line = lines[gpu_header_index - 1]
headers = header_line.strip().split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

# Collect GPU connection lines
gpu_lines = []
for line in lines[gpu_header_index : gpu_header_index + gpu_count]:
if re.match(r"^\s*GPU\d+", line):
gpu_lines.append(line.strip())
if len(gpu_lines) != gpu_count:
custom_print("Mismatch in GPU count when parsing rocm-smi output")
return

all_xgmi = True
for line in gpu_lines:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
for conn in connections:
if conn not in ("XGMI", "0"):
all_xgmi = False
break
if not all_xgmi:
break

if all_xgmi:
custom_print("All GPUs are fully connected via XGMI.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via XGMI. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


def fix_and_load_json(s):
# Remove trailing commas before } or ]
s = re.sub(r',(\s*[}\]])', r'\1', s)
s = re.sub(r",(\s*[}\]])", r"\1", s)

# Insert missing commas between properties
# Match positions where a value is followed by a newline and then a quote without a comma
Expand Down
Loading

0 comments on commit 09df0a7

Please sign in to comment.