Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Capture max memory reserved and malloc_retries metric #2520

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 76 additions & 24 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,46 @@ class CompileMode(Enum):
FX_SCRIPT = "fx_script"


@dataclass
class MemoryStats:
rank: int
malloc_retries: int
max_mem_allocated_mbs: int
max_mem_reserved_mbs: int

@classmethod
def for_device(cls, rank: int) -> "MemoryStats":
stats = torch.cuda.memory_stats(rank)
alloc_retries = stats.get("num_alloc_retries", 0)
max_allocated = stats.get("allocated_bytes.all.peak", 0)
max_reserved = stats.get("reserved_bytes.all.peak", 0)
return cls(
rank,
alloc_retries,
max_allocated // 1024 // 1024,
max_reserved // 1024 // 1024,
)

def __str__(self) -> str:
return f"Rank {self.rank}: retries={self.malloc_retries}, allocated={self.max_mem_allocated_mbs:7}mb, reserved={self.max_mem_reserved_mbs:7}mb"


@dataclass
class BenchmarkResult:
"Class for holding results of benchmark runs"
short_name: str
elapsed_time: torch.Tensor # milliseconds
max_mem_allocated: List[int] # megabytes
mem_stats: List[MemoryStats] # memory stats per rank
rank: int = -1

def __str__(self) -> str:
return f"{self.short_name: <{35}} | Runtime (P90): {self.runtime_percentile(90):g} ms | Memory (P90): {self.max_mem_percentile(90)/1000:.2g} GB"
runtime = f"Runtime (P90): {self.runtime_percentile(90):g} ms"
mem_alloc = (
f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2g} GB"
)
mem_reserved = f"Peak Memory reserved (P90): {self.max_mem_reserved_percentile(90)/1000:.2g} GB"
malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50) } / {self.mem_retries(90)} / {self.mem_retries(100)}"
return f"{self.short_name: <{35}} | {malloc_retries} | {runtime} | {mem_alloc} | {mem_reserved}"

def runtime_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
Expand All @@ -121,11 +151,37 @@ def runtime_percentile(
interpolation=interpolation,
)

def max_mem_percentile(
def max_mem_alloc_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
return self._mem_percentile(
lambda m: m.max_mem_allocated_mbs, percentile, interpolation
)

def max_mem_reserved_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
return self._mem_percentile(
lambda m: m.max_mem_reserved_mbs, percentile, interpolation
)

def mem_retries(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
max_mem = torch.tensor(self.max_mem_allocated, dtype=torch.float)
return torch.quantile(max_mem, percentile / 100.0, interpolation=interpolation)
return self._mem_percentile(
lambda m: m.malloc_retries, percentile, interpolation
)

def _mem_percentile(
self,
mem_selector: Callable[[MemoryStats], int],
percentile: int = 50,
interpolation: str = "nearest",
) -> torch.Tensor:
mem_data = torch.tensor(
[mem_selector(mem_stat) for mem_stat in self.mem_stats], dtype=torch.float
)
return torch.quantile(mem_data, percentile / 100.0, interpolation=interpolation)


class ECWrapper(torch.nn.Module):
Expand Down Expand Up @@ -346,11 +402,9 @@ def write_report(

qps = int(num_requests / avg_dur_s)

mem_allocated_by_rank = benchmark_res.max_mem_allocated

mem_str = ""
for i, mem_mb in enumerate(mem_allocated_by_rank):
mem_str += f"Rank {i}: {mem_mb:7}mb "
for memory_stats in benchmark_res.mem_stats:
mem_str += f"{memory_stats}\n"

report_str += f"{benchmark_res.short_name:40} Avg QPS:{qps:10} Avg Duration: {int(1000*avg_dur_s):5}"
report_str += f"ms Standard Dev Duration: {(1000*std_dur_s):.2f}ms\n"
Expand Down Expand Up @@ -523,7 +577,7 @@ def benchmark(
device_type: str = "cuda",
benchmark_unsharded_module: bool = False,
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
memory_stats: List[MemoryStats] = []
if enable_logging:
logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}")

Expand Down Expand Up @@ -582,12 +636,10 @@ def benchmark(
if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
memory_stats.append(MemoryStats.for_device(di))
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)
memory_stats.append(MemoryStats.for_device(rank))

if output_dir != "":
# Only do profiling if output_dir is set
Expand Down Expand Up @@ -642,7 +694,7 @@ def trace_handler(prof) -> None:
return BenchmarkResult(
short_name=name,
elapsed_time=elapsed_time,
max_mem_allocated=max_mem_allocated,
mem_stats=memory_stats,
rank=rank,
)

Expand All @@ -662,14 +714,16 @@ def benchmark_func(
device_type: str = "cuda",
pre_gpu_load: int = 0,
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
memory_stats: List[MemoryStats] = []
if device_type == "cuda":
if rank == -1:
# Reset memory for measurement, no process per rank so do all
for di in range(world_size):
torch.cuda.reset_peak_memory_stats(di)
torch.cuda.reset_accumulated_memory_stats(di)
else:
torch.cuda.reset_peak_memory_stats(rank)
torch.cuda.reset_accumulated_memory_stats(rank)

start = []
end = []
Expand Down Expand Up @@ -718,12 +772,10 @@ def benchmark_func(
if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
memory_stats.append(MemoryStats.for_device(di))
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)
memory_stats.append(MemoryStats.for_device(rank))

if profile_dir != "":
# Only do profiling if output_dir is set
Expand Down Expand Up @@ -770,7 +822,7 @@ def trace_handler(prof) -> None:
return BenchmarkResult(
short_name=name,
elapsed_time=elapsed_time,
max_mem_allocated=max_mem_allocated,
mem_stats=memory_stats,
rank=rank,
)

Expand Down Expand Up @@ -944,7 +996,7 @@ def setUp() -> None:
res = qq.get()

benchmark_res_per_rank.append(res)
assert len(res.max_mem_allocated) == 1
assert len(res.mem_stats) == 1

for p in processes:
p.join()
Expand All @@ -953,13 +1005,13 @@ def setUp() -> None:
total_benchmark_res = BenchmarkResult(
benchmark_res_per_rank[0].short_name,
benchmark_res_per_rank[0].elapsed_time,
[0] * world_size,
[MemoryStats(rank, 0, 0, 0) for rank in range(world_size)],
0,
)

for res in benchmark_res_per_rank:
# Each rank's BenchmarkResult contains 1 memory measurement
total_benchmark_res.max_mem_allocated[res.rank] = res.max_mem_allocated[0]
total_benchmark_res.mem_stats[res.rank] = res.mem_stats[0]

return total_benchmark_res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,7 @@ def _func_to_benchmark(
rank=rank,
)
if rank == 0:
print(
f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB"
)
print(result)


def single_runner(
Expand Down Expand Up @@ -456,9 +454,7 @@ def _func_to_benchmark(
rank=0,
)

print(
f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB"
)
print(result)


if __name__ == "__main__":
Expand Down
12 changes: 9 additions & 3 deletions torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import click

import torch
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
from torchrec.distributed.benchmark.benchmark_utils import (
benchmark,
BenchmarkResult,
MemoryStats,
)
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import (
_fbgemm_permute_pooled_embs,
Expand Down Expand Up @@ -104,11 +108,13 @@ def wrapped_func(
result = BenchmarkResult(
short_name=name,
elapsed_time=torch.tensor(times) * 1e3,
max_mem_allocated=[0],
mem_stats=[MemoryStats(0, 0, 0, 0)],
)

mem_alloc = f"Memory alloc (P90): {result.max_mem_alloc_percentile(90):5.1f}"
mem_reserved = f"Memory alloc (P90): {result.max_mem_reserved_percentile(90):5.1f}"
print(
f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | Memory (P90): {result.max_mem_percentile(90):5.1f}"
f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | {mem_alloc} | {mem_reserved}"
)


Expand Down
4 changes: 2 additions & 2 deletions torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Otherwise will get error
# NotImplementedError: fbgemm::permute_1D_sparse_data: We could not find the abstract impl for this operator.
from fbgemm_gpu import sparse_ops # noqa: F401, E402
from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult
from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult, MemoryStats
from torchrec.distributed.dist_data import _get_recat

from torchrec.distributed.test_utils.test_model import ModelInput
Expand Down Expand Up @@ -227,7 +227,7 @@ def benchmark_kjt(
result = BenchmarkResult(
short_name=f"{test_name}-{transform_type.name}",
elapsed_time=torch.tensor(times),
max_mem_allocated=[0],
mem_stats=[MemoryStats(0, 0, 0, 0)],
)

p50_runtime = result.runtime_percentile(50, interpolation="linear").item()
Expand Down
Loading