From aa7c98722b328d0cc1c5a89587ae8fadae7dc07a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 28 Aug 2023 20:22:57 -0700 Subject: [PATCH] sd timing (#1706) --- .github/workflows/benchmark.yml | 4 ++-- examples/stable_diffusion.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 6cab94e5f282..96b32dac5be6 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -26,7 +26,7 @@ jobs: run: | ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz - time python3 examples/stable_diffusion.py --noshow + time python3 examples/stable_diffusion.py --noshow --timing - name: Run LLaMA run: | ln -s ~/tinygrad/weights/LLaMA weights/LLaMA @@ -72,7 +72,7 @@ jobs: run: | ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz - time DEBUG=1 python3 examples/stable_diffusion.py --noshow + time DEBUG=1 python3 examples/stable_diffusion.py --noshow --timing - name: Run LLaMA run: | ln -s ~/tinygrad/weights/LLaMA weights/LLaMA diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 87ce23a68c80..12acc5bcaa42 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -9,7 +9,8 @@ from tqdm import tqdm from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes, GlobalCounters +from tinygrad.ops import Device +from tinygrad.helpers import dtypes, GlobalCounters, Timing from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from extra.utils import download_file from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict @@ -566,6 +567,7 @@ def __init__(self): parser.add_argument('--out', type=str, default=os.path.join(tempfile.gettempdir(), "rendered.png"), help="Output filename") parser.add_argument('--noshow', action='store_true', help="Don't show the image") parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16") + parser.add_argument('--timing', action='store_true', help="Print timing per step") args = parser.parse_args() Tensor.no_grad = True @@ -638,7 +640,10 @@ def do_step(latent, timestep): for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])): GlobalCounters.reset() t.set_description("%3d %3d" % (index, timestep)) - latent = do_step(latent, Tensor([timestep])) + with Timing("step in ", enabled=args.timing): + latent = do_step(latent, Tensor([timestep])) + if args.timing: Device[Device.DEFAULT].synchronize() + del do_step # upsample latent space to image with autoencoder x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)