From 32a60a7bacc9f623099ef41dcc1b4a7a2d22f23d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 8 Sep 2024 09:31:41 -0400 Subject: [PATCH 1/3] Support onetrainer text encoder Flux lora. --- comfy/lora.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index ad951bbafa2..02c27bf07cf 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -207,6 +207,7 @@ def model_lora_keys_clip(model, key_map={}): text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False + clip_g_present = False for b in range(32): #TODO: clean up for c in LORA_CLIP_MAP: k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) @@ -230,6 +231,7 @@ def model_lora_keys_clip(model, key_map={}): k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: + clip_g_present = True if clip_l_present: lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base key_map[lora_key] = k @@ -245,9 +247,15 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): - if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora + if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora + t5_index = 1 + if clip_l_present: + t5_index += 1 + if clip_g_present: + t5_index += 1 + l_key = k[len("t5xxl.transformer."):-len(".weight")] - lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) + lora_key = "lora_te{}_{}".format(t5_index, l_key.replace(".", "_")) key_map[lora_key] = k elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] From a5da4d0b3e72bb3b6ceafd54c0684fbf95a9d81f Mon Sep 17 00:00:00 2001 From: guill Date: Sun, 8 Sep 2024 06:48:47 -0700 Subject: [PATCH 2/3] Fix error with ExecutionBlocker and OUTPUT_IS_LIST (#4836) This change resolves an error when a node with OUTPUT_IS_LIST=(True,) receives an ExecutionBlocker. I've also added a unit test for this case. --- execution.py | 8 +++++++- tests/inference/test_execution.py | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index e66aabbcc37..6c386341bfe 100644 --- a/execution.py +++ b/execution.py @@ -179,7 +179,13 @@ def merge_result_data(results, obj): # merge node execution results for i, is_list in zip(range(len(results[0])), output_is_list): if is_list: - output.append([x for o in results for x in o[i]]) + value = [] + for o in results: + if isinstance(o[i], ExecutionBlocker): + value.append(o[i]) + else: + value.extend(o[i]) + output.append(value) else: output.append([o[i] for o in results]) return output diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index c7daddeb636..3909ca68de9 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -496,3 +496,29 @@ def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilde assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert not result.did_run(test_node), "The execution should have been cached" + + # This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker + # as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node, + # only that one entry in the list is blocked. + def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder): + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0)) + int1 = g.node("StubInt", value=1) + int2 = g.node("StubInt", value=2) + int3 = g.node("StubInt", value=3) + int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0)) + compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==") + blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) + + list_output = g.node("TestMakeListNode", value1=blocker.out(0)) + output = g.node("PreviewImage", images=list_output.out(0)) + + result = client.run(g) + assert result.did_run(output), "The execution should have run" + images = result.get_images(output) + assert len(images) == 2, "Should have 2 images" + assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black" + assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black" From 9c5fca75f46f7b9f18c07385925f151a7629a94f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 8 Sep 2024 10:10:47 -0400 Subject: [PATCH 3/3] Fix lora issue. --- comfy/lora.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 02c27bf07cf..61979e5004a 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -248,15 +248,17 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora + l_key = k[len("t5xxl.transformer."):-len(".weight")] t5_index = 1 - if clip_l_present: - t5_index += 1 if clip_g_present: t5_index += 1 + if clip_l_present: + t5_index += 1 + if t5_index == 2: + key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux + t5_index += 1 - l_key = k[len("t5xxl.transformer."):-len(".weight")] - lora_key = "lora_te{}_{}".format(t5_index, l_key.replace(".", "_")) - key_map[lora_key] = k + key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))