diff --git a/src/dynamicprompts/commands/__init__.py b/src/dynamicprompts/commands/__init__.py index 26e0c3e..2e82c8d 100644 --- a/src/dynamicprompts/commands/__init__.py +++ b/src/dynamicprompts/commands/__init__.py @@ -3,13 +3,15 @@ from dynamicprompts.commands.sequence_command import SequenceCommand from dynamicprompts.commands.variant_command import VariantCommand, VariantOption from dynamicprompts.commands.wildcard_command import WildcardCommand +from dynamicprompts.commands.wrap_command import WrapCommand __all__ = [ "Command", "LiteralCommand", + "SamplingMethod", "SequenceCommand", "VariantCommand", "VariantOption", "WildcardCommand", - "SamplingMethod", + "WrapCommand", ] diff --git a/src/dynamicprompts/commands/wrap_command.py b/src/dynamicprompts/commands/wrap_command.py new file mode 100644 index 0000000..4f58d93 --- /dev/null +++ b/src/dynamicprompts/commands/wrap_command.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import dataclasses +import logging +import re + +from dynamicprompts.commands import Command +from dynamicprompts.enums import SamplingMethod + +log = logging.getLogger(__name__) + +WRAP_MARKER_CHARACTERS = { + "\u1801", # Mongolian ellipsis + "\u2026", # Horizontal ellipsis + "\u22EE", # Vertical ellipsis + "\u22EF", # Midline horizontal ellipsis + "\u22F0", # Up right diagonal ellipsis + "\u22F1", # Down right diagonal ellipsis + "\uFE19", # Presentation form for vertical horizontal ellipsis +} + +WRAP_MARKER_RE = re.compile( + f"[{''.join(WRAP_MARKER_CHARACTERS)}]+" # One or more wrap marker characters + "|" + r"\.{3,}", # ASCII ellipsis of 3 or more dots +) + + +def split_wrapper_string(s: str) -> tuple[str, str]: + """ + Split a string into a prefix and suffix at the first wrap marker. + """ + match = WRAP_MARKER_RE.search(s) + if match is None: + log.warning("Found no wrap marker in string %r", s) + return s, "" + else: + return s[: match.start()], s[match.end() :] + + +@dataclasses.dataclass(frozen=True) +class WrapCommand(Command): + wrapper: Command + inner: Command + sampling_method: SamplingMethod | None = None diff --git a/src/dynamicprompts/parser/config.py b/src/dynamicprompts/parser/config.py index bed26f5..a46721f 100644 --- a/src/dynamicprompts/parser/config.py +++ b/src/dynamicprompts/parser/config.py @@ -10,6 +10,8 @@ class ParserConfig: wildcard_wrap: str = "__" variable_start: str = "${" variable_end: str = "}" + wrap_start: str = "%{" + wrap_end: str = "}" default_parser_config = ParserConfig() diff --git a/src/dynamicprompts/parser/parse.py b/src/dynamicprompts/parser/parse.py index 50829d8..e727762 100644 --- a/src/dynamicprompts/parser/parse.py +++ b/src/dynamicprompts/parser/parse.py @@ -23,6 +23,7 @@ ::= + ::= "${" "=" "}" ::= "${" (":" )? "}" + ::= "%{" "$$" "}" Note that whitespace is preserved in case it is significant to the user. """ @@ -44,6 +45,7 @@ VariantCommand, VariantOption, WildcardCommand, + WrapCommand, ) from dynamicprompts.commands.variable_commands import ( VariableAccessCommand, @@ -62,6 +64,8 @@ sampler_cyclical = pp.Char("@") sampler_symbol = sampler_random | sampler_combinatorial | sampler_cyclical +variant_delim = pp.Suppress("$$") + OPT_WS = pp.Opt(pp.White()) # Optional whitespace var_name = pp.Word(pp.alphas + "_-", pp.alphanums + "_-") @@ -75,7 +79,6 @@ def _configure_range() -> pp.ParserElement: hyphen = pp.Suppress("-") - variant_delim = pp.Suppress("$$") # Exclude: # - $, which is used to indicate the end of the separator definition i.e. {1$$ and $$X|Y|Z} @@ -136,7 +139,12 @@ def _configure_literal_sequence( # - { denotes the start of a variant (or whatever variant_start is set to ) # - # denotes the start of a comment # - $ denotes the start of a variable command (or whatever variable_start is set to) - non_literal_chars = rf"#{parser_config.variant_start}{parser_config.variable_start}" + # - % denotes the start of a wrap command (or whatever wrap_start is set to) + non_literal_chars = ( + rf"#{parser_config.variant_start}" + rf"{parser_config.variable_start}" + rf"{parser_config.wrap_start}" + ) if is_variant_literal: # Inside a variant the following characters are also not allowed @@ -227,6 +235,23 @@ def _configure_variable_assignment( return variable_assignment.leave_whitespace() +def _configure_wrap_command( + parser_config: ParserConfig, + prompt: pp.ParserElement, +) -> pp.ParserElement: + wrap_command = pp.Group( + pp.Suppress(parser_config.wrap_start) + + OPT_WS + + prompt()("wrapper") + + OPT_WS + + variant_delim + + OPT_WS + + prompt()("inner") + + pp.Suppress(parser_config.wrap_end), + ) + return wrap_command.leave_whitespace() + + def _parse_literal_command(parse_result: pp.ParseResults) -> LiteralCommand: s = " ".join(parse_result) return LiteralCommand(s) @@ -365,6 +390,16 @@ def _parse_variable_assignment_command( ) +def _parse_wrap_command( + parse_result: pp.ParseResults, +) -> WrapCommand: + parts = parse_result[0].as_dict() + return WrapCommand( + inner=parts["inner"], + wrapper=parts["wrapper"], + ) + + def create_parser( *, parser_config: ParserConfig, @@ -382,6 +417,10 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) + wrap_command = _configure_wrap_command( + parser_config=parser_config, + prompt=variant_prompt, + ) wildcard = _configure_wildcard(parser_config=parser_config) literal_sequence = _configure_literal_sequence(parser_config=parser_config) variant_literal_sequence = _configure_literal_sequence( @@ -395,9 +434,16 @@ def create_parser( ) chunk = ( - variable_assignment | variable_access | variants | wildcard | literal_sequence + variable_assignment + | variable_access + | wrap_command + | variants + | wildcard + | literal_sequence + ) + variant_chunk = ( + variable_access | wrap_command | variants | wildcard | variant_literal_sequence ) - variant_chunk = variable_access | variants | wildcard | variant_literal_sequence prompt <<= pp.ZeroOrMore(chunk)("prompt") variant_prompt <<= pp.ZeroOrMore(variant_chunk)("prompt") @@ -417,6 +463,7 @@ def create_parser( variable_assignment.set_parse_action(_parse_variable_assignment_command) prompt.set_parse_action(_parse_sequence_or_single_command) variant_prompt.set_parse_action(_parse_sequence_or_single_command) + wrap_command.set_parse_action(_parse_wrap_command) return prompt diff --git a/src/dynamicprompts/samplers/base.py b/src/dynamicprompts/samplers/base.py index 1399302..a01f248 100644 --- a/src/dynamicprompts/samplers/base.py +++ b/src/dynamicprompts/samplers/base.py @@ -8,6 +8,7 @@ SequenceCommand, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.commands.variable_commands import ( VariableAccessCommand, @@ -43,6 +44,8 @@ def generator_from_command( ) if isinstance(command, VariableAccessCommand): return self._get_variable(command, context) + if isinstance(command, WrapCommand): + return self._get_wrap(command, context) return self._unsupported_command(command) def _unsupported_command(self, command: Command) -> ResultGen: @@ -100,3 +103,10 @@ def _get_variable( return context.for_sampling_variable(variable).generator_from_command( command_to_sample, ) + + def _get_wrap( + self, + command: WrapCommand, + context: SamplingContext, + ) -> ResultGen: + return self._unsupported_command(command) diff --git a/src/dynamicprompts/samplers/combinatorial.py b/src/dynamicprompts/samplers/combinatorial.py index dfa5913..2e838ac 100644 --- a/src/dynamicprompts/samplers/combinatorial.py +++ b/src/dynamicprompts/samplers/combinatorial.py @@ -10,6 +10,7 @@ SequenceCommand, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.command_collection import CommandCollection @@ -158,3 +159,9 @@ def _get_literal( context: SamplingContext, ) -> ResultGen: yield SamplingResult(text=command.literal) + + def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen: + for wrapper_result in context.sample_prompts(command.wrapper): + wrap = wrapper_result.as_wrapper() + for inner in context.sample_prompts(command.inner): + yield wrap(inner) diff --git a/src/dynamicprompts/samplers/random.py b/src/dynamicprompts/samplers/random.py index c1070b7..8523606 100644 --- a/src/dynamicprompts/samplers/random.py +++ b/src/dynamicprompts/samplers/random.py @@ -8,6 +8,7 @@ Command, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( @@ -124,3 +125,11 @@ def _get_wildcard( while True: value = next(gen) yield from context.sample_prompts(value, 1) + + def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen: + wrapper_gen = context.generator_from_command(command.wrapper) + inner_gen = context.generator_from_command(command.inner) + wrapper_result: SamplingResult + inner_result: SamplingResult + for wrapper_result, inner_result in zip(wrapper_gen, inner_gen): + yield wrapper_result.as_wrapper()(inner_result) diff --git a/src/dynamicprompts/sampling_result.py b/src/dynamicprompts/sampling_result.py index 3603c6f..eb52a4d 100644 --- a/src/dynamicprompts/sampling_result.py +++ b/src/dynamicprompts/sampling_result.py @@ -3,6 +3,8 @@ import dataclasses from typing import Iterable +from dynamicprompts.commands.wrap_command import split_wrapper_string + @dataclasses.dataclass(frozen=True) class SamplingResult: @@ -26,6 +28,23 @@ def whitespace_squashed(self) -> SamplingResult: return dataclasses.replace(self, text=squash_whitespace(self.text)) + def text_replaced(self, new_text: str) -> SamplingResult: + return dataclasses.replace(self, text=new_text) + + def as_wrapper(self): + """ + Return a function that wraps a SamplingResult with this one, + partitioning this result's text along the wrap marker. + """ + prefix, suffix = split_wrapper_string(self.text) + prefix_res = self.text_replaced(prefix) + suffix_res = self.text_replaced(suffix) + + def wrapper(inner: SamplingResult) -> SamplingResult: + return SamplingResult.joined([prefix_res, inner, suffix_res], separator="") + + return wrapper + @classmethod def joined( cls, diff --git a/tests/test_data/wildcards/wrappers.txt b/tests/test_data/wildcards/wrappers.txt new file mode 100644 index 0000000..b4a23ea --- /dev/null +++ b/tests/test_data/wildcards/wrappers.txt @@ -0,0 +1,2 @@ +Art Deco, ..., sleek, geometric forms, art deco style +Pop Art, ....., vivid colors, flat color, 2D, strong lines, Pop Art diff --git a/tests/test_wrapping.py b/tests/test_wrapping.py new file mode 100644 index 0000000..ea365de --- /dev/null +++ b/tests/test_wrapping.py @@ -0,0 +1,66 @@ +import pytest +from dynamicprompts.enums import SamplingMethod +from dynamicprompts.parser.parse import parse +from dynamicprompts.sampling_context import SamplingContext +from dynamicprompts.wildcards import WildcardManager + +from tests.utils import sample_n + + +# Methods currently supported by wrap command +@pytest.fixture( + params=[ + SamplingMethod.COMBINATORIAL, + SamplingMethod.RANDOM, + ], +) +def scon(request, wildcard_manager: WildcardManager) -> SamplingContext: + return SamplingContext( + default_sampling_method=request.param, + wildcard_manager=wildcard_manager, + ) + + +def test_wrap_with_wildcard(scon: SamplingContext): + cmd = parse("%{__wrappers__$${fox|cow}}") + assert sample_n(cmd, scon, n=4) == { + "Art Deco, cow, sleek, geometric forms, art deco style", + "Art Deco, fox, sleek, geometric forms, art deco style", + "Pop Art, cow, vivid colors, flat color, 2D, strong lines, Pop Art", + "Pop Art, fox, vivid colors, flat color, 2D, strong lines, Pop Art", + } + + +@pytest.mark.parametrize("placeholder", ["…", "᠁", ".........", "..."]) +def test_wrap_with_literal(scon: SamplingContext, placeholder: str): + cmd = parse("%{happy ... on a meadow$${fox|cow}}".replace("...", placeholder)) + assert sample_n(cmd, scon, n=2) == { + "happy fox on a meadow", + "happy cow on a meadow", + } + + +def test_bad_wrap_is_prefix(scon: SamplingContext): + cmd = parse("%{happy $${fox|cow}}") + assert sample_n(cmd, scon, n=2) == { + "happy fox", + "happy cow", + } + + +def test_wrap_suffix(scon: SamplingContext): + cmd = parse("%{... in jail$${fox|cow}}") + assert sample_n(cmd, scon, n=2) == { + "fox in jail", + "cow in jail", + } + + +def test_wrap_with_variant(scon): + cmd = parse("%{ {cool|hot} ...$${fox|cow}}") + assert sample_n(cmd, scon, n=4) == { + "cool fox", + "cool cow", + "hot fox", + "hot cow", + } diff --git a/tests/utils.py b/tests/utils.py index e17a7f6..86ee8c1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,8 @@ from __future__ import annotations +from dynamicprompts.commands import Command +from dynamicprompts.sampling_context import SamplingContext + def cross(list1: list[str], list2: list[str], sep=",") -> list[str]: return [f"{x}{sep}{y}" for x in list1 for y in list2 if x != y] @@ -15,3 +18,15 @@ def interleave(list1: list[str], list2: list[str]) -> list[str]: new_list[1::2] = list2 return new_list + + +def sample_n(cmd: Command, scon: SamplingContext, n: int) -> set[str]: + """ + Sample until we have n unique prompts. + """ + seen = set() + for p in scon.sample_prompts(cmd): + seen.add(str(p)) + if len(seen) == n: + break + return seen diff --git a/tests/wildcard/test_wildcardmanager.py b/tests/wildcard/test_wildcardmanager.py index 7e7a7e4..0f19aca 100644 --- a/tests/wildcard/test_wildcardmanager.py +++ b/tests/wildcard/test_wildcardmanager.py @@ -141,15 +141,17 @@ def test_hierarchy(wildcard_manager: WildcardManager): "variant", "weighted-animals/heavy", "weighted-animals/light", + "wrappers", } assert set(root.collections) == { "clothing", # from pantry YAML "colors-cold", # .txt "colors-warm", # .txt + "dupes", # .txt "referencing-colors", # .txt "shapes", # flat list YAML "variant", # .txt - "dupes", # .txt + "wrappers", # .txt } assert set(root.child_nodes["animals"].collections) == { "all-references",