diff --git a/comfy/model_management.py b/comfy/model_management.py index bcc93779223..edbe6a8a4bb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -668,6 +668,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo if bf16_supported and weight_dtype == torch.bfloat16: return None + fp16_supported = should_use_fp16(inference_device, prioritize_performance=True) for dt in supported_dtypes: if dt == torch.float16 and fp16_supported: return torch.float16