From 20c204f6b2e04a63061fc3c9b65210f432d907fb Mon Sep 17 00:00:00 2001 From: Bas Krahmer Date: Fri, 14 Jul 2023 17:18:27 +0200 Subject: [PATCH] add cpu support; refactor/cleanup --- optimum/onnxruntime/runs/__init__.py | 15 +----- optimum/runs_base.py | 2 +- .../benchmark/benchmark_bettertransformer.py | 48 +++++------------- .../benchmark_bettertransformer_training.py | 49 +++++++------------ ...mark_bettertransformer_training_minimal.py | 29 +++-------- .../benchmark_bettertransformer_vit.py | 43 ++++++---------- ...st_transformers_optimum_examples_parity.py | 1 + 7 files changed, 55 insertions(+), 132 deletions(-) diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index e77e2cab21..7d7aa34d2d 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -1,4 +1,3 @@ -import copy import os from pathlib import Path @@ -56,20 +55,8 @@ def __init__(self, run_config): processing_class = task_processing_map[self.task] self.task_processor = processing_class( - dataset_path=run_config["dataset"]["path"], - dataset_name=run_config["dataset"]["name"], - calibration_split=run_config["dataset"]["calibration_split"], - eval_split=run_config["dataset"]["eval_split"], preprocessor=self.preprocessor, - data_keys=run_config["dataset"]["data_keys"], - ref_keys=run_config["dataset"]["ref_keys"], - task_args=run_config["task_args"], - static_quantization=self.static_quantization, - num_calibration_samples=run_config["calibration"]["num_calibration_samples"] - if self.static_quantization - else None, - config=trfs_model.config, - max_eval_samples=run_config["max_eval_samples"], + config=run_config, ) self.metric_names = run_config["metrics"] diff --git a/optimum/runs_base.py b/optimum/runs_base.py index 3a1d164c60..275ad00ea8 100644 --- a/optimum/runs_base.py +++ b/optimum/runs_base.py @@ -145,7 +145,7 @@ def launch_eval(self): def load_datasets(self): """Load evaluation dataset, and if needed, calibration dataset for static quantization.""" - datasets_dict = self.task_processor.load_datasets() + datasets_dict = self.task_processor.load_dataset() self._eval_dataset = datasets_dict["eval"] if self.static_quantization: diff --git a/tests/benchmark/benchmark_bettertransformer.py b/tests/benchmark/benchmark_bettertransformer.py index be8e8ef1eb..5978dacf0c 100644 --- a/tests/benchmark/benchmark_bettertransformer.py +++ b/tests/benchmark/benchmark_bettertransformer.py @@ -1,7 +1,7 @@ import argparse import torch -from torch.profiler import profile +from benchmark_common import timing_cpu, timing_cuda from tqdm import tqdm from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig @@ -107,34 +107,6 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s return tokens, lengths, mask -def timing_cuda(model, num_batches, input_ids, masks, is_decoder, generation_config=None): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - torch.cuda.reset_peak_memory_stats(device) - torch.cuda.empty_cache() - torch.cuda.synchronize() - - start_event.record() - inference_fn(generation_config, input_ids, is_decoder, masks, model, num_batches) - - end_event.record() - torch.cuda.synchronize() - max_memory = torch.cuda.max_memory_allocated(device) - - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches, max_memory - - -def timing_cpu(model, num_batches, input_ids, masks, is_decoder, generation_config=None): - with profile(activities=[torch.profiler.ProfilerActivity.CPU], profile_memory=True) as p: - inference_fn(generation_config, input_ids, is_decoder, masks, model, num_batches) - - elapsed_time = p.key_averages().self_cpu_time_total - max_memory = max([event.cpu_memory_usage for event in p.key_averages()]) - - return elapsed_time / num_batches, max_memory - - def inference_fn(generation_config, input_ids, is_decoder, masks, model, num_batches): for _ in tqdm(range(num_batches)): if is_decoder: @@ -144,7 +116,6 @@ def inference_fn(generation_config, input_ids, is_decoder, masks, model, num_bat def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id, use_cuda): - # Warmup if is_decoder: gen_config = GenerationConfig( max_new_tokens=max_token, @@ -152,21 +123,24 @@ def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_t use_cache=True, pad_token_id=pad_token_id, ) - _ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config) else: - _ = model(input_ids, masks) + gen_config = None + + # Warmup + inference_fn(gen_config, input_ids, is_decoder, masks, model, num_batches=1) if use_cuda: torch.cuda.synchronize() - # benchmark - timing_fn = timing_cuda if args.use_cuda else timing_cpu + def benchmarked_fn(): + inference_fn(gen_config, input_ids, is_decoder, masks, model, num_batches) - if is_decoder: - total_time, max_mem = timing_fn(model, num_batches, input_ids, masks, is_decoder, gen_config) + # benchmark + if args.use_cuda: + total_time, max_mem = timing_cuda(benchmarked_fn, num_batches, device) else: - total_time, max_mem = timing_fn(model, num_batches, input_ids, masks, is_decoder) + total_time, max_mem = timing_cpu(benchmarked_fn, num_batches) return total_time, max_mem diff --git a/tests/benchmark/benchmark_bettertransformer_training.py b/tests/benchmark/benchmark_bettertransformer_training.py index a46ca6c444..042055b109 100644 --- a/tests/benchmark/benchmark_bettertransformer_training.py +++ b/tests/benchmark/benchmark_bettertransformer_training.py @@ -3,8 +3,8 @@ import numpy as np import torch +from benchmark_common import timing_cpu, timing_cuda from datasets import load_dataset -from torch.profiler import profile from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler @@ -78,39 +78,28 @@ def benchmark_training(model, num_epochs: int, train_dataloader, device): lr_scheduler.step() optimizer.zero_grad() + def benchmark_fn(): + training_fn(num_training_steps, batch, device, lr_scheduler, model, optimizer, progress_bar) + if device.type == "cpu": - with profile(activities=[torch.profiler.ProfilerActivity.CPU], profile_memory=True) as p: - for _ in range(num_training_steps): - training_fn(batch, device, lr_scheduler, model, optimizer, progress_bar) + total_time, max_mem = timing_cpu(benchmark_fn, num_training_steps) + return total_time, max_mem - elapsed_time = p.key_averages().self_cpu_time_total - max_memory = max([event.cpu_memory_usage for event in p.key_averages()]) + else: + total_time, max_mem = timing_cuda(benchmark_fn, num_training_steps, device) + return total_time, max_mem - return elapsed_time / num_training_steps, max_memory - else: - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for _ in range(num_epochs): - for _, batch in enumerate(train_dataloader): - training_fn(batch, device, lr_scheduler, model, optimizer, progress_bar) - - end_event.record() - torch.cuda.synchronize() - - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_epochs - - -def training_fn(batch, device, lr_scheduler, model, optimizer, progress_bar): - batch = {k: v.to(device) for k, v in batch.items()} - outputs = model(**batch) - loss = outputs.logits.sum() - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - progress_bar.update(1) +def training_fn(num_train_steps, batch, device, lr_scheduler, model, optimizer, progress_bar): + for _ in range(num_train_steps): + batch = {k: v.to(device) for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.logits.sum() + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) if __name__ == "__main__": diff --git a/tests/benchmark/benchmark_bettertransformer_training_minimal.py b/tests/benchmark/benchmark_bettertransformer_training_minimal.py index 2efe3b8b22..1a7d8103af 100644 --- a/tests/benchmark/benchmark_bettertransformer_training_minimal.py +++ b/tests/benchmark/benchmark_bettertransformer_training_minimal.py @@ -4,7 +4,7 @@ import numpy as np import torch -from torch.profiler import profile +from benchmark_common import timing_cpu, timing_cuda from tqdm.auto import tqdm from transformers import AutoModelForCausalLM @@ -62,32 +62,17 @@ def benchmark_training(model, inputs: Dict, num_training_steps: int, use_cuda: b loss = outputs.logits.sum() loss.backward() - if use_cuda: - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - torch.cuda.reset_peak_memory_stats(device) - torch.cuda.empty_cache() - - torch.cuda.synchronize() - start_event.record() + def benchmark_fn(): training_fn(inputs, model, num_training_steps, progress_bar) - end_event.record() - torch.cuda.synchronize() - - max_memory = torch.cuda.max_memory_allocated(device) - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_training_steps, max_memory + if use_cuda: + total_time, max_mem = timing_cuda(benchmark_fn, num_training_steps, device) + return total_time, max_mem # CPU profiling else: - with profile(activities=[torch.profiler.ProfilerActivity.CPU], profile_memory=True) as p: - training_fn(inputs, model, num_training_steps, progress_bar) - - elapsed_time = p.key_averages().self_cpu_time_total - max_memory = max([event.cpu_memory_usage for event in p.key_averages()]) - - return elapsed_time / num_training_steps, max_memory + total_time, max_mem = timing_cpu(benchmark_fn, num_training_steps) + return total_time, max_mem def training_fn(inputs, model, num_training_steps, progress_bar): diff --git a/tests/benchmark/benchmark_bettertransformer_vit.py b/tests/benchmark/benchmark_bettertransformer_vit.py index 8dbf2aa408..78cc576b7e 100644 --- a/tests/benchmark/benchmark_bettertransformer_vit.py +++ b/tests/benchmark/benchmark_bettertransformer_vit.py @@ -2,8 +2,8 @@ import requests import torch +from benchmark_common import timing_cpu, timing_cuda from PIL import Image -from torch.profiler import profile from transformers import AutoFeatureExtractor, AutoModel from optimum.bettertransformer import BetterTransformer @@ -47,26 +47,6 @@ def get_batch(batch_size, model_name): return input_features -def timing_cpu(model, num_batches, input_features): - with profile(activities=[torch.profiler.ProfilerActivity.CPU], profile_memory=True) as p: - inference_fn(input_features, model, num_batches) - - elapsed_time = p.key_averages().self_cpu_time_total - max([event.cpu_memory_usage for event in p.key_averages()]) - - return elapsed_time / num_batches - - -def timing_cuda(model, num_batches, input_features): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - inference_fn(input_features, model, num_batches) - end_event.record() - torch.cuda.synchronize() - return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches - - def inference_fn(input_features, model, num_batches): for _ in range(num_batches): _ = model(input_features) @@ -90,15 +70,22 @@ def benchmark(model_name, num_batches, batch_size, is_cuda, is_half): input_features = input_features.to(0) # Warmup - for _ in range(2): - _ = hf_model(input_features) - if is_cuda: - torch.cuda.synchronize() + inference_fn(input_features, hf_model, 2) + if is_cuda: + torch.cuda.synchronize() + + def hf_benchmark_fn(): + inference_fn(input_features, hf_model, num_batches) - timing_fn = timing_cuda if is_cuda else timing_cpu + def bt_benchmark_fn(): + inference_fn(input_features, bt_model, num_batches) - total_hf_time = timing_fn(hf_model, num_batches, input_features) - total_bt_time = timing_fn(bt_model, num_batches, input_features) + if is_cuda: + total_hf_time, _ = timing_cuda(hf_benchmark_fn, num_batches, torch.device(0)) + total_bt_time, _ = timing_cuda(bt_benchmark_fn, num_batches, torch.device(0)) + else: + total_hf_time, _ = timing_cpu(hf_benchmark_fn, num_batches) + total_bt_time, _ = timing_cpu(bt_benchmark_fn, num_batches) return total_bt_time, total_hf_time diff --git a/tests/benchmark/test_transformers_optimum_examples_parity.py b/tests/benchmark/test_transformers_optimum_examples_parity.py index d6ee0fe111..32e1ca7e25 100644 --- a/tests/benchmark/test_transformers_optimum_examples_parity.py +++ b/tests/benchmark/test_transformers_optimum_examples_parity.py @@ -278,6 +278,7 @@ def test_question_answering_parity(self): optimum_results["exact_match"], benchmark_results["evaluation"]["others"]["optimized"]["exact_match"] ) + # TODO shuffle issue is solved; fix this test @unittest.skip( "failing related to shuffle issue https://github.com/huggingface/datasets/issues/5145 , skip for now" )