Skip to content

Commit

Permalink
Load controlnet in fp8 if weights are in fp8.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Sep 21, 2024
1 parent 2d810b0 commit dc96a1a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 20 deletions.
47 changes: 29 additions & 18 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,19 +400,22 @@ def inference_memory_requirements(self, dtype):
def controlnet_config(sd, model_options={}):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)

supported_inference_dtypes = model_config.supported_inference_dtypes
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(sd)

supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)

unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)

controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes))
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)

operations = model_options.get("custom_operations", None)
if operations is None:
if manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)

offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
Expand Down Expand Up @@ -583,22 +586,30 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):

if controlnet_config is None:
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config

unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = comfy.utils.weight_dtype(controlnet_data)

if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]

if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)

unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)

load_device = comfy.model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = comfy.model_management.unet_dtype()
else:
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)

manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
if "custom_operations" in model_options:
controlnet_config["operations"] = model_options["custom_operations"]
if "dtype" in model_options:
controlnet_config["dtype"] = model_options["dtype"]
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)

controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
Expand Down
2 changes: 2 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,8 @@ def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())

def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
Expand Down
4 changes: 2 additions & 2 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,10 @@ def forward_comfy_cast_weights(self, input):
return torch.nn.functional.linear(input, weight, bias)


def pick_operations(weight_dtype, compute_dtype, load_device=None):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if args.fast and not disable_fast_fp8:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

4 comments on commit dc96a1a

@JorgeR81
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean load ControlNet in FP8 automatically, if Flux is loaded FP8 ?

What about GGUF models ?
People using Q4 may want ControlNet in FP8, but if you use Q8_0, you may want the full quality ControlNet model. 

Why not add an option to load ControlNet in FP8, in an "advanced" ControlNet loader node ?

@comfyanonymous
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loads the controlnet in fp8 if the controlnet weights are in fp8.

@JorgeR81
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loads the controlnet in fp8 if the controlnet weights are in fp8.

OK, thanks, this makes more sense.

But do you know any ControlNet models in FP8 ?
Maybe you could convert some of the larger ones.

Do you think ControlNet in Q8_0 could work too ?

@JorgeR81
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But do you know any ControlNet models in FP8 ?
Maybe you could convert some of the larger ones.

It's done, already :

https://huggingface.co/Kijai/flux-fp8/blob/main/flux_shakker_labs_union_pro-fp8_e4m3fn.safetensors

Please sign in to comment.