Skip to content

Commit

Permalink
add cpu support; refactor/cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
baskrahmer committed Jul 14, 2023
1 parent 41479b5 commit 20c204f
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 132 deletions.
15 changes: 1 addition & 14 deletions optimum/onnxruntime/runs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os
from pathlib import Path

Expand Down Expand Up @@ -56,20 +55,8 @@ def __init__(self, run_config):

processing_class = task_processing_map[self.task]
self.task_processor = processing_class(
dataset_path=run_config["dataset"]["path"],
dataset_name=run_config["dataset"]["name"],
calibration_split=run_config["dataset"]["calibration_split"],
eval_split=run_config["dataset"]["eval_split"],
preprocessor=self.preprocessor,
data_keys=run_config["dataset"]["data_keys"],
ref_keys=run_config["dataset"]["ref_keys"],
task_args=run_config["task_args"],
static_quantization=self.static_quantization,
num_calibration_samples=run_config["calibration"]["num_calibration_samples"]
if self.static_quantization
else None,
config=trfs_model.config,
max_eval_samples=run_config["max_eval_samples"],
config=run_config,
)

self.metric_names = run_config["metrics"]
Expand Down
2 changes: 1 addition & 1 deletion optimum/runs_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def launch_eval(self):

def load_datasets(self):
"""Load evaluation dataset, and if needed, calibration dataset for static quantization."""
datasets_dict = self.task_processor.load_datasets()
datasets_dict = self.task_processor.load_dataset()

self._eval_dataset = datasets_dict["eval"]
if self.static_quantization:
Expand Down
48 changes: 11 additions & 37 deletions tests/benchmark/benchmark_bettertransformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

import torch
from torch.profiler import profile
from benchmark_common import timing_cpu, timing_cuda
from tqdm import tqdm
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

Expand Down Expand Up @@ -107,34 +107,6 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s
return tokens, lengths, mask


def timing_cuda(model, num_batches, input_ids, masks, is_decoder, generation_config=None):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
torch.cuda.synchronize()

start_event.record()
inference_fn(generation_config, input_ids, is_decoder, masks, model, num_batches)

end_event.record()
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated(device)

return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches, max_memory


def timing_cpu(model, num_batches, input_ids, masks, is_decoder, generation_config=None):
with profile(activities=[torch.profiler.ProfilerActivity.CPU], profile_memory=True) as p:
inference_fn(generation_config, input_ids, is_decoder, masks, model, num_batches)

elapsed_time = p.key_averages().self_cpu_time_total
max_memory = max([event.cpu_memory_usage for event in p.key_averages()])

return elapsed_time / num_batches, max_memory


def inference_fn(generation_config, input_ids, is_decoder, masks, model, num_batches):
for _ in tqdm(range(num_batches)):
if is_decoder:
Expand All @@ -144,29 +116,31 @@ def inference_fn(generation_config, input_ids, is_decoder, masks, model, num_bat


def benchmark(model, input_ids, masks, num_batches, is_decoder, max_token, pad_token_id, use_cuda):
# Warmup
if is_decoder:
gen_config = GenerationConfig(
max_new_tokens=max_token,
min_new_tokens=max_token,
use_cache=True,
pad_token_id=pad_token_id,
)
_ = model.generate(input_ids, attention_mask=masks, generation_config=gen_config)

else:
_ = model(input_ids, masks)
gen_config = None

# Warmup
inference_fn(gen_config, input_ids, is_decoder, masks, model, num_batches=1)

if use_cuda:
torch.cuda.synchronize()

# benchmark
timing_fn = timing_cuda if args.use_cuda else timing_cpu
def benchmarked_fn():
inference_fn(gen_config, input_ids, is_decoder, masks, model, num_batches)

if is_decoder:
total_time, max_mem = timing_fn(model, num_batches, input_ids, masks, is_decoder, gen_config)
# benchmark
if args.use_cuda:
total_time, max_mem = timing_cuda(benchmarked_fn, num_batches, device)
else:
total_time, max_mem = timing_fn(model, num_batches, input_ids, masks, is_decoder)
total_time, max_mem = timing_cpu(benchmarked_fn, num_batches)

return total_time, max_mem

Expand Down
49 changes: 19 additions & 30 deletions tests/benchmark/benchmark_bettertransformer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import numpy as np
import torch
from benchmark_common import timing_cpu, timing_cuda
from datasets import load_dataset
from torch.profiler import profile
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
Expand Down Expand Up @@ -78,39 +78,28 @@ def benchmark_training(model, num_epochs: int, train_dataloader, device):
lr_scheduler.step()
optimizer.zero_grad()

def benchmark_fn():
training_fn(num_training_steps, batch, device, lr_scheduler, model, optimizer, progress_bar)

if device.type == "cpu":
with profile(activities=[torch.profiler.ProfilerActivity.CPU], profile_memory=True) as p:
for _ in range(num_training_steps):
training_fn(batch, device, lr_scheduler, model, optimizer, progress_bar)
total_time, max_mem = timing_cpu(benchmark_fn, num_training_steps)
return total_time, max_mem

elapsed_time = p.key_averages().self_cpu_time_total
max_memory = max([event.cpu_memory_usage for event in p.key_averages()])
else:
total_time, max_mem = timing_cuda(benchmark_fn, num_training_steps, device)
return total_time, max_mem

return elapsed_time / num_training_steps, max_memory

else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_epochs):
for _, batch in enumerate(train_dataloader):
training_fn(batch, device, lr_scheduler, model, optimizer, progress_bar)

end_event.record()
torch.cuda.synchronize()

return (start_event.elapsed_time(end_event) * 1.0e-3) / num_epochs


def training_fn(batch, device, lr_scheduler, model, optimizer, progress_bar):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.logits.sum()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
def training_fn(num_train_steps, batch, device, lr_scheduler, model, optimizer, progress_bar):
for _ in range(num_train_steps):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.logits.sum()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)


if __name__ == "__main__":
Expand Down
29 changes: 7 additions & 22 deletions tests/benchmark/benchmark_bettertransformer_training_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import torch
from torch.profiler import profile
from benchmark_common import timing_cpu, timing_cuda
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM

Expand Down Expand Up @@ -62,32 +62,17 @@ def benchmark_training(model, inputs: Dict, num_training_steps: int, use_cuda: b
loss = outputs.logits.sum()
loss.backward()

if use_cuda:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()

torch.cuda.synchronize()
start_event.record()
def benchmark_fn():
training_fn(inputs, model, num_training_steps, progress_bar)
end_event.record()
torch.cuda.synchronize()

max_memory = torch.cuda.max_memory_allocated(device)

return (start_event.elapsed_time(end_event) * 1.0e-3) / num_training_steps, max_memory
if use_cuda:
total_time, max_mem = timing_cuda(benchmark_fn, num_training_steps, device)
return total_time, max_mem

# CPU profiling
else:
with profile(activities=[torch.profiler.ProfilerActivity.CPU], profile_memory=True) as p:
training_fn(inputs, model, num_training_steps, progress_bar)

elapsed_time = p.key_averages().self_cpu_time_total
max_memory = max([event.cpu_memory_usage for event in p.key_averages()])

return elapsed_time / num_training_steps, max_memory
total_time, max_mem = timing_cpu(benchmark_fn, num_training_steps)
return total_time, max_mem


def training_fn(inputs, model, num_training_steps, progress_bar):
Expand Down
43 changes: 15 additions & 28 deletions tests/benchmark/benchmark_bettertransformer_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import requests
import torch
from benchmark_common import timing_cpu, timing_cuda
from PIL import Image
from torch.profiler import profile
from transformers import AutoFeatureExtractor, AutoModel

from optimum.bettertransformer import BetterTransformer
Expand Down Expand Up @@ -47,26 +47,6 @@ def get_batch(batch_size, model_name):
return input_features


def timing_cpu(model, num_batches, input_features):
with profile(activities=[torch.profiler.ProfilerActivity.CPU], profile_memory=True) as p:
inference_fn(input_features, model, num_batches)

elapsed_time = p.key_averages().self_cpu_time_total
max([event.cpu_memory_usage for event in p.key_averages()])

return elapsed_time / num_batches


def timing_cuda(model, num_batches, input_features):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
inference_fn(input_features, model, num_batches)
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / num_batches


def inference_fn(input_features, model, num_batches):
for _ in range(num_batches):
_ = model(input_features)
Expand All @@ -90,15 +70,22 @@ def benchmark(model_name, num_batches, batch_size, is_cuda, is_half):
input_features = input_features.to(0)

# Warmup
for _ in range(2):
_ = hf_model(input_features)
if is_cuda:
torch.cuda.synchronize()
inference_fn(input_features, hf_model, 2)
if is_cuda:
torch.cuda.synchronize()

def hf_benchmark_fn():
inference_fn(input_features, hf_model, num_batches)

timing_fn = timing_cuda if is_cuda else timing_cpu
def bt_benchmark_fn():
inference_fn(input_features, bt_model, num_batches)

total_hf_time = timing_fn(hf_model, num_batches, input_features)
total_bt_time = timing_fn(bt_model, num_batches, input_features)
if is_cuda:
total_hf_time, _ = timing_cuda(hf_benchmark_fn, num_batches, torch.device(0))
total_bt_time, _ = timing_cuda(bt_benchmark_fn, num_batches, torch.device(0))
else:
total_hf_time, _ = timing_cpu(hf_benchmark_fn, num_batches)
total_bt_time, _ = timing_cpu(bt_benchmark_fn, num_batches)

return total_bt_time, total_hf_time

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def test_question_answering_parity(self):
optimum_results["exact_match"], benchmark_results["evaluation"]["others"]["optimized"]["exact_match"]
)

# TODO shuffle issue is solved; fix this test
@unittest.skip(
"failing related to shuffle issue https://github.com/huggingface/datasets/issues/5145 , skip for now"
)
Expand Down

0 comments on commit 20c204f

Please sign in to comment.