Skip to content

Commit

Permalink
Capture max memory reserved and malloc_retries metric (#2520)
Browse files Browse the repository at this point in the history
Summary:

# This diff

Adds two metrics to the pipeline benchmarks:
* `num_alloc_retries` - this is bumped by one every time allocator cannot grab memory from device, and have to perform memory defrag/reclaiming
* `max reserved memory` - metric that captures the total reserved memory in addition to already collected `max allocated memory`

Differential Revision: D64896100
  • Loading branch information
che-sh authored and facebook-github-bot committed Oct 25, 2024
1 parent 43a20d0 commit 9bcad19
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 35 deletions.
100 changes: 76 additions & 24 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,46 @@ class CompileMode(Enum):
FX_SCRIPT = "fx_script"


@dataclass
class MemoryStats:
rank: int
malloc_retries: int
max_mem_allocated_mbs: int
max_mem_reserved_mbs: int

@classmethod
def for_device(cls, rank: int) -> "MemoryStats":
stats = torch.cuda.memory_stats(rank)
alloc_retries = stats.get("num_alloc_retries", 0)
max_allocated = stats.get("allocated_bytes.all.peak", 0)
max_reserved = stats.get("reserved_bytes.all.peak", 0)
return cls(
rank,
alloc_retries,
max_allocated // 1024 // 1024,
max_reserved // 1024 // 1024,
)

def __str__(self) -> str:
return f"Rank {self.rank}: retries={self.malloc_retries}, allocated={self.max_mem_allocated_mbs:7}mb, reserved={self.max_mem_reserved_mbs:7}mb"


@dataclass
class BenchmarkResult:
"Class for holding results of benchmark runs"
short_name: str
elapsed_time: torch.Tensor # milliseconds
max_mem_allocated: List[int] # megabytes
mem_stats: List[MemoryStats] # memory stats per rank
rank: int = -1

def __str__(self) -> str:
return f"{self.short_name: <{35}} | Runtime (P90): {self.runtime_percentile(90):g} ms | Memory (P90): {self.max_mem_percentile(90)/1000:.2g} GB"
runtime = f"Runtime (P90): {self.runtime_percentile(90):g} ms"
mem_alloc = (
f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2g} GB"
)
mem_reserved = f"Peak Memory reserved (P90): {self.max_mem_reserved_percentile(90)/1000:.2g} GB"
malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50) } / {self.mem_retries(90)} / {self.mem_retries(100)}"
return f"{self.short_name: <{35}} | {malloc_retries} | {runtime} | {mem_alloc} | {mem_reserved}"

def runtime_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
Expand All @@ -120,11 +150,37 @@ def runtime_percentile(
interpolation=interpolation,
)

