Skip to content

Commit

Permalink
[Bug #4529] Fix graph partial validation failure (#4588)
Browse files Browse the repository at this point in the history
Currently, if a graph partially fails validation (i.e. some outputs are
valid while others have links from missing nodes), the execution loop
could get an exception resulting in server lockup.

This isn't actually possible to reproduce via the default UI, but is a
potential issue for people using the API to construct invalid graphs.
  • Loading branch information
guill committed Aug 24, 2024
1 parent 07dcbc3 commit 6ab1e6f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
9 changes: 9 additions & 0 deletions comfy_execution/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
Expand All @@ -74,6 +76,8 @@ def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
Expand All @@ -87,6 +91,9 @@ def get_node_signature(self, dynprompt, node_id):
return to_hashable(signature)

def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
Expand All @@ -112,6 +119,8 @@ def get_ordered_ancestry(self, dynprompt, node_id):
return ancestors, order_mapping

def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:
Expand Down
23 changes: 21 additions & 2 deletions tests/inference/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,25 @@ def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"

def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
# We have multiple outputs. The first is invalid, but the second is valid
g.node("SaveImage", images=mix1.out(0))
g.node("SaveImage", images=mix2.out(0))
g.remove_node("removeme")

client.run(g)

# Add back in the missing node to make sure the error doesn't break the server
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
client.run(g)

def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Creating the nodes in this specific order previously caused a bug
Expand Down Expand Up @@ -450,8 +469,8 @@ def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)

output1 = g.node("PreviewImage", images=input1.out(0))
output2 = g.node("PreviewImage", images=input1.out(0))
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input1.out(0))

result = client.run(g)
images1 = result.get_images(output1)
Expand Down

0 comments on commit 6ab1e6f

Please sign in to comment.