Skip to content

Commit

Permalink
Make strip_template_info a free function, have it return None to sign…
Browse files Browse the repository at this point in the history
…ify it modifies in-place
  • Loading branch information
akx committed Dec 7, 2023
1 parent 0645c94 commit dbdb49e
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 32 deletions.
12 changes: 8 additions & 4 deletions sd_dynamic_prompts/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from modules.generation_parameters_copypaste import parse_generation_parameters
from modules.script_callbacks import ImageSaveParams

from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver, PromptTemplates
from sd_dynamic_prompts.pnginfo_saver import (
PngInfoSaver,
PromptTemplates,
strip_template_info,
)
from sd_dynamic_prompts.prompt_writer import PromptWriter
from sd_dynamic_prompts.settings import on_ui_settings
from sd_dynamic_prompts.wildcards_tab import initialize as initialize_wildcards_tab
Expand Down Expand Up @@ -47,17 +51,17 @@ def on_save(image_save_params: ImageSaveParams) -> None:
script_callbacks.on_before_image_saved(on_save)


def register_on_infotext_pasted(pnginfo_saver: PngInfoSaver) -> None:
def register_on_infotext_pasted() -> None:
def on_infotext_pasted(infotext: str, parameters: dict[str, Any]) -> None:
new_parameters = {}
if "Prompt" in parameters and "Template:" in parameters["Prompt"]:
parameters = pnginfo_saver.strip_template_info(parameters)
strip_template_info(parameters)
new_parameters = parse_generation_parameters(parameters["Prompt"])
elif (
"Negative prompt" in parameters
and "Template:" in parameters["Negative prompt"]
):
parameters = pnginfo_saver.strip_template_info(parameters)
strip_template_info(parameters)
new_parameters = parse_generation_parameters(parameters["Negative prompt"])
new_parameters["Negative prompt"] = new_parameters["Prompt"]
new_parameters["Prompt"] = parameters["Prompt"]
Expand Down
2 changes: 1 addition & 1 deletion sd_dynamic_prompts/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self):

callbacks.register_pnginfo_saver(self._pnginfo_saver)
callbacks.register_prompt_writer(self._prompt_writer)
callbacks.register_on_infotext_pasted(self._pnginfo_saver)
callbacks.register_on_infotext_pasted()
callbacks.register_settings()
callbacks.register_wildcards_tab(self._wildcard_manager)

Expand Down
47 changes: 23 additions & 24 deletions sd_dynamic_prompts/pnginfo_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,29 @@
NEGATIVE_TEMPLATE_LABEL = "Negative Template"


def strip_template_info(parameters: dict[str, Any]) -> None:
if "Prompt" in parameters and f"{TEMPLATE_LABEL}:" in parameters["Prompt"]:
parameters["Prompt"] = (
parameters["Prompt"].split(f"{TEMPLATE_LABEL}:")[0].strip()
)
elif "Negative prompt" in parameters:
split_by = None
if (
f"\n{TEMPLATE_LABEL}:" in parameters["Negative prompt"]
and f"\n{NEGATIVE_TEMPLATE_LABEL}:" in parameters["Negative prompt"]
):
split_by = f"{TEMPLATE_LABEL}"
elif f"\n{NEGATIVE_TEMPLATE_LABEL}:" in parameters["Negative prompt"]:
split_by = f"\n{NEGATIVE_TEMPLATE_LABEL}:"
elif f"\n{TEMPLATE_LABEL}:" in parameters["Negative prompt"]:
split_by = f"\n{TEMPLATE_LABEL}:"

if split_by:
parameters["Negative prompt"] = (
parameters["Negative prompt"].split(split_by)[0].strip()
)


@dataclass
class PromptTemplates:
positive_template: str
Expand Down Expand Up @@ -41,27 +64,3 @@ def update_pnginfo(self, parameters: str, prompt_templates: PromptTemplates) ->
)

return parameters

def strip_template_info(self, parameters: dict[str, Any]) -> dict[str, Any]:
if "Prompt" in parameters and f"{TEMPLATE_LABEL}:" in parameters["Prompt"]:
parameters["Prompt"] = (
parameters["Prompt"].split(f"{TEMPLATE_LABEL}:")[0].strip()
)
elif "Negative prompt" in parameters:
split_by = None
if (
f"\n{TEMPLATE_LABEL}:" in parameters["Negative prompt"]
and f"\n{NEGATIVE_TEMPLATE_LABEL}:" in parameters["Negative prompt"]
):
split_by = f"{TEMPLATE_LABEL}"
elif f"\n{NEGATIVE_TEMPLATE_LABEL}:" in parameters["Negative prompt"]:
split_by = f"\n{NEGATIVE_TEMPLATE_LABEL}:"
elif f"\n{TEMPLATE_LABEL}:" in parameters["Negative prompt"]:
split_by = f"\n{TEMPLATE_LABEL}:"

if split_by:
parameters["Negative prompt"] = (
parameters["Negative prompt"].split(split_by)[0].strip()
)

return parameters
9 changes: 6 additions & 3 deletions tests/prompts/ui/test_pnginfo_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import pytest

from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver, PromptTemplates
from sd_dynamic_prompts.pnginfo_saver import (
PngInfoSaver,
PromptTemplates,
strip_template_info,
)


@pytest.fixture
Expand Down Expand Up @@ -72,7 +76,6 @@ def test_remove_template_from_infotext(
positive_prompt: str,
negative_prompt: str,
) -> None:
png_info_saver = PngInfoSaver()
if not negative_prompt:
basic_parameters["Prompt"] = build_parameters(positive_prompt, negative_prompt)
basic_parameters["Negative prompt"] = ""
Expand All @@ -83,7 +86,7 @@ def test_remove_template_from_infotext(
negative_prompt,
)

png_info_saver.strip_template_info(basic_parameters)
strip_template_info(basic_parameters)

if negative_prompt:
assert basic_parameters["Prompt"] == positive_prompt
Expand Down

0 comments on commit dbdb49e

Please sign in to comment.