From 31b0f6f3d8034371e95024d6bba5c193db79bd9d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 11:10:00 -0500 Subject: [PATCH] UNET weights can now be stored in fp8. --fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats supported by pytorch. --- comfy/cldm/cldm.py | 4 ++-- comfy/cli_args.py | 5 ++++- comfy/controlnet.py | 16 ++++++++++++---- .../ldm/modules/diffusionmodules/openaimodel.py | 4 ++-- comfy/model_base.py | 13 ++++++++++++- comfy/model_management.py | 15 +++++++++++++++ 6 files changed, 47 insertions(+), 10 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 76a525b378a..bbe5891e691 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -283,7 +283,7 @@ def make_zero_conv(self, channels, operations=None): return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) def forward(self, x, hint, timesteps, context, y=None, **kwargs): - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) guided_hint = self.input_hint_block(hint, emb, context) @@ -295,7 +295,7 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs): assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for module, zero_conv in zip(self.input_blocks, self.zero_convs): if guided_hint is not None: h = module(h, emb, context) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 72fce10872f..58d0348028f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -55,7 +55,10 @@ def __call__(self, parser, namespace, values, option_string=None): fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") -parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group = parser.add_mutually_exclusive_group() +fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") +fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6dd99afdc77..5921e6b1d19 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,6 +1,7 @@ import torch import math import os +import contextlib import comfy.utils import comfy.model_management import comfy.model_detection @@ -147,24 +148,31 @@ def get_control(self, x_noisy, t, cond, batched_number): else: return None + dtype = self.control_model.dtype + if comfy.model_management.supports_dtype(self.device, dtype): + precision_scope = lambda a: contextlib.nullcontext(a) + else: + precision_scope = torch.autocast + dtype = torch.float32 + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] y = cond.get('y', None) if y is not None: - y = y.to(self.control_model.dtype) + y = y.to(dtype) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) + with precision_scope(comfy.model_management.get_autocast_device(self.device)): + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 855c3d1f4cd..12efd833c51 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -841,14 +841,14 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) diff --git a/comfy/model_base.py b/comfy/model_base.py index 253ea66673b..5bfcc391ded 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -5,6 +5,7 @@ import comfy.model_management import comfy.conds from enum import Enum +import contextlib from . import utils class ModelType(Enum): @@ -61,6 +62,13 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, trans context = c_crossattn dtype = self.get_dtype() + + if comfy.model_management.supports_dtype(xc.device, dtype): + precision_scope = lambda a: contextlib.nullcontext(a) + else: + precision_scope = torch.autocast + dtype = torch.float32 + xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() context = context.to(dtype) @@ -70,7 +78,10 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, trans if hasattr(extra, "to"): extra = extra.to(dtype) extra_conds[o] = extra - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + + with precision_scope(comfy.model_management.get_autocast_device(xc.device)): + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): diff --git a/comfy/model_management.py b/comfy/model_management.py index d4acd8950ca..18d15f9d064 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -459,6 +459,10 @@ def unet_inital_load_device(parameters, dtype): def unet_dtype(device=None, model_params=0): if args.bf16_unet: return torch.bfloat16 + if args.fp8_e4m3fn_unet: + return torch.float8_e4m3fn + if args.fp8_e5m2_unet: + return torch.float8_e5m2 if should_use_fp16(device=device, model_params=model_params): return torch.float16 return torch.float32 @@ -515,6 +519,17 @@ def get_autocast_device(dev): return dev.type return "cuda" +def supports_dtype(device, dtype): #TODO + if dtype == torch.float32: + return True + if torch.device("cpu") == device: + return False + if dtype == torch.float16: + return True + if dtype == torch.bfloat16: + return True + return False + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: