Skip to content

Commit

Permalink
Allow bf16 computations on CPU with IPEX
Browse files Browse the repository at this point in the history
Modern CPUs have native AVX512 BF16 instructions, which significantly improves
matmul and conv2d operations. At this moment PyTorch has almost no native support
of these optimization (even with oneDNN it does not use optimal methods),
however IPEX adds everything needed.

There is a known issue with IPEX: it significantly reduces performance on AMD CPUs,
but such situations can be detected and mitigated, see https://documentation.sigma2.no/jobs/mkl.html

After mitigation, UNET steps are 40-50% faster on both AMD and Intel CPUs.
There are minor visible changes with bf16, but no avalanche effects, so this feature
is enabled by default with new `--autocast=auto` option.
It can be disabled with `--autocast=no` even if IPEX is installed and CPU is compatible.

Signed-off-by: Sv. Lockal <[email protected]>
  • Loading branch information
AngryLoki committed Jun 4, 2024
1 parent b1fd26f commit 9ae59d2
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 11 deletions.
7 changes: 6 additions & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __call__(self, parser, namespace, values, option_string=None):

parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs or CPUs.")

class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Expand Down Expand Up @@ -108,6 +108,11 @@ class LatentPreviewMethod(enum.Enum):
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")

class AutocastMode(enum.Enum):
Auto = "auto"
Yes = "yes"
No = "no"
parser.add_argument("--autocast", type=AutocastMode, default=AutocastMode.Auto, help="When CPU mode is enabled and IPEX is installed, use bf16 autocast to improve performance.", action=EnumAction)

parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
Expand Down
98 changes: 88 additions & 10 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args
from comfy.cli_args import args, AutocastMode
import torch
import sys
import platform
Expand All @@ -27,6 +27,7 @@ class CPUState(Enum):
total_vram = 0

lowvram_available = True
ipex_available = False
xpu_available = False

if args.deterministic:
Expand All @@ -50,6 +51,7 @@ class CPUState(Enum):
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
ipex_available = True
except:
pass

Expand All @@ -62,13 +64,18 @@ class CPUState(Enum):

if args.cpu:
cpu_state = CPUState.CPU

def is_cpu_with_ipex():
return cpu_state == CPUState.CPU and ipex_available and not mkl_is_crippled()

def is_intel_xpu():
global cpu_state
global xpu_available
if cpu_state == CPUState.GPU:
if xpu_available:
return True
return cpu_state == CPUState.GPU and xpu_available

def use_cpu_autocast():
if args.autocast == AutocastMode.Auto:
return is_cpu_with_ipex()
if args.autocast == AutocastMode.Yes:
return cpu_state == CPUState.CPU and ipex_available
return False

def get_torch_device():
Expand Down Expand Up @@ -303,9 +310,13 @@ def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e

if cpu_state == CPUState.CPU:
torch.set_autocast_cpu_enabled(use_cpu_autocast())

if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
if (is_cpu_with_ipex() or is_intel_xpu()) and not args.disable_ipex_optimize:
ipex_dtype = torch.bfloat16 if cpu_state == CPUState.CPU and cpu_has_fast_bf16() else None
self.real_model = ipex.optimize(self.real_model.eval(), dtype=ipex_dtype, graph_mode=True, concat_linear=True)

self.weights_loaded = True
return self.real_model
Expand Down Expand Up @@ -828,10 +839,77 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma

return True

_mkl_is_crippled = None
_cpu_has_fast_bf16 = None

def _collect_cpu_info():
"""
Collect some flags, remember results.
Imports are deferred to reduce startup time.
"""
global _mkl_is_crippled, _cpu_has_fast_bf16

if _mkl_is_crippled is not None:
return

from cpuinfo import get_cpu_info

cpu_info = get_cpu_info()
is_intel = cpu_info.get('vendor_id_raw') == 'GenuineIntel'
has_avx512f = 'avx512f' in cpu_info['flags']
has_bf16 = 'avx512_bf16' in cpu_info['flags']

# All Intel CPUs are ok, non AVX512 CPUs are probably ok
if is_intel or not has_avx512f:
_mkl_is_crippled = False
_cpu_has_fast_bf16 = has_bf16
return

_mkl_is_crippled = True
_cpu_has_fast_bf16 = False

import os
if os.name == 'nt':
# non-intel MKL on Windows is always crippled and slow
return

# Search for preloaded symbol mkl_serv_intel_cpu_true.
# If CPU supports avx512_bf16, but symbol is not defined,
# MKL will be extremely slow.
import ctypes
import ctypes.util

try:
libdl = ctypes.CDLL(ctypes.util.find_library('dl'))
libdl.dlsym.restype = ctypes.c_void_p
libdl.dlsym.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
psym = libdl.dlsym(ctypes.c_void_p(0), b'mkl_serv_intel_cpu_true')

if psym:
psym_func = ctypes.CFUNCTYPE(ctypes.c_int)(psym)
_mkl_is_crippled = psym_func() != 1
_cpu_has_fast_bf16 = has_bf16 and not _mkl_is_crippled
except:
pass

if _mkl_is_crippled and not has_bf16:
logging.info("CPU supports avx512 (without bf16), but MKL degrades performance")
elif has_bf16 and not _cpu_has_fast_bf16:
logging.info("CPU supports avx512_bf16 instructions, but MKL degrades performance")


def mkl_is_crippled() -> bool:
_collect_cpu_info()
return _mkl_is_crippled

def cpu_has_fast_bf16() -> bool:
_collect_cpu_info()
return ipex_available and _cpu_has_fast_bf16

def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if is_device_cpu(device):
return cpu_has_fast_bf16()

if device is not None: #TODO not sure about mps bf16 support
if is_device_mps(device):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Pillow
scipy
tqdm
psutil
py-cpuinfo

#non essential dependencies:
kornia>=0.7.1
Expand Down

0 comments on commit 9ae59d2

Please sign in to comment.