Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flux support for Intel xpu #1162

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

add flux support for Intel xpu #1162

wants to merge 2 commits into from

Conversation

bedovyy
Copy link

@bedovyy bedovyy commented Aug 15, 2024

Intel XPU doesn't support float64, so change torch device to "cpu".
refered https://github.com/comfyanonymous/ComfyUI/blob/3f5939add69c2a8fea2b892a46a48c2937dc4128/comfy/ldm/flux/math.py#L15

I have tested on A770.
I also tested on RTX4090 and it doesn't seem performance drop (but not sure).

@rabidcopy
Copy link

rabidcopy commented Aug 17, 2024

I was able to get Flux to work on DirectML following the changes here. (GGUFs don't work, errors out on the de-quant with DirectML, but FP8 works for sure. Can't test NF4 on my AMD card.)
Stack trace before changes:

Moving model(s) has taken 0.03 seconds
  0%|                                                                                            | 0/4 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules_forge\main_thread.py", line 30, in work
    self.result = self.func(*self.args, **self.kwargs)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules\txt2img.py", line 110, in txt2img_function
    processed = processing.process_images(p)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules\processing.py", line 809, in process_images
    res = process_images_inner(p)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules\processing.py", line 952, in process_images_inner
    samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules\processing.py", line 1323, in sample
    samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules\sd_samplers_kdiffusion.py", line 234, in sample
    samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules\sd_samplers_common.py", line 272, in launch_sampling
    return func()
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules\sd_samplers_kdiffusion.py", line 234, in <lambda>
    samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\venv\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\k_diffusion\sampling.py", line 128, in sample_euler
    denoised = model(x, sigma_hat * s_in, **extra_args)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\modules\sd_samplers_cfg_denoiser.py", line 186, in forward
    denoised, cond_pred, uncond_pred = sampling_function(self, denoiser_params=denoiser_params, cond_scale=cond_scale, cond_composition=cond_composition)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\sampling\sampling_function.py", line 339, in sampling_function
    denoised, cond_pred, uncond_pred = sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options, seed, return_full=True)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\sampling\sampling_function.py", line 284, in sampling_function_inner
    cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\sampling\sampling_function.py", line 254, in calc_cond_uncond_batch
    output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\modules\k_model.py", line 45, in apply_model
    model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\nn\flux.py", line 402, in forward
    out = self.inner_forward(img, img_ids, context, txt_ids, timestep, y, guidance)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\nn\flux.py", line 370, in inner_forward
    pe = self.pe_embedder(ids)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\nn\flux.py", line 82, in forward
    [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\nn\flux.py", line 82, in <listcomp>
    [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
  File "C:\Users\user\Desktop\stable-diffusion-webui-forge\backend\nn\flux.py", line 22, in rope
    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
RuntimeError: The parameter is incorrect.
The parameter is incorrect.

Command:
.\webui.bat --attention-quad --disable-attention-upcast --directml 0 --skip-torch-cuda-test
Changes to flux.py that adds or if args.directml is not None: and from backend.args import args to imports.

diff --git a/backend/nn/flux.py b/backend/nn/flux.py
index 1ab874cc..c21f88e8 100644
--- a/backend/nn/flux.py
+++ b/backend/nn/flux.py
@@ -10,6 +10,8 @@ from torch import nn
 from einops import rearrange, repeat
 from backend.attention import attention_function
 from backend.utils import fp16_fix
+from backend import memory_management
+from backend.args import args
 
 
 def attention(q, k, v, pe):
@@ -19,11 +21,16 @@ def attention(q, k, v, pe):
 
 
 def rope(pos, dim, theta):
-    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+    if memory_management.is_device_mps(pos.device) or memory_management.is_intel_xpu() or args.directml is not None:
+        device = torch.device("cpu")
+    else:
+        device = pos.device
+
+    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
     omega = 1.0 / (theta ** scale)
 
     # out = torch.einsum("...n,d->...nd", pos, omega)
-    out = pos.unsqueeze(-1) * omega.unsqueeze(0)
+    out = pos.unsqueeze(-1).to(device) * omega.unsqueeze(0)
 
     cos_out = torch.cos(out)
     sin_out = torch.sin(out)
@@ -34,7 +41,7 @@ def rope(pos, dim, theta):
     b, n, d, _ = out.shape
     out = out.view(b, n, d, 2, 2)
 
-    return out.float()
+    return out.float().to(pos.device)
 
 
 def apply_rope(xq, xk, freqs_cis):

After the changes:

Moving model(s) has taken 0.03 seconds
100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:52<00:00, 58.02s/it]

Which is reasonably surprising considering this is on a RX 570 4GB.

@cocktailpeanut
Copy link

it works on macs too, thank you for including the mps handling code as well.

@bedovyy
Copy link
Author

bedovyy commented Aug 17, 2024

I have checked generation result, and the result is same as ComfyUI's.
It is different from RTX3060 though, and I think it is because of diffrent GPU calculation.

q4_0 q4_1
A770 Forge f_q4_0_fp16_a770 f_q4_1_fp16_a770
A770 ComfyUI q4_0_fp16_a770 q4_1_fp16_a770
RTX3060 Forge f_q4_0_fp16_3060 f_q4_1_fp16_3060
RTX3060 ComfyUI q4_0_fp16_3060 q4_1_fp16_3060

@cocktailpeanut
Copy link

Is there a reason this is on hold?

Would love to use this as the official version instead of having to fork just to run forge on a mac. Thank you!

@gameblabla
Copy link

gameblabla commented Aug 21, 2024

I had to update the torch version otherwise i would get the same error as #979 (i know this PR doesn't address this but thought to mention).

I did the flux fix somewhat differently : gameblabla@c9adf48
But the issue is that Pytorch reserves too much VRAM for itself (at least 7GB) so it runs out of memory on flux fp8.
(I'm guessing that's why you set it to CPU ?)

@bedovyy
Copy link
Author

bedovyy commented Aug 22, 2024

I had to update the torch version otherwise i would get the same error as #979 (i know this PR doesn't address this but thought to mention).

I did the flux fix somewhat differently : gameblabla@c9adf48 But the issue is that Pytorch reserves too much VRAM for itself (at least 7GB) so it runs out of memory on flux fp8. (I'm guessing that's why you set it to CPU ?)

I did change float32 first, just because A770 doesn't support float64. Then I thought there must be reason to set float64.
I don't know about torch or tensor things, but it maybe precision sensitive part or something.
So I checked ComfyUI's code and change it same as that.

By the way, fp8 is already about 12GB, and it consume about 3+GB, so it is almost run out of memory.
--disable-ipex-hijack option can make partial loading of diffusion model, so you may try it.

@redrum-llik
Copy link

Unfortunately, does not seem to solve the issue for me - no matter what, I still end up with

  File "C:\Stability Matrix\Packages\Forge\backend\nn\flux.py", line 407, in forward
    out = self.inner_forward(img, img_ids, context, txt_ids, timestep, y, guidance)
  File "C:\Stability Matrix\Packages\Forge\backend\nn\flux.py", line 364, in inner_forward
    img = self.img_in(img)
  File "C:\Stability Matrix\Packages\Forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Stability Matrix\Packages\Forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Stability Matrix\Packages\Forge\backend\operations.py", line 145, in forward
    return torch.nn.functional.linear(x, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4032x64 and 1x98304)
mat1 and mat2 shapes cannot be multiplied (4032x64 and 1x98304)

Windows, Intel Arc A770, default settings for "flux" preset, flux.py patched with the changes from this PR. Anyone else having this issue?

@bedovyy
Copy link
Author

bedovyy commented Aug 24, 2024

Unfortunately, does not seem to solve the issue for me - no matter what, I still end up with

  File "C:\Stability Matrix\Packages\Forge\backend\nn\flux.py", line 407, in forward
    out = self.inner_forward(img, img_ids, context, txt_ids, timestep, y, guidance)
  File "C:\Stability Matrix\Packages\Forge\backend\nn\flux.py", line 364, in inner_forward
    img = self.img_in(img)
  File "C:\Stability Matrix\Packages\Forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Stability Matrix\Packages\Forge\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Stability Matrix\Packages\Forge\backend\operations.py", line 145, in forward
    return torch.nn.functional.linear(x, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4032x64 and 1x98304)
mat1 and mat2 shapes cannot be multiplied (4032x64 and 1x98304)

Windows, Intel Arc A770, default settings for "flux" preset, flux.py patched with the changes from this PR. Anyone else having this issue?

I have checked mine with latest commit, but It was okay with flux1-dev-Q4_0.gguf, t5xxl_fp16, clip_l, ae.
As far as I know, A770 cannot use CPU offload, so if you use fp16 or fp8, try to add --disable-ipex-hijack option, or use quantized model.

here is my option by the way.

set COMMANDLINE_ARGS=--use-ipex --unet-in-bf16 --disable-xformers --always-low-vram

@redrum-llik
Copy link

As far as I know, A770 cannot use CPU offload, so if you use fp16 or fp8, try to add --disable-ipex-hijack option, or use quantized model.

Yeah, I ran the quantized one, but apparently this is not the culprit - I am having this "mat1 and mat2 shapes cannot be multiplied" exception for any model type now, and even between different installations (tried reforge and A1111 today, clean install w/torch-ipex 2.1.0 build). Can you please let me know what driver version you use with A770? I probably need to downgrade.

@congdm
Copy link

congdm commented Aug 25, 2024

Yeah, I ran the quantized one, but apparently this is not the culprit - I am having this "mat1 and mat2 shapes cannot be multiplied" exception for any model type now, and even between different installations (tried reforge and A1111 today, clean install w/torch-ipex 2.1.0 build). Can you please let me know what driver version you use with A770? I probably need to downgrade.

It runs fine for me, didn't encounter any "mat1 and mat2 shapes cannot be multiplied" error. Driver 32.0.101.5971

@bedovyy
Copy link
Author

bedovyy commented Aug 25, 2024

As far as I know, A770 cannot use CPU offload, so if you use fp16 or fp8, try to add --disable-ipex-hijack option, or use quantized model.

Yeah, I ran the quantized one, but apparently this is not the culprit - I am having this "mat1 and mat2 shapes cannot be multiplied" exception for any model type now, and even between different installations (tried reforge and A1111 today, clean install w/torch-ipex 2.1.0 build). Can you please let me know what driver version you use with A770? I probably need to downgrade.

I always use latest. intel extension for torch also latest, which is 2.1.40+xpu.
try basic generation, without any lora or no special option or additional extension, just simple positive.
I think mat errors is usually from incompatible option or extension.

If there's still problem, it seems better open issue with very detailed infomation (because there's no many people who use Intel GPU)

