From b0b10d8824b4ca62bd9188e9e6ed7bb64acea80c Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 18 Aug 2023 19:10:52 +0000 Subject: [PATCH] style --- tests/benchmark/benchmark_gptq.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/benchmark/benchmark_gptq.py b/tests/benchmark/benchmark_gptq.py index 61e665d9e1..b3745a4a2c 100644 --- a/tests/benchmark/benchmark_gptq.py +++ b/tests/benchmark/benchmark_gptq.py @@ -136,6 +136,7 @@ def warmup( return gen_config + def benchmark_latency( model, input_ids: torch.Tensor, @@ -163,6 +164,7 @@ def benchmark_latency( return total_time + def benchmark_memory( model, input_ids: torch.Tensor, @@ -191,11 +193,11 @@ def benchmark_memory( _ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config) else: _ = model(input_ids, masks) - + torch.cuda.synchronize() - + memory_stats = torch.cuda.memory_stats() - + peak_allocated_torch_mb = memory_stats["allocated_bytes.all.peak"] * 1e-6 peak_reserved_torch_mb = memory_stats["reserved_bytes.all.peak"] * 1e-6 @@ -206,7 +208,7 @@ def benchmark_memory( assert peak_external_mb > 0 # This formula is to confirm. We measure the actual allocated PyTorch memory, plus the additional non-PyTorch memory (as the CUDA context, CUDA extension device memory). We need to substract the PyTorch peak reserved memory since this one appears in the peak nvidia-smi/nvmlDeviceGetMemoryInfo. - + # NOTE: I verified this is only a ROUGH estimate. It may be better to use PYTORCH_NO_CUDA_MEMORY_CACHING=1 and just nvmlDeviceGetMemoryInfo. # We can actually doubt whether it make sense to try to estimate when we would OOM, given that different devices, CUDA version do have # a different CUDA context size. @@ -334,7 +336,9 @@ def benchmark_memory( latencies = {} throughputs = {} all_max_mem = {} -print("WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit.") +print( + "WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit." +) for batch_size in tqdm(batch_sizes): for prompt_length in tqdm(prompt_lengths):