From 3d2c9eade1b69a414faa2bb7828dc8cfedbbfecd Mon Sep 17 00:00:00 2001 From: Evgenii Kolpakov Date: Tue, 29 Oct 2024 21:09:05 -0700 Subject: [PATCH] Capture max memory reserved and malloc_retries metric (#2520) Summary: # This diff Adds two metrics to the pipeline benchmarks: * `num_alloc_retries` - this is bumped by one every time allocator cannot grab memory from device, and have to perform memory defrag/reclaiming * `max reserved memory` - metric that captures the total reserved memory in addition to already collected `max allocated memory` Reviewed By: dstaay-fb Differential Revision: D64896100 --- .../distributed/benchmark/benchmark_utils.py | 100 +++++++++++++----- .../tests/pipeline_benchmarks.py | 8 +- .../sparse/tests/jagged_tensor_benchmark.py | 12 ++- .../keyed_jagged_tensor_benchmark_lib.py | 4 +- 4 files changed, 89 insertions(+), 35 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 389786641..2177830cc 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -101,16 +101,46 @@ class CompileMode(Enum): FX_SCRIPT = "fx_script" +@dataclass +class MemoryStats: + rank: int + malloc_retries: int + max_mem_allocated_mbs: int + max_mem_reserved_mbs: int + + @classmethod + def for_device(cls, rank: int) -> "MemoryStats": + stats = torch.cuda.memory_stats(rank) + alloc_retries = stats.get("num_alloc_retries", 0) + max_allocated = stats.get("allocated_bytes.all.peak", 0) + max_reserved = stats.get("reserved_bytes.all.peak", 0) + return cls( + rank, + alloc_retries, + max_allocated // 1024 // 1024, + max_reserved // 1024 // 1024, + ) + + def __str__(self) -> str: + return f"Rank {self.rank}: retries={self.malloc_retries}, allocated={self.max_mem_allocated_mbs:7}mb, reserved={self.max_mem_reserved_mbs:7}mb" + + @dataclass class BenchmarkResult: "Class for holding results of benchmark runs" short_name: str elapsed_time: torch.Tensor # milliseconds - max_mem_allocated: List[int] # megabytes + mem_stats: List[MemoryStats] # memory stats per rank rank: int = -1 def __str__(self) -> str: - return f"{self.short_name: <{35}} | Runtime (P90): {self.runtime_percentile(90):g} ms | Memory (P90): {self.max_mem_percentile(90)/1000:.2g} GB" + runtime = f"Runtime (P90): {self.runtime_percentile(90):g} ms" + mem_alloc = ( + f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2g} GB" + ) + mem_reserved = f"Peak Memory reserved (P90): {self.max_mem_reserved_percentile(90)/1000:.2g} GB" + malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50) } / {self.mem_retries(90)} / {self.mem_retries(100)}" + return f"{self.short_name: <{35}} | {malloc_retries} | {runtime} | {mem_alloc} | {mem_reserved}" def runtime_percentile( self, percentile: int = 50, interpolation: str = "nearest" @@ -121,11 +151,37 @@ def runtime_percentile( interpolation=interpolation, ) - def max_mem_percentile( + def max_mem_alloc_percentile( + self, percentile: int = 50, interpolation: str = "nearest" + ) -> torch.Tensor: + return self._mem_percentile( + lambda m: m.max_mem_allocated_mbs, percentile, interpolation + ) + + def max_mem_reserved_percentile( + self, percentile: int = 50, interpolation: str = "nearest" + ) -> torch.Tensor: + return self._mem_percentile( + lambda m: m.max_mem_reserved_mbs, percentile, interpolation + ) + + def mem_retries( self, percentile: int = 50, interpolation: str = "nearest" ) -> torch.Tensor: - max_mem = torch.tensor(self.max_mem_allocated, dtype=torch.float) - return torch.quantile(max_mem, percentile / 100.0, interpolation=interpolation) + return self._mem_percentile( + lambda m: m.malloc_retries, percentile, interpolation + ) + + def _mem_percentile( + self, + mem_selector: Callable[[MemoryStats], int], + percentile: int = 50, + interpolation: str = "nearest", + ) -> torch.Tensor: + mem_data = torch.tensor( + [mem_selector(mem_stat) for mem_stat in self.mem_stats], dtype=torch.float + ) + return torch.quantile(mem_data, percentile / 100.0, interpolation=interpolation) class ECWrapper(torch.nn.Module): @@ -346,11 +402,9 @@ def write_report( qps = int(num_requests / avg_dur_s) - mem_allocated_by_rank = benchmark_res.max_mem_allocated - mem_str = "" - for i, mem_mb in enumerate(mem_allocated_by_rank): - mem_str += f"Rank {i}: {mem_mb:7}mb " + for memory_stats in benchmark_res.mem_stats: + mem_str += f"{memory_stats}\n" report_str += f"{benchmark_res.short_name:40} Avg QPS:{qps:10} Avg Duration: {int(1000*avg_dur_s):5}" report_str += f"ms Standard Dev Duration: {(1000*std_dur_s):.2f}ms\n" @@ -523,7 +577,7 @@ def benchmark( device_type: str = "cuda", benchmark_unsharded_module: bool = False, ) -> BenchmarkResult: - max_mem_allocated: List[int] = [] + memory_stats: List[MemoryStats] = [] if enable_logging: logger.info(f" BENCHMARK_MODEL[{name}]:\n{model}") @@ -582,12 +636,10 @@ def benchmark( if rank == -1: # Add up all memory allocated in inference mode for di in range(world_size): - b = torch.cuda.max_memory_allocated(di) - max_mem_allocated.append(b // 1024 // 1024) + memory_stats.append(MemoryStats.for_device(di)) else: # Only add up memory allocated for current rank in training mode - b = torch.cuda.max_memory_allocated(rank) - max_mem_allocated.append(b // 1024 // 1024) + memory_stats.append(MemoryStats.for_device(rank)) if output_dir != "": # Only do profiling if output_dir is set @@ -642,7 +694,7 @@ def trace_handler(prof) -> None: return BenchmarkResult( short_name=name, elapsed_time=elapsed_time, - max_mem_allocated=max_mem_allocated, + mem_stats=memory_stats, rank=rank, ) @@ -662,14 +714,16 @@ def benchmark_func( device_type: str = "cuda", pre_gpu_load: int = 0, ) -> BenchmarkResult: - max_mem_allocated: List[int] = [] + memory_stats: List[MemoryStats] = [] if device_type == "cuda": if rank == -1: # Reset memory for measurement, no process per rank so do all for di in range(world_size): torch.cuda.reset_peak_memory_stats(di) + torch.cuda.reset_accumulated_memory_stats(di) else: torch.cuda.reset_peak_memory_stats(rank) + torch.cuda.reset_accumulated_memory_stats(rank) start = [] end = [] @@ -718,12 +772,10 @@ def benchmark_func( if rank == -1: # Add up all memory allocated in inference mode for di in range(world_size): - b = torch.cuda.max_memory_allocated(di) - max_mem_allocated.append(b // 1024 // 1024) + memory_stats.append(MemoryStats.for_device(di)) else: # Only add up memory allocated for current rank in training mode - b = torch.cuda.max_memory_allocated(rank) - max_mem_allocated.append(b // 1024 // 1024) + memory_stats.append(MemoryStats.for_device(rank)) if profile_dir != "": # Only do profiling if output_dir is set @@ -770,7 +822,7 @@ def trace_handler(prof) -> None: return BenchmarkResult( short_name=name, elapsed_time=elapsed_time, - max_mem_allocated=max_mem_allocated, + mem_stats=memory_stats, rank=rank, ) @@ -944,7 +996,7 @@ def setUp() -> None: res = qq.get() benchmark_res_per_rank.append(res) - assert len(res.max_mem_allocated) == 1 + assert len(res.mem_stats) == 1 for p in processes: p.join() @@ -953,13 +1005,13 @@ def setUp() -> None: total_benchmark_res = BenchmarkResult( benchmark_res_per_rank[0].short_name, benchmark_res_per_rank[0].elapsed_time, - [0] * world_size, + [MemoryStats(rank, 0, 0, 0) for rank in range(world_size)], 0, ) for res in benchmark_res_per_rank: # Each rank's BenchmarkResult contains 1 memory measurement - total_benchmark_res.max_mem_allocated[res.rank] = res.max_mem_allocated[0] + total_benchmark_res.mem_stats[res.rank] = res.mem_stats[0] return total_benchmark_res diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index 81bd54928..538264c04 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -382,9 +382,7 @@ def _func_to_benchmark( rank=rank, ) if rank == 0: - print( - f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB" - ) + print(result) def single_runner( @@ -456,9 +454,7 @@ def _func_to_benchmark( rank=0, ) - print( - f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB" - ) + print(result) if __name__ == "__main__": diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index b9dd12d3b..34862e380 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -15,7 +15,11 @@ import click import torch -from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult +from torchrec.distributed.benchmark.benchmark_utils import ( + benchmark, + BenchmarkResult, + MemoryStats, +) from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import ( _fbgemm_permute_pooled_embs, @@ -104,11 +108,13 @@ def wrapped_func( result = BenchmarkResult( short_name=name, elapsed_time=torch.tensor(times) * 1e3, - max_mem_allocated=[0], + mem_stats=[MemoryStats(0, 0, 0, 0)], ) + mem_alloc = f"Memory alloc (P90): {result.max_mem_alloc_percentile(90):5.1f}" + mem_reserved = f"Memory alloc (P90): {result.max_mem_reserved_percentile(90):5.1f}" print( - f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | Memory (P90): {result.max_mem_percentile(90):5.1f}" + f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | {mem_alloc} | {mem_reserved}" ) diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py index e7e8e50df..235495494 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -20,7 +20,7 @@ # Otherwise will get error # NotImplementedError: fbgemm::permute_1D_sparse_data: We could not find the abstract impl for this operator. from fbgemm_gpu import sparse_ops # noqa: F401, E402 -from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult +from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult, MemoryStats from torchrec.distributed.dist_data import _get_recat from torchrec.distributed.test_utils.test_model import ModelInput @@ -227,7 +227,7 @@ def benchmark_kjt( result = BenchmarkResult( short_name=f"{test_name}-{transform_type.name}", elapsed_time=torch.tensor(times), - max_mem_allocated=[0], + mem_stats=[MemoryStats(0, 0, 0, 0)], ) p50_runtime = result.runtime_percentile(50, interpolation="linear").item()