Skip to content

Commit

Permalink
sd timing (tinygrad#1706)
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Aug 29, 2023
1 parent 8844a0a commit aa7c987
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
run: |
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
time python3 examples/stable_diffusion.py --noshow
time python3 examples/stable_diffusion.py --noshow --timing
- name: Run LLaMA
run: |
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
run: |
ln -s ~/tinygrad/weights/sd-v1-4.ckpt weights/sd-v1-4.ckpt
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
time DEBUG=1 python3 examples/stable_diffusion.py --noshow
time DEBUG=1 python3 examples/stable_diffusion.py --noshow --timing
- name: Run LLaMA
run: |
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
Expand Down
9 changes: 7 additions & 2 deletions examples/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, GlobalCounters
from tinygrad.ops import Device
from tinygrad.helpers import dtypes, GlobalCounters, Timing
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
Expand Down Expand Up @@ -566,6 +567,7 @@ def __init__(self):
parser.add_argument('--out', type=str, default=os.path.join(tempfile.gettempdir(), "rendered.png"), help="Output filename")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
parser.add_argument('--timing', action='store_true', help="Print timing per step")
args = parser.parse_args()

Tensor.no_grad = True
Expand Down Expand Up @@ -638,7 +640,10 @@ def do_step(latent, timestep):
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep))
latent = do_step(latent, Tensor([timestep]))
with Timing("step in ", enabled=args.timing):
latent = do_step(latent, Tensor([timestep]))
if args.timing: Device[Device.DEFAULT].synchronize()
del do_step

# upsample latent space to image with autoencoder
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)
Expand Down

0 comments on commit aa7c987

Please sign in to comment.