@redrum-llik
Copy link

redrum-llik commented Aug 25, 2024

I always use latest. intel extension for torch also latest, which is 2.1.40+xpu. try basic generation, without any lora or no special option or additional extension, just simple positive. I think mat errors is usually from incompatible option or extension.

If there's still problem, it seems better open issue with very detailed infomation (because there's no many people who use Intel GPU)

Thanks for the details! Tried 2.1.40 (patched similarly to Nuuuull releases, with Intel libs baked into wheels) but that did not change the behavior either. Still got no idea what is causing this issue, but I was able to workaround it with the following dirty patch:

diff --git "a/backend/operations.py" "b/backend/operations.py"
index 72cbfc0d..c06913b1 100644
--- "a/backend/operations.py"
+++ "b/backend/operations.py"
@@ -145,7 +145,8 @@ class ForgeOperations:
                     return torch.nn.functional.linear(x, weight, bias)
             else:
                 weight, bias = get_weight_and_bias(self)
-                return torch.nn.functional.linear(x, weight, bias)
+                with torch.autocast(device_type='xpu', dtype=torch.float16):
+                    return torch.nn.functional.linear(x, weight, bias)
 
     class Conv2d(torch.nn.Conv2d):

UPD: oh, and that was an issue present without any extensions or LORAs at all, just a simple prompt like "forest" on default settings, no hires fix, no controlnet, nothing at all.

@mirh
Copy link

mirh commented Sep 11, 2024

FP64 might be added in whatever they next environment version is?
https://www.intel.com/content/www/us/en/developer/articles/release-notes/gpu-dependencies-for-pytorch-release-notes.html#inpage-nav-3
EDIT: in retrospect this might just concern their newest data center architecture that should ship with native fp64, unsure. Why can't this just ship with the emulation flags enabled (assuming doubles aren't really needed to do anything critical)

@DenOfEquity
Copy link
Collaborator

Can any XPU users confirm if this simple change to the current line 22 works as expected:

    if pos.device.type == "mps" or pos.device.type == "xpu":

(i.e. using float32, rather than changing device).
I'm not sure that scale needs to be float64 at all; I get identical results after making it float32, on my old 1070. Anyway, minimal change is best, if it works.

@mirh
Copy link

mirh commented Sep 14, 2024

intel/torch-xpu-ops#628
intel/intel-xpu-backend-for-triton#1840
Ok I was starting to be skeptical my previous discovery was a fluke, but it does seem like they promise this should work even for regular consumer skus.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants