Skip to content

Commit

Permalink
Fix BT benchmark script (#1344)
Browse files Browse the repository at this point in the history
* fix script

* fix test
  • Loading branch information
fxmarty authored Sep 4, 2023
1 parent b771e04 commit 5663aae
Showing 1 changed file with 48 additions and 22 deletions.
70 changes: 48 additions & 22 deletions tests/benchmark/benchmark_bettertransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from optimum.bettertransformer import BetterTransformer
from optimum.exporters import TasksManager

import numpy as np
import pandas as pd

def get_parser():
parser = argparse.ArgumentParser()
Expand All @@ -26,13 +27,13 @@ def get_parser():
"--avg-seqlen",
type=int,
default=256,
help="",
help="True average sequence length (the rest will be padding).",
)
parser.add_argument(
"--max-seqlen",
type=int,
default=256,
help="",
help="Input padded sequence length.",
)
parser.add_argument(
"--model-name",
Expand Down Expand Up @@ -61,12 +62,17 @@ def get_parser():
parser.add_argument(
"--is_decoder",
action="store_true",
help="Benchmark the generate method."
)
parser.add_argument(
"--sweep",
action="store_true",
)
parser.add_argument(
"--max_token",
type=int,
default=100,
help="",
help="Number of new tokens, for autoregressive models using generate.",
)
return parser

Expand Down Expand Up @@ -103,38 +109,51 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s


def timing_cuda(model, num_batches, input_ids, masks, is_decoder, generation_config=None):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if is_decoder:
model.generation_config.eos_token_id = None

torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
torch.cuda.synchronize()

start_event.record()
# We need NOT call torch.cuda.empty_cache() here as it appears to negate the warmup.

latencies = []
for _ in tqdm(range(num_batches)):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()

if is_decoder:
_ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config)
else:
_ = model(input_ids, masks)
end_event.record()
torch.cuda.synchronize()
end_event.record()
torch.cuda.synchronize()

latency_ms = start_event.elapsed_time(end_event)
print(f"\nLatency per token: {latency_ms / generation_config.min_new_tokens:.3f} ms")
latencies.append(latency_ms)

max_memory = torch.cuda.max_memory_allocated(device)

return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches, max_memory
return np.mean(latencies), max_memory


def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id):
# Warmup
if is_decoder:
gen_config = GenerationConfig(
max_new_tokens=max_token,
min_new_tokens=max_token,
use_cache=True,
pad_token_id=pad_token_id,
eos_token_id=None,
)

# warmup
if is_decoder:
_ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config)
torch.cuda.synchronize()

else:
_ = model(input_ids, masks)
torch.cuda.synchronize()
Expand All @@ -152,8 +171,13 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
parser = get_parser()
args = parser.parse_args()

BATCH_SIZES = [2]
SEQ_LEN = [64]
if args.sweep:
BATCH_SIZES = [1, 2, 4]
SEQ_LEN = [64, 128]
else:
BATCH_SIZES = [args.batch_size]
SEQ_LEN = [args.max_seqlen]

if args.is_decoder:
PAD_PERCENTAGES = [0]
else:
Expand Down Expand Up @@ -183,9 +207,10 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
else:
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None)

output_file = open("log_{}.csv".format(args.model_name.replace("/", "-")), "w")
output_name = "log_{}.csv".format(args.model_name.replace("/", "-"))
output_file = open(output_name, "w")
output_file.write(
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, Latency eager (s), Latency BT (s), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)\n"
"num_batches, batch_size, seq_len, is cuda, is half, use mask, pad percentage, Latency eager (ms), Latency BT (ms), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)\n"
)

all_total_hf_time = {}
Expand Down Expand Up @@ -229,8 +254,6 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
for seq_len in tqdm(SEQ_LEN):
for pad_perc in tqdm(PAD_PERCENTAGES):
print(f"-- Running: bs={bs}, seq_len={seq_len}")
# current_std = int(seq_len*pad_perc)
# max_seqlen = seq_len + current_std
max_seqlen = seq_len
mean_seqlen = int((1 - pad_perc) * max_seqlen)
input_ids, _, masks = get_batch(
Expand Down Expand Up @@ -268,15 +291,15 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
max_mem_eager = max_mem_eager * 1e-6
max_mem_bt = max_mem_bt * 1e-6

print(f"PT eager: {total_hf_time:.3f} s, peak {max_mem_eager:.2f} MB")
print(f"PT native: {total_bt_time:.3f} s, peak {max_mem_bt:.2f} MB")
print(f"PT eager: {total_hf_time:.3f} ms, peak {max_mem_eager:.2f} MB")
print(f"PT native: {total_bt_time:.3f} ms, peak {max_mem_bt:.2f} MB")

output_file.write(
"{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
args.num_batches,
args.use_cuda,
bs,
seq_len,
args.use_cuda,
args.use_half,
args.use_mask,
pad_perc,
Expand All @@ -290,3 +313,6 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t
)

output_file.close()
print("RESULTS:")
df = pd.read_csv(output_name)
print(df.to_markdown(index=False))

0 comments on commit 5663aae

Please sign in to comment.