Skip to content

Commit

Permalink
Add wrap command
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Nov 10, 2023
1 parent dcfc03d commit c10391a
Show file tree
Hide file tree
Showing 12 changed files with 232 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/dynamicprompts/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from dynamicprompts.commands.sequence_command import SequenceCommand
from dynamicprompts.commands.variant_command import VariantCommand, VariantOption
from dynamicprompts.commands.wildcard_command import WildcardCommand
from dynamicprompts.commands.wrap_command import WrapCommand

__all__ = [
"Command",
"LiteralCommand",
"SamplingMethod",
"SequenceCommand",
"VariantCommand",
"VariantOption",
"WildcardCommand",
"SamplingMethod",
"WrapCommand",
]
45 changes: 45 additions & 0 deletions src/dynamicprompts/commands/wrap_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

import dataclasses
import logging
import re

from dynamicprompts.commands import Command
from dynamicprompts.enums import SamplingMethod

log = logging.getLogger(__name__)

WRAP_MARKER_CHARACTERS = {
"\u1801", # Mongolian ellipsis
"\u2026", # Horizontal ellipsis
"\u22EE", # Vertical ellipsis
"\u22EF", # Midline horizontal ellipsis
"\u22F0", # Up right diagonal ellipsis
"\u22F1", # Down right diagonal ellipsis
"\uFE19", # Presentation form for vertical horizontal ellipsis
}

WRAP_MARKER_RE = re.compile(
f"[{''.join(WRAP_MARKER_CHARACTERS)}]+" # One or more wrap marker characters
"|"
r"\.{3,}", # ASCII ellipsis of 3 or more dots
)


def split_wrapper_string(s: str) -> tuple[str, str]:
"""
Split a string into a prefix and suffix at the first wrap marker.
"""
match = WRAP_MARKER_RE.search(s)
if match is None:
log.warning("Found no wrap marker in string %r", s)
return s, ""
else:
return s[: match.start()], s[match.end() :]


@dataclasses.dataclass(frozen=True)
class WrapCommand(Command):
wrapper: Command
inner: Command
sampling_method: SamplingMethod | None = None
2 changes: 2 additions & 0 deletions src/dynamicprompts/parser/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class ParserConfig:
wildcard_wrap: str = "__"
variable_start: str = "${"
variable_end: str = "}"
wrap_start: str = "%{"
wrap_end: str = "}"


default_parser_config = ParserConfig()
55 changes: 51 additions & 4 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<variant_literal_sequence> ::= <variant_literal>+
<variable_assignment> ::= "${" <variable_name> "=" <variant_chunk> "}"
<variable_access> ::= "${" <variable_name> (":" <variant_chunk>)? "}"
<wrap_command> ::= "%{" <variant_chunk> "$$" <variant_chunk> "}"
Note that whitespace is preserved in case it is significant to the user.
"""
Expand All @@ -44,6 +45,7 @@
VariantCommand,
VariantOption,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.commands.variable_commands import (
VariableAccessCommand,
Expand All @@ -62,6 +64,8 @@
sampler_cyclical = pp.Char("@")
sampler_symbol = sampler_random | sampler_combinatorial | sampler_cyclical

variant_delim = pp.Suppress("$$")

OPT_WS = pp.Opt(pp.White()) # Optional whitespace

var_name = pp.Word(pp.alphas + "_-", pp.alphanums + "_-")
Expand All @@ -75,7 +79,6 @@

def _configure_range() -> pp.ParserElement:
hyphen = pp.Suppress("-")
variant_delim = pp.Suppress("$$")

# Exclude:
# - $, which is used to indicate the end of the separator definition i.e. {1$$ and $$X|Y|Z}
Expand Down Expand Up @@ -136,7 +139,12 @@ def _configure_literal_sequence(
# - { denotes the start of a variant (or whatever variant_start is set to )
# - # denotes the start of a comment
# - $ denotes the start of a variable command (or whatever variable_start is set to)
non_literal_chars = rf"#{parser_config.variant_start}{parser_config.variable_start}"
# - % denotes the start of a wrap command (or whatever wrap_start is set to)
non_literal_chars = (
rf"#{parser_config.variant_start}"
rf"{parser_config.variable_start}"
rf"{parser_config.wrap_start}"
)

if is_variant_literal:
# Inside a variant the following characters are also not allowed
Expand Down Expand Up @@ -227,6 +235,23 @@ def _configure_variable_assignment(
return variable_assignment.leave_whitespace()


def _configure_wrap_command(
parser_config: ParserConfig,
prompt: pp.ParserElement,
) -> pp.ParserElement:
wrap_command = pp.Group(
pp.Suppress(parser_config.wrap_start)
+ OPT_WS
+ prompt()("wrapper")
+ OPT_WS
+ variant_delim
+ OPT_WS
+ prompt()("inner")
+ pp.Suppress(parser_config.wrap_end),
)
return wrap_command.leave_whitespace()


def _parse_literal_command(parse_result: pp.ParseResults) -> LiteralCommand:
s = " ".join(parse_result)
return LiteralCommand(s)
Expand Down Expand Up @@ -365,6 +390,16 @@ def _parse_variable_assignment_command(
)


def _parse_wrap_command(
parse_result: pp.ParseResults,
) -> WrapCommand:
parts = parse_result[0].as_dict()
return WrapCommand(
inner=parts["inner"],
wrapper=parts["wrapper"],
)


def create_parser(
*,
parser_config: ParserConfig,
Expand All @@ -382,6 +417,10 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)
wrap_command = _configure_wrap_command(
parser_config=parser_config,
prompt=variant_prompt,
)
wildcard = _configure_wildcard(parser_config=parser_config)
literal_sequence = _configure_literal_sequence(parser_config=parser_config)
variant_literal_sequence = _configure_literal_sequence(
Expand All @@ -395,9 +434,16 @@ def create_parser(
)

chunk = (
variable_assignment | variable_access | variants | wildcard | literal_sequence
variable_assignment
| variable_access
| wrap_command
| variants
| wildcard
| literal_sequence
)
variant_chunk = (
variable_access | wrap_command | variants | wildcard | variant_literal_sequence
)
variant_chunk = variable_access | variants | wildcard | variant_literal_sequence

prompt <<= pp.ZeroOrMore(chunk)("prompt")
variant_prompt <<= pp.ZeroOrMore(variant_chunk)("prompt")
Expand All @@ -417,6 +463,7 @@ def create_parser(
variable_assignment.set_parse_action(_parse_variable_assignment_command)
prompt.set_parse_action(_parse_sequence_or_single_command)
variant_prompt.set_parse_action(_parse_sequence_or_single_command)
wrap_command.set_parse_action(_parse_wrap_command)
return prompt


Expand Down
10 changes: 10 additions & 0 deletions src/dynamicprompts/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SequenceCommand,
VariantCommand,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.commands.variable_commands import (
VariableAccessCommand,
Expand Down Expand Up @@ -43,6 +44,8 @@ def generator_from_command(
)
if isinstance(command, VariableAccessCommand):
return self._get_variable(command, context)
if isinstance(command, WrapCommand):
return self._get_wrap(command, context)
return self._unsupported_command(command)

def _unsupported_command(self, command: Command) -> ResultGen:
Expand Down Expand Up @@ -100,3 +103,10 @@ def _get_variable(
return context.for_sampling_variable(variable).generator_from_command(
command_to_sample,
)

def _get_wrap(
self,
command: WrapCommand,
context: SamplingContext,
) -> ResultGen:
return self._unsupported_command(command)
7 changes: 7 additions & 0 deletions src/dynamicprompts/samplers/combinatorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SequenceCommand,
VariantCommand,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.command_collection import CommandCollection
Expand Down Expand Up @@ -158,3 +159,9 @@ def _get_literal(
context: SamplingContext,
) -> ResultGen:
yield SamplingResult(text=command.literal)

def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen:
for wrapper_result in context.sample_prompts(command.wrapper):
wrap = wrapper_result.as_wrapper()
for inner in context.sample_prompts(command.inner):
yield wrap(inner)
9 changes: 9 additions & 0 deletions src/dynamicprompts/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Command,
VariantCommand,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
Expand Down Expand Up @@ -124,3 +125,11 @@ def _get_wildcard(
while True:
value = next(gen)
yield from context.sample_prompts(value, 1)

def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen:
wrapper_gen = context.generator_from_command(command.wrapper)
inner_gen = context.generator_from_command(command.inner)
wrapper_result: SamplingResult
inner_result: SamplingResult
for wrapper_result, inner_result in zip(wrapper_gen, inner_gen):
yield wrapper_result.as_wrapper()(inner_result)
19 changes: 19 additions & 0 deletions src/dynamicprompts/sampling_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import dataclasses
from typing import Iterable

from dynamicprompts.commands.wrap_command import split_wrapper_string


@dataclasses.dataclass(frozen=True)
class SamplingResult:
Expand All @@ -26,6 +28,23 @@ def whitespace_squashed(self) -> SamplingResult:

return dataclasses.replace(self, text=squash_whitespace(self.text))

def text_replaced(self, new_text: str) -> SamplingResult:
return dataclasses.replace(self, text=new_text)

def as_wrapper(self):
"""
Return a function that wraps a SamplingResult with this one,
partitioning this result's text along the wrap marker.
"""
prefix, suffix = split_wrapper_string(self.text)
prefix_res = self.text_replaced(prefix)
suffix_res = self.text_replaced(suffix)

def wrapper(inner: SamplingResult) -> SamplingResult:
return SamplingResult.joined([prefix_res, inner, suffix_res], separator="")

return wrapper

@classmethod
def joined(
cls,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_data/wildcards/wrappers.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Art Deco, ..., sleek, geometric forms, art deco style
Pop Art, ....., vivid colors, flat color, 2D, strong lines, Pop Art
66 changes: 66 additions & 0 deletions tests/test_wrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest
from dynamicprompts.enums import SamplingMethod
from dynamicprompts.parser.parse import parse
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.wildcards import WildcardManager

from tests.utils import sample_n


# Methods currently supported by wrap command
@pytest.fixture(
params=[
SamplingMethod.COMBINATORIAL,
SamplingMethod.RANDOM,
],
)
def scon(request, wildcard_manager: WildcardManager) -> SamplingContext:
return SamplingContext(
default_sampling_method=request.param,
wildcard_manager=wildcard_manager,
)


def test_wrap_with_wildcard(scon: SamplingContext):
cmd = parse("%{__wrappers__$${fox|cow}}")
assert sample_n(cmd, scon, n=4) == {
"Art Deco, cow, sleek, geometric forms, art deco style",
"Art Deco, fox, sleek, geometric forms, art deco style",
"Pop Art, cow, vivid colors, flat color, 2D, strong lines, Pop Art",
"Pop Art, fox, vivid colors, flat color, 2D, strong lines, Pop Art",
}


@pytest.mark.parametrize("placeholder", ["…", "᠁", ".........", "..."])
def test_wrap_with_literal(scon: SamplingContext, placeholder: str):
cmd = parse("%{happy ... on a meadow$${fox|cow}}".replace("...", placeholder))
assert sample_n(cmd, scon, n=2) == {
"happy fox on a meadow",
"happy cow on a meadow",
}


def test_bad_wrap_is_prefix(scon: SamplingContext):
cmd = parse("%{happy $${fox|cow}}")
assert sample_n(cmd, scon, n=2) == {
"happy fox",
"happy cow",
}


def test_wrap_suffix(scon: SamplingContext):
cmd = parse("%{... in jail$${fox|cow}}")
assert sample_n(cmd, scon, n=2) == {
"fox in jail",
"cow in jail",
}


def test_wrap_with_variant(scon):
cmd = parse("%{ {cool|hot} ...$${fox|cow}}")
assert sample_n(cmd, scon, n=4) == {
"cool fox",
"cool cow",
"hot fox",
"hot cow",
}
15 changes: 15 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from dynamicprompts.commands import Command
from dynamicprompts.sampling_context import SamplingContext


def cross(list1: list[str], list2: list[str], sep=",") -> list[str]:
return [f"{x}{sep}{y}" for x in list1 for y in list2 if x != y]
Expand All @@ -15,3 +18,15 @@ def interleave(list1: list[str], list2: list[str]) -> list[str]:
new_list[1::2] = list2

return new_list


def sample_n(cmd: Command, scon: SamplingContext, n: int) -> set[str]:
"""
Sample until we have n unique prompts.
"""
seen = set()
for p in scon.sample_prompts(cmd):
seen.add(str(p))
if len(seen) == n:
break
return seen
Loading

0 comments on commit c10391a

Please sign in to comment.