def max_mem_percentile(
def max_mem_alloc_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
return self._mem_percentile(
lambda m: m.max_mem_allocated_mbs, percentile, interpolation
)

def max_mem_reserved_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
return self._mem_percentile(
lambda m: m.max_mem_reserved_mbs, percentile, interpolation
)

def mem_retries(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
max_mem = torch.tensor(self.max_mem_allocated, dtype=torch.float)
return torch.quantile(max_mem, percentile / 100.0, interpolation=interpolation)
return self._mem_percentile(
lambda m: m.malloc_retries, percentile, interpolation
)

def _mem_percentile(
self,
mem_selector: Callable[[MemoryStats], int],
percentile: int = 50,
interpolation: str = "nearest",
) -> torch.Tensor:
mem_data = torch.tensor(
[mem_selector(mem_stat) for mem_stat in self.mem_stats], dtype=torch.float
)
return torch.quantile(mem_data, percentile / 100.0, interpolation=interpolation)


class ECWrapper(torch.nn.Module):
Expand Down Expand Up @@ -345,11 +401,9 @@ def write_report(

qps = int(num_requests / avg_dur_s)

mem_allocated_by_rank = benchmark_res.max_mem_allocated

mem_str = ""
for i, mem_mb in enumerate(mem_allocated_by_rank):
mem_str += f"Rank {i}: {mem_mb:7}mb "
for memory_stats in benchmark_res.mem_stats:
mem_str += f"{memory_stats}\n"

report_str += f"{benchmark_res.short_name:40} Avg QPS:{qps:10} Avg Duration: {int(1000*avg_dur_s):5}"
report_str += f"ms Standard Dev Duration: {(1000*std_dur_s):.2f}ms\n"
Expand Down Expand Up @@ -521,7 +575,7 @@ def benchmark(
device_type: str = "cuda",
benchmark_unsharded_module: bool = False,
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
memory_stats: List[MemoryStats] = []
if enable_logging:
logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}")

Expand Down Expand Up @@ -580,12 +634,10 @@ def benchmark(
if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
memory_stats.append(MemoryStats.for_device(di))
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)
memory_stats.append(MemoryStats.for_device(rank))

if output_dir != "":
# Only do profiling if output_dir is set
Expand Down Expand Up @@ -640,7 +692,7 @@ def trace_handler(prof) -> None:
return BenchmarkResult(
short_name=name,
elapsed_time=elapsed_time,
max_mem_allocated=max_mem_allocated,
mem_stats=memory_stats,
rank=rank,
)

Expand All @@ -660,14 +712,16 @@ def benchmark_func(
device_type: str = "cuda",
pre_gpu_load: int = 0,
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
memory_stats: List[MemoryStats] = []
if device_type == "cuda":
if rank == -1:
# Reset memory for measurement, no process per rank so do all
for di in range(world_size):
torch.cuda.reset_peak_memory_stats(di)
torch.cuda.reset_accumulated_memory_stats(di)
else:
torch.cuda.reset_peak_memory_stats(rank)
torch.cuda.reset_accumulated_memory_stats(rank)

start = []
end = []
Expand Down Expand Up @@ -716,12 +770,10 @@ def benchmark_func(
if rank == -1:
# Add up all memory allocated in inference mode
for di in range(world_size):
b = torch.cuda.max_memory_allocated(di)
max_mem_allocated.append(b // 1024 // 1024)
memory_stats.append(MemoryStats.for_device(di))
else:
# Only add up memory allocated for current rank in training mode
b = torch.cuda.max_memory_allocated(rank)
max_mem_allocated.append(b // 1024 // 1024)
memory_stats.append(MemoryStats.for_device(rank))

if profile_dir != "":
# Only do profiling if output_dir is set
Expand Down Expand Up @@ -768,7 +820,7 @@ def trace_handler(prof) -> None:
return BenchmarkResult(
short_name=name,
elapsed_time=elapsed_time,
max_mem_allocated=max_mem_allocated,
mem_stats=memory_stats,
rank=rank,
)

Expand Down Expand Up @@ -942,7 +994,7 @@ def setUp() -> None:
res = qq.get()

benchmark_res_per_rank.append(res)
assert len(res.max_mem_allocated) == 1
assert len(res.mem_stats) == 1

for p in processes:
p.join()
Expand All @@ -951,13 +1003,13 @@ def setUp() -> None:
total_benchmark_res = BenchmarkResult(
benchmark_res_per_rank[0].short_name,
benchmark_res_per_rank[0].elapsed_time,
[0] * world_size,
[MemoryStats(rank, 0, 0, 0) for rank in range(world_size)],
0,
)

for res in benchmark_res_per_rank:
# Each rank's BenchmarkResult contains 1 memory measurement
total_benchmark_res.max_mem_allocated[res.rank] = res.max_mem_allocated[0]
total_benchmark_res.mem_stats[res.rank] = res.mem_stats[0]

return total_benchmark_res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,7 @@ def _func_to_benchmark(
rank=rank,
)
if rank == 0:
print(
f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB"
)
print(result)


def single_runner(
Expand Down Expand Up @@ -456,9 +454,7 @@ def _func_to_benchmark(
rank=0,
)

print(
f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB"
)
print(result)


if __name__ == "__main__":
Expand Down
12 changes: 9 additions & 3 deletions torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import click

import torch
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
from torchrec.distributed.benchmark.benchmark_utils import (
benchmark,
BenchmarkResult,
MemoryStats,
)
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import (
_fbgemm_permute_pooled_embs,
Expand Down Expand Up @@ -104,11 +108,13 @@ def wrapped_func(
result = BenchmarkResult(
short_name=name,
elapsed_time=torch.tensor(times) * 1e3,
max_mem_allocated=[0],
mem_stats=[MemoryStats(0, 0, 0, 0)],
)

mem_alloc = f"Memory alloc (P90): {result.max_mem_alloc_percentile(90):5.1f}"
mem_reserved = f"Memory alloc (P90): {result.max_mem_reserved_percentile(90):5.1f}"
print(
f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | Memory (P90): {result.max_mem_percentile(90):5.1f}"
f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | {mem_alloc} | {mem_reserved}"
)


Expand Down
4 changes: 2 additions & 2 deletions torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Otherwise will get error
# NotImplementedError: fbgemm::permute_1D_sparse_data: We could not find the abstract impl for this operator.
from fbgemm_gpu import sparse_ops # noqa: F401, E402
from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult
from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult, MemoryStats
from torchrec.distributed.dist_data import _get_recat

from torchrec.distributed.test_utils.test_model import ModelInput
Expand Down Expand Up @@ -227,7 +227,7 @@ def benchmark_kjt(
result = BenchmarkResult(
short_name=f"{test_name}-{transform_type.name}",
elapsed_time=torch.tensor(times),
max_mem_allocated=[0],
mem_stats=[MemoryStats(0, 0, 0, 0)],
)

p50_runtime = result.runtime_percentile(50, interpolation="linear").item()
Expand Down

0 comments on commit 9bcad19

Please sign in to comment.