Skip to content

Commit

Permalink
runtime utils
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketpurandare committed Oct 28, 2024
1 parent c6132d3 commit 049f16b
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 16 deletions.
1 change: 1 addition & 0 deletions torch/distributed/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .memory_tracker import MemoryTracker
from .mod_tracker import ModTracker
from .runtime_estimator import RuntimeEstimator
from .run_est_utils import get_peak_flops_registry
from .sac_estimator import (
MSPS,
SACEstimator,
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/_tools/mem_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import warnings
from copy import deepcopy
from enum import auto, Enum
from enum import auto, Enum, StrEnum
from functools import partial, wraps
from typing import (
Any,
Expand Down Expand Up @@ -47,11 +47,11 @@
__all__ = ["MemTracker"]


class _RefType(str, Enum):
class _RefType(StrEnum):
"""Base Class for defining memory reference types, categorizing tensors based on their usage within a model."""


class _State(str, Enum):
class _State(StrEnum):
"""Base Class for defining module state to capture snapshots ."""


Expand Down
99 changes: 99 additions & 0 deletions torch/distributed/_tools/run_est_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Dict
import subprocess
import torch

peak_factors: Dict[str, Dict[torch.dtype, float]] = {
"h100": {
torch.float16: 0.75,
torch.bfloat16: 0.75,
torch.float32: 0.5,
torch.float64: 0.5
},
"a100": {
torch.float16: 0.75,
torch.bfloat16: 0.75,
torch.float32: 0.65,
torch.float64: 0.65
}
}

def get_peak_flops_registry(device_name: str) -> Dict[torch.dtype, int]:
"""
Returns peak FLOPS for given device and data type.
Args:
device_name (str): Device name (e.g., "H100", "A100").
Returns:
Dict[torch.dtype, int]: Peak FLOPS reistry for the device.
Raises:
ValueError: If device is not supported.
"""
try:
# Run lspci command and capture output
result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)

# Filter output for lines containing device name
device_lines = [
line
for line in result.stdout.splitlines()
if device_name in line
]

if not device_lines:
raise ValueError(f"Device {device_name} not found")

# Determine model type (NVL or SXM) for H100
model_type = None
if device_name == "H100":
for line in device_lines:
if "NVL" in line:
model_type = "NVL"
break
elif "SXM" in line:
model_type = "SXM"
break
if model_type is None:
raise ValueError(f"Unable to determine model type for device {device_name}")

# Define peak FLOPS registry
peak_flops_registry = {
"A100": {
torch.float64: 9.7e12,
torch.float32: 19.5e12,
torch.bfloat16: 312e12,
torch.float16: 312e12,
torch.int8: 624e12,
},
"H100 SXM": {
torch.float64: 34e12,
torch.float32: 67e12,
torch.bfloat16: 1979e12,
torch.float16: 1979e12,
torch.int8: 3958e12,
},
"H100 NVL": {
torch.float64: 30e12,
torch.float32: 60e12,
torch.bfloat16: 1671e12,
torch.float16: 1671e12,
torch.int8: 3341e12,
},
}

# Get peak FLOPS for device and data type
device_key = device_name if device_name == "A100" else f"{device_name} {model_type}"
peak_flops_reg = peak_flops_registry.get(device_key, {})

if len(peak_flops_reg) == 0:
raise ValueError(f"Unsupported device {device_name}")

return peak_flops_reg

except subprocess.CalledProcessError as e:
print(f"Error running lspci: {e}")
raise
except Exception as e:
print(e)
raise
35 changes: 22 additions & 13 deletions torch/distributed/_tools/runtime_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import defaultdict
from typing import Any, Callable, Dict, List, Set, Tuple
from typing_extensions import Self
from torch.distributed._tools.run_est_utils import get_peak_flops_registry, peak_factors

import torch
import torch.utils._pytree as pytree
Expand Down Expand Up @@ -397,7 +398,8 @@ class RuntimeEstimator(TorchDispatchMode):
}
_no_fallback_kernel: Set[torch._ops._OpNamespace] = set()
fake_mode: FakeTensorMode

_peak_flops_reg: Dict[torch.dtype, int] = {}
_factors: Dict[torch.dtype, float] = {}
gpu_types: Dict[int, str] = {}
count = {}

Expand All @@ -417,10 +419,12 @@ def __init__(self) -> None:

gpu_id = torch.cuda.current_device() # Get the current GPU ID
if gpu_id not in RuntimeEstimator.gpu_types:
RuntimeEstimator.gpu_types[gpu_id] = self.get_device_type() # Initialize gpu_type for the GPU
RuntimeEstimator.gpu_types[gpu_id] = RuntimeEstimator.get_device_type() # Initialize gpu_type for the GPU
self.gpu_type = RuntimeEstimator.gpu_types[gpu_id] # Assign gpu_type based on the current GPU

def get_device_type(self) -> int:
RuntimeEstimator._factors = peak_factors[self.gpu_type]

@classmethod
def get_device_type(cls) -> str:
try:
result = subprocess.check_output(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'])
gpu_name = result.decode('utf-8').strip()
Expand Down Expand Up @@ -629,15 +633,20 @@ def _get_compute_time(cls, func_packet, args, kwargs, out, out_dtypes) -> float:
float: The estimated compute time in nanoseconds.
"""
if func_packet in flop_registry:
assert (
len(out_dtypes) == 1
), f"Only support single out dtype got {out_dtypes} for {func_packet}"
dtype = out_dtypes.pop()
# This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
peak_gpu_flops = get_device_tflops(dtype) * 1e15
# We can expect to achieve 75% of theoretical peak flops
factor = 0.75
peak_empirical_flops = factor * peak_gpu_flops
# assert (
# len(out_dtypes) == 1
# ), f"Only support single out dtype got {out_dtypes} for {func_packet}"
# dtype = out_dtypes.pop()
float_dtypes = out_dtypes & cls._float_types
dtype = min(float_dtypes, key=lambda x: x.itemsize)
if dtype == torch.float32:
print(func_packet, dtype)
print([arg.dtype for arg in args])
if len (cls._peak_flops_reg) == 0:
cls._peak_flops_reg = get_peak_flops_registry(cls.get_device_type().upper())

peak_gpu_flops = cls._peak_flops_reg[dtype]
peak_empirical_flops = cls._factors[dtype] * peak_gpu_flops
flop_count_func = flop_registry[func_packet]
# We divide by a factor of 2 to get the MACs (multiply and accumulate)
flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2
Expand Down

0 comments on commit 049f16b

Please sign in to comment.