Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support graph control #3658

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@

import comfy.model_management

def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}, skipped={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in skipped:
return {}, True
if input_unique_id not in outputs:
input_data_all[x] = (None,)
continue
Expand All @@ -38,7 +40,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
return input_data_all
return input_data_all, False

def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
Expand Down Expand Up @@ -116,7 +118,7 @@ def format_value(x):
else:
return str(x)

def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, skipped, prompt_id, outputs_ui, object_storage):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
Expand All @@ -131,14 +133,17 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, skipped, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result

input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
input_data_all, is_skipped = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data, skipped)
if is_skipped:
skipped.add(unique_id)
return (True, None, None)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
Expand All @@ -149,6 +154,15 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
object_storage[(unique_id, class_type)] = obj

output_data, output_ui = get_output_data(obj, input_data_all)

is_subsequent_skipped = False
for values in output_data:
for value in values:
if value == nodes.GraphControlSignal.SKIP:
is_subsequent_skipped = True
if is_subsequent_skipped:
skipped.add(unique_id)

outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
Expand Down Expand Up @@ -373,6 +387,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)
executed = set()
skipped = set()
output_node_id = None
to_execute = []

Expand All @@ -388,7 +403,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, skipped, prompt_id, self.outputs_ui, self.object_storage)
if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break
Expand Down Expand Up @@ -577,7 +592,7 @@ def validate_inputs(prompt, item, validated):
continue

if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
(input_data_all, _) = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
Expand Down
32 changes: 32 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import random
import logging
import enum

from PIL import Image, ImageOps, ImageSequence, ImageFile
from PIL.PngImagePlugin import PngInfo
Expand Down Expand Up @@ -44,6 +45,12 @@ def interrupt_processing(value=True):

MAX_RESOLUTION=16384

class AnyType(str):
def __ne__(self, __value: object) -> bool:
return False

any_type = AnyType("*")

class CLIPTextEncode:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -1730,6 +1737,30 @@ def expand_image(self, image, left, top, right, bottom, feathering):
return (new_image, mask)


class GraphControlSignal(enum.Enum):
SKIP = 0
CONTINUE = 1

class GraphControl:
ALL_VALUES = [signal.name for signal in GraphControlSignal]
@classmethod
def INPUT_TYPES(s):
return {"required": {
"anything": (any_type, ),
"boolean": ("BOOLEAN", {"default": True}),
"true_signal": (s.ALL_VALUES, {"default": GraphControlSignal.CONTINUE.name}),
"false_signal": (s.ALL_VALUES, {"default": GraphControlSignal.SKIP.name}),
}}
RETURN_TYPES = (any_type, "GraphControlSignal", "BOOLEAN")
RETURN_NAMES = ("anything", "signal", "boolean")
FUNCTION = "main"

CATEGORY = "graph"

def main(self, anything, boolean, true_signal, false_signal):
signal = GraphControlSignal[true_signal if boolean else false_signal]
return (anything, signal, boolean)

NODE_CLASS_MAPPINGS = {
"KSampler": KSampler,
"CheckpointLoaderSimple": CheckpointLoaderSimple,
Expand Down Expand Up @@ -1797,6 +1828,7 @@ def expand_image(self, image, left, top, right, bottom, feathering):
"ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly,
"GraphControl": GraphControl,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand Down