Skip to content

Commit

Permalink
Group node fixes (#2259)
Browse files Browse the repository at this point in the history
* Prevent cleaning graph state on undo/redo

* Remove pause rendering due to LG bug

* Fix crash on disconnected internal reroutes

* Fix widget inputs being incorrect order and value

* Fix initial primitive values on connect

* basic support for basic rerouted converted inputs

* Populate primitive to reroute input

* dont crash on bad primitive links

* Fix convert to group changing control value

* reduce restrictions

* fix random crash in tests
  • Loading branch information
pythongosssss committed Dec 13, 2023
1 parent b454a67 commit 3900789
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 27 deletions.
134 changes: 129 additions & 5 deletions tests-ui/tests/groupNode.test.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />

const { start, createDefaultWorkflow } = require("../utils");
const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
const lg = require("../utils/litegraph");

describe("group node", () => {
Expand Down Expand Up @@ -273,7 +273,7 @@ describe("group node", () => {

let reroutes = [];
let prevNode = nodes.ckpt;
for(let i = 0; i < 5; i++) {
for (let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
Expand All @@ -283,7 +283,7 @@ describe("group node", () => {

const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());

group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
Expand Down Expand Up @@ -407,12 +407,18 @@ describe("group node", () => {
const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE);
const preview = ez.PreviewImage(decode.outputs[0]);

expect((await graph.toPrompt()).output).toEqual({
const output = {
[latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" },
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" },
[preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" },
});
};
expect((await graph.toPrompt()).output).toEqual(output);

// Ensure missing connections dont cause errors
group2.inputs.VAE.disconnect();
delete output[decode.id].inputs.vae;
expect((await graph.toPrompt()).output).toEqual(output);
});
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
Expand Down Expand Up @@ -673,6 +679,55 @@ describe("group node", () => {
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
test("correctly handles widget inputs", async () => {
const { ez, graph, app } = await start();
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];

const image = ez.LoadImage();
const scale1 = ez.ImageScaleBy(image.outputs[0]);
const scale2 = ez.ImageScaleBy(image.outputs[0]);
const preview1 = ez.PreviewImage(scale1.outputs[0]);
const preview2 = ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.upscale_method.value = upscaleMethods[1];
scale1.widgets.upscale_method.convertToInput();

const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
expect(group.inputs.length).toBe(3);
expect(group.inputs[0].input.type).toBe("IMAGE");
expect(group.inputs[1].input.type).toBe("IMAGE");
expect(group.inputs[2].input.type).toBe("COMBO");

// Ensure links are maintained
expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[2].connection).toBeFalsy();

// Ensure primitive gets correct type
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs[2]);
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
primitive.widgets.value.value = upscaleMethods[1];

await checkBeforeAndAfterReload(graph, async (r) => {
const scale1id = r ? `${group.id}:0` : scale1.id;
const scale2id = r ? `${group.id}:1` : scale2.id;
// Ensure widget value is applied to prompt
expect((await graph.toPrompt()).output).toStrictEqual({
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
[scale1id]: {
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[scale2id]: {
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
[preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
});
});
});
test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start();
const scale = ez.LatentUpscale();
Expand Down Expand Up @@ -846,4 +901,73 @@ describe("group node", () => {
expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
});
test("internal reroutes work with converted inputs and merge options", async () => {
const { ez, graph, app } = await start();
const vae = ez.VAELoader();
const latent = ez.EmptyLatentImage();
const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE);
const scale = ez.ImageScale(decode.outputs.IMAGE);
ez.PreviewImage(scale.outputs.IMAGE);

const r1 = ez.Reroute();
const r2 = ez.Reroute();

latent.widgets.width.value = 64;
latent.widgets.height.value = 128;

latent.widgets.width.convertToInput();
latent.widgets.height.convertToInput();
latent.widgets.batch_size.convertToInput();

scale.widgets.width.convertToInput();
scale.widgets.height.convertToInput();

r1.inputs[0].input.label = "hbw";
r1.outputs[0].connectTo(latent.inputs.height);
r1.outputs[0].connectTo(latent.inputs.batch_size);
r1.outputs[0].connectTo(scale.inputs.width);

r2.inputs[0].input.label = "wh";
r2.outputs[0].connectTo(latent.inputs.width);
r2.outputs[0].connectTo(scale.inputs.height);

const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]);

expect(group.inputs[0].input.type).toBe("VAE");
expect(group.inputs[1].input.type).toBe("INT");
expect(group.inputs[2].input.type).toBe("INT");

const p1 = ez.PrimitiveNode();
const p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(group.inputs[1]);
p2.outputs[0].connectTo(group.inputs[2]);

expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10

expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10

expect(p1.widgets.value.value).toBe(128);
expect(p2.widgets.value.value).toBe(64);

p1.widgets.value.value = 16;
p2.widgets.value.value = 32;

await checkBeforeAndAfterReload(graph, async (r) => {
const id = (v) => (r ? `${group.id}:` : "") + v;
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" },
[id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" },
[id(4)]: {
inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] },
class_type: "ImageScale",
},
5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" },
});
});
});
});
8 changes: 8 additions & 0 deletions tests-ui/utils/ezgraph.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ export class EzInput extends EzSlot {
this.input = input;
}

get connection() {
const link = this.node.node.inputs?.[this.index]?.link;
if (link == null) {
return null;
}
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
}

disconnect() {
this.node.node.disconnectInput(this.index);
}
Expand Down
9 changes: 9 additions & 0 deletions tests-ui/utils/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) {

return { ckpt, pos, neg, empty, sampler, decode, save };
}

export async function getNodeDefs() {
const { api } = require("../../web/scripts/api");
return api.getNodeDefs();
}

export async function getNodeDef(nodeId) {
return (await getNodeDefs())[nodeId];
}
Loading

0 comments on commit 3900789

Please sign in to comment.