Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Aug 18, 2023
1 parent 6b58913 commit b0b10d8
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions tests/benchmark/benchmark_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def warmup(

return gen_config


def benchmark_latency(
model,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -163,6 +164,7 @@ def benchmark_latency(

return total_time


def benchmark_memory(
model,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -191,11 +193,11 @@ def benchmark_memory(
_ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
else:
_ = model(input_ids, masks)

torch.cuda.synchronize()

memory_stats = torch.cuda.memory_stats()

peak_allocated_torch_mb = memory_stats["allocated_bytes.all.peak"] * 1e-6
peak_reserved_torch_mb = memory_stats["reserved_bytes.all.peak"] * 1e-6

Expand All @@ -206,7 +208,7 @@ def benchmark_memory(
assert peak_external_mb > 0

# This formula is to confirm. We measure the actual allocated PyTorch memory, plus the additional non-PyTorch memory (as the CUDA context, CUDA extension device memory). We need to substract the PyTorch peak reserved memory since this one appears in the peak nvidia-smi/nvmlDeviceGetMemoryInfo.

# NOTE: I verified this is only a ROUGH estimate. It may be better to use PYTORCH_NO_CUDA_MEMORY_CACHING=1 and just nvmlDeviceGetMemoryInfo.
# We can actually doubt whether it make sense to try to estimate when we would OOM, given that different devices, CUDA version do have
# a different CUDA context size.
Expand Down Expand Up @@ -334,7 +336,9 @@ def benchmark_memory(
latencies = {}
throughputs = {}
all_max_mem = {}
print("WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit.")
print(
"WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit."
)

for batch_size in tqdm(batch_sizes):
for prompt_length in tqdm(prompt_lengths):
Expand Down

0 comments on commit b0b10d8

Please sign in to comment.