Skip to content

Commit

Permalink
add proper bench
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Aug 18, 2023
1 parent 05fb48d commit 6b58913
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 50 deletions.
56 changes: 54 additions & 2 deletions tests/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,71 @@

Run

```
```shell
CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --sweep --num-batches 4 --task text-generation
```

and

```
```shell
git clone --branch main https://huggingface.co/TheBloke/Llama-2-13B-chat-GPTQ
cd Llama-2-13B-chat-GPTQ
mv gptq_model-4bit-128g.safetensors model.safetensors
mv quantize_config.json quantization_config.json

# and then
# with exllama kernel
CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --gptq-model /path/to/Llama-2-13B-chat-GPTQ/ --sweep --num-batches 4 --gptq --task text-generation

# without exllama kernel
CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --gptq-model /path/to/Llama-2-13B-chat-GPTQ/ --sweep --num-batches 4 --gptq --task text-generation --disable-exllama
```

### Benchmark results

Here are results obtained on a single NVIDIA A100-SXM4-80GB GPU. We use a prompt length of 512, and generate exactly 512 new tokens. Each generation is repeated for 4 batches, and metrics are averaged over the number of batches and generation length.

Additional benchmarks could be done in the act-order case.

From the bencharmk, it appears that Exllama kernel is the best-in-class for GPTQ, although it is rather slow for larger batch sizes. The memory savings are not exactly of x4 although weights are in int4. This can be explained by the possible static buffers used by the kernels, the CUDA context (taken into account in the measurements), and the KV cache that is still in fp16.

#### Batch size = 1

|gptq |act_order|bits|group_size|kernel|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)|
|-----|---------|----|----------|------|-------------|----------------------|------------------|----------------|
|False|None |None|None |None |26.0 |36.958 |27.058 |29152.98 |
|True |False |4 |128 |exllama|36.2 |33.711 |29.663 |10484.34 |
|True |False |4 |128 |autogptq-cuda-old|36.2 |46.44 |21.53 |10344.62 |


#### Batch size = 2

|gptq |act_order|bits|group_size|kernel|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)|
|-----|---------|----|----------|------|-------------|----------------------|------------------|----------------|
|False|None |None|None |None |26.0 |37.35 |53.53 |30831.09 |
|True |False |4 |128 |exllama|36.2 |37.25 |53.68 |12162.43 |
|True |False |4 |128 |autogptq-cuda-old|36.2 |47.41 |42.18 |12020.34 |

#### Batch size = 4

|gptq |act_order|bits|group_size|kernel|FIELD6|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)|
|-----|---------|----|----------|------|------|-------------|----------------------|------------------|----------------|
|False|None |None|None |None |26.0 |37.89 |105.55 |34187.22 | |
|True |False |4 |128 |exllama|36.2 |54.14 |73.87 |15518.55 | |
|True |False |4 |128 |autogptq-cuda-old|36.2 |60.98 |65.59 |15374.67 | |

#### Batch size = 8

|gptq |act_order|bits|group_size|kernel|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)|
|-----|---------|----|----------|------|-------------|----------------------|------------------|----------------|
|False|None |None|None |None |26.0 |47.37 |168.86 |40327.62 |
|True |False |4 |128 |exllama|36.2 |73.57 |108.73 |21864.56 |
|True |False |4 |128 |autogptq-cuda-old|36.2 |104.44 |76.59 |20987.68 |

#### Batch size = 16

|gptq |act_order|bits|group_size|kernel|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)|
|-----|---------|----|----------|------|-------------|----------------------|------------------|----------------|
|False|None |None|None |None |26.0 |69.94 |228.76 |53986.51 |
|True |False |4 |128 |exllama|36.2 |95.41 |167.68 |34777.04 |
|True |False |4 |128 |autogptq-cuda-old|36.2 |192.48 |83.12 |35497.62 |
168 changes: 120 additions & 48 deletions tests/benchmark/benchmark_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,37 +106,14 @@ def timing_cuda(
return np.mean(latencies)


def memory_cuda(
def warmup(
model,
input_ids: torch.Tensor,
masks: torch.Tensor,
is_decoder: bool,
memory_tracker: MemoryTracker,
generation_config=None,
):
with memory_tracker.track():
if is_decoder:
_ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config)
else:
_ = model(input_ids, masks)

return memory_tracker.peak_memory


def benchmark(
model,
input_ids: torch.Tensor,
masks: torch.Tensor,
num_batches: int,
is_decoder: bool,
new_tokens: int,
pad_token_id: int,
memory_tracker: MemoryTracker,
):
torch.cuda.empty_cache()
gc.collect()

# It appears running the warmup only once is not enough to get low variance on the latency in later runs. Hence the `for i in range(2):` below.
print("Warmup...")
if is_decoder:
gen_config = GenerationConfig(
Expand All @@ -157,19 +134,98 @@ def benchmark(
_ = model(input_ids, masks)
torch.cuda.synchronize()

return gen_config

def benchmark_latency(
model,
input_ids: torch.Tensor,
masks: torch.Tensor,
num_batches: int,
is_decoder: bool,
new_tokens: int,
pad_token_id: int,
memory_tracker: MemoryTracker,
):
torch.cuda.empty_cache()
gc.collect()

gen_config = warmup(
model,
input_ids,
masks,
is_decoder,
new_tokens,
pad_token_id,
)

print("Measuring latency...")
total_time = timing_cuda(model, num_batches, input_ids, masks, is_decoder, gen_config)

return total_time

def benchmark_memory(
model,
input_ids: torch.Tensor,
masks: torch.Tensor,
num_batches: int,
is_decoder: bool,
new_tokens: int,
pad_token_id: int,
memory_tracker: MemoryTracker,
):
torch.cuda.empty_cache()
gc.collect()

print("Measuring peak memory...")
max_mem = memory_cuda(model, input_ids, masks, is_decoder, memory_tracker, gen_config)
with memory_tracker.track():
gen_config = warmup(
model,
input_ids,
masks,
is_decoder,
new_tokens,
pad_token_id,
)

if is_decoder:
_ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
else:
_ = model(input_ids, masks)

torch.cuda.synchronize()

memory_stats = torch.cuda.memory_stats()

peak_allocated_torch_mb = memory_stats["allocated_bytes.all.peak"] * 1e-6
peak_reserved_torch_mb = memory_stats["reserved_bytes.all.peak"] * 1e-6

return total_time, max_mem
peak_nvml_mb = memory_tracker.peak_memory

# I am not sure whether we should substract here `inactive_split_bytes.all.peak` (not sure what it corresponds to, though it can get quite large, in the several GB).
peak_external_mb = peak_nvml_mb - peak_reserved_torch_mb
assert peak_external_mb > 0

# This formula is to confirm. We measure the actual allocated PyTorch memory, plus the additional non-PyTorch memory (as the CUDA context, CUDA extension device memory). We need to substract the PyTorch peak reserved memory since this one appears in the peak nvidia-smi/nvmlDeviceGetMemoryInfo.

# NOTE: I verified this is only a ROUGH estimate. It may be better to use PYTORCH_NO_CUDA_MEMORY_CACHING=1 and just nvmlDeviceGetMemoryInfo.
# We can actually doubt whether it make sense to try to estimate when we would OOM, given that different devices, CUDA version do have
# a different CUDA context size.
peak_memory_mb = peak_allocated_torch_mb + peak_external_mb

print(f"DEBUG: peak allocated torch: {peak_allocated_torch_mb:.2f} MB")
print(f"DEBUG: peak nvidia-smi/nvml: {peak_nvml_mb:.2f} MB")
print(f"DEBUG: peak reserved torch: {peak_reserved_torch_mb:.2f} MB")
print(f"DEBUG: peak external: {peak_external_mb:.2f} MB")
print(f"DEBUG: global peak: {peak_memory_mb:.2f} MB")

return peak_memory_mb


parser = get_parser()
args = parser.parse_args()

if args.sweep:
batch_sizes = [1, 4, 8, 16]
batch_sizes = [1, 2, 4, 8, 16]
prompt_lengths = [512]
new_tokens = [512]
else:
Expand Down Expand Up @@ -226,6 +282,7 @@ def benchmark(
group_size = quantize_config_dict["group_size"]

if not args.disable_exllama:
# Exllama kernel can handle both the act-order / no act-order cases.
kernel = "exllama"
elif act_order:
kernel = "autotogptq-cuda"
Expand All @@ -251,6 +308,7 @@ def benchmark(
load_end = time.time_ns()

load_time = (load_end - load_start) * 1e-9
print(f"Model load time: {load_time:.1f} s")

uses_gptq = args.gptq
print(f"Model uses GPTQ: {uses_gptq}")
Expand All @@ -270,24 +328,38 @@ def benchmark(
file_name = file_name + ".csv"

output_file = open(file_name, "w")
output_file.write(
"gptq, act_order, bits, group_size, kernel, num_batches, batch_size, prompt_length, new_tokens, Load time (s), Per-token latency (ms), Throughput (tok/s), Max memory (MB)\n"
)
header = "gptq, act_order, bits, group_size, kernel, num_batches, batch_size, prompt_length, new_tokens, Load time (s), Per-token latency (ms), Throughput (tok/s), Max memory (MB)\n"
output_file.write(header)

latencies = {}
throughputs = {}
all_max_mem = {}
print("WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit.")

for batch_size in tqdm(batch_sizes):
for prompt_length in tqdm(prompt_lengths):
for new_token in tqdm(new_tokens):
print(f"---- Running: batch_size={batch_size}, prompt_length={prompt_length}, new_tokens={new_token}")

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device)
masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device)

with torch.no_grad():
mean_latency, max_mem = benchmark(
max_mem = benchmark_memory(
model,
input_ids,
masks,
args.num_batches,
is_decoder,
new_token,
tokenizer.pad_token_id,
memory_tracker=memory_tracker,
)

mean_latency = benchmark_latency(
model,
input_ids,
masks,
Expand All @@ -307,27 +379,27 @@ def benchmark(
throughputs[index] = throughput
all_max_mem[index] = max_mem

# TODO: validate that maxmem is correct
print(
f"Latency per token: {per_token_latency:.3f} ms, throughput: {throughput:.3f} tok/s, peak mem: {max_mem:.2f} MB"
)

output_file.write(
"{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
uses_gptq,
act_order,
bits,
group_size,
kernel,
args.num_batches,
batch_size,
prompt_length,
new_token,
f"{load_time:.2f}",
f"{per_token_latency:.4f}",
f"{throughput:.4f}",
f"{max_mem:.4f}",
)
line = "{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
uses_gptq,
act_order,
bits,
group_size,
kernel,
args.num_batches,
batch_size,
prompt_length,
new_token,
f"{load_time:.2f}",
f"{per_token_latency:.2f}",
f"{throughput:.2f}",
f"{max_mem:.2f}",
)
print(header)
print(line)
output_file.write(line)

output_file.close()

0 comments on commit 6b58913

Please sign in to comment.