Skip to content

Commit

Permalink
Deduplicate commas in generated magic prompts (#749)
Browse files Browse the repository at this point in the history
Fixes #645
  • Loading branch information
akx authored Mar 27, 2024
1 parent 3280638 commit e8326b5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
14 changes: 13 additions & 1 deletion sd_dynamic_prompts/magic_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from itertools import zip_longest

from dynamicprompts.generators.magicprompt import MagicPromptGenerator
Expand All @@ -8,6 +9,14 @@
)


def massage_prompt(prompt: str) -> str:
# Coalesce repeated punctuation to a single instance
prompt = re.sub(r"([.,])\1+", r"\1", prompt)
# Remove leading/trailing whitespace
prompt = prompt.strip()
return prompt


class SpecialSyntaxAwareMagicPromptGenerator(MagicPromptGenerator):
"""
Magic Prompt generator that is aware of A1111 special syntax (LoRA, hypernet, etc.).
Expand All @@ -18,7 +27,10 @@ def _generate_magic_prompts(self, orig_prompts: list[str]) -> list[str]:
*(remove_a1111_special_syntax_chunks(p) for p in orig_prompts),
)
# `transformers` is rather particular that the input is a list, not a tuple
magic_prompts = super()._generate_magic_prompts(list(orig_prompts))
magic_prompts = [
massage_prompt(prompt)
for prompt in super()._generate_magic_prompts(list(orig_prompts))
]
# in case we somehow get less magic prompts than we started with,
# use zip_longest instead of zip.
return [
Expand Down
5 changes: 4 additions & 1 deletion tests/prompts/test_magic_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ def fake_generator(prompts, **_kwargs):
assert isinstance(prompts, list) # be as particular as transformers is
for prompt in prompts:
assert "<" not in prompt # should have been stripped
yield [{"generated_text": f"magical {prompt}"}]
yield [{"generated_text": f"magical {prompt},,,, wow, so nice"}]


def test_magic_prompts(monkeypatch):
Expand Down Expand Up @@ -30,3 +30,6 @@ def test_magic_prompts(monkeypatch):
assert "<hypernet:v18000Steps:1>" in prompt
# but we should expect to see some magic
assert prompt.startswith("magical ")
# See that multiple commas are coalesced
assert ",,,," not in prompt
assert ", wow, so nice" in prompt

0 comments on commit e8326b5

Please sign in to comment.