Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimum_neuron sample with sd_15_512 #879

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/benchmark/pytorch/sd_15_512_optimum_neuron_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
os.environ["NEURON_FUSE_SOFTMAX"] = "1"
import time
import math

from optimum.neuron import NeuronStableDiffusionPipeline

# Specialized benchmarking class for stable diffusion.
def benchmark(n_runs, test_name, model, model_inputs):
if not isinstance(model_inputs, tuple):
model_inputs = (model_inputs,)

warmup_run = model(*model_inputs)

latency_collector = LatencyCollector()
# can't use register_forward_pre_hook or register_forward_hook because StableDiffusionPipeline is not a torch.nn.Module

for _ in range(n_runs):
latency_collector.pre_hook()
res = model(*model_inputs)
latency_collector.hook()

p0_latency_ms = latency_collector.percentile(0) * 1000
p50_latency_ms = latency_collector.percentile(50) * 1000
p90_latency_ms = latency_collector.percentile(90) * 1000
p95_latency_ms = latency_collector.percentile(95) * 1000
p99_latency_ms = latency_collector.percentile(99) * 1000
p100_latency_ms = latency_collector.percentile(100) * 1000

report_dict = dict()
report_dict["Latency P0"] = f'{p0_latency_ms:.1f}'
report_dict["Latency P50"]=f'{p50_latency_ms:.1f}'
report_dict["Latency P90"]=f'{p90_latency_ms:.1f}'
report_dict["Latency P95"]=f'{p95_latency_ms:.1f}'
report_dict["Latency P99"]=f'{p99_latency_ms:.1f}'
report_dict["Latency P100"]=f'{p100_latency_ms:.1f}'

report = f'RESULT FOR {test_name}:'
for key, value in report_dict.items():
report += f' {key}={value}'
print(report)

class LatencyCollector:
def __init__(self):
self.start = None
self.latency_list = []

def pre_hook(self, *args):
self.start = time.time()

def hook(self, *args):
self.latency_list.append(time.time() - self.start)

def percentile(self, percent):
latency_list = self.latency_list
pos_float = len(latency_list) * percent / 100
max_pos = len(latency_list) - 1
pos_floor = min(math.floor(pos_float), max_pos)
pos_ceil = min(math.ceil(pos_float), max_pos)
latency_list = sorted(latency_list)
return latency_list[pos_ceil] if pos_float - pos_floor > 0.5 else latency_list[pos_floor]

# # For saving compiler artifacts
COMPILER_WORKDIR_ROOT = 'sd_1_5_fp32_512_compile_workdir'
pipe = NeuronStableDiffusionPipeline.from_pretrained(COMPILER_WORKDIR_ROOT)

prompt = "a photo of an astronaut riding a horse on mars"
n_runs = 20
benchmark(n_runs, "stable_diffusion_15_512", pipe, prompt)
23 changes: 23 additions & 0 deletions src/benchmark/pytorch/sd_15_512_optimum_neuron_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import time
from optimum.neuron import NeuronStableDiffusionPipeline

# For saving compiler artifacts
COMPILER_WORKDIR_ROOT = 'sd_1_5_fp32_512_compile_workdir'

# Model ID for SD version pipeline
model_id = "runwayml/stable-diffusion-v1-5"

# Compilation config
compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16","inline_weights_to_neff": "True"}
input_shapes = {"batch_size": 1, "height": 512, "width": 512}

# --- Compile the model
start_time = time.time()
stable_diffusion = NeuronStableDiffusionPipeline.from_pretrained(model_id, export=True, **compiler_args, **input_shapes)

# Save the compiled model
stable_diffusion.save_pretrained(COMPILER_WORKDIR_ROOT)

compile_time = time.time() - start_time
print('Total compile time:', compile_time)