Skip to content

Commit

Permalink
Add wrap command (beta/proof of concept/WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Nov 6, 2023
1 parent 14d432b commit 9972cc3
Show file tree
Hide file tree
Showing 14 changed files with 179 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/dynamicprompts/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from dynamicprompts.commands.sequence_command import SequenceCommand
from dynamicprompts.commands.variant_command import VariantCommand, VariantOption
from dynamicprompts.commands.wildcard_command import WildcardCommand
from dynamicprompts.commands.wrap_command import WrapCommand

__all__ = [
"Command",
"LiteralCommand",
"SamplingMethod",
"SequenceCommand",
"VariantCommand",
"VariantOption",
"WildcardCommand",
"SamplingMethod",
"WrapCommand",
]
13 changes: 13 additions & 0 deletions src/dynamicprompts/commands/wrap_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

import dataclasses

from dynamicprompts.commands import Command
from dynamicprompts.enums import SamplingMethod


@dataclasses.dataclass(frozen=True)
class WrapCommand(Command):
wrapper: Command
inner: Command
sampling_method: SamplingMethod | None = None
3 changes: 3 additions & 0 deletions src/dynamicprompts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
DEFAULT_COMBO_JOINER = ","
MAX_IMAGES = 1000
DEFAULT_RANDOM = Random()

WILDCARD_PROMPT_WRAP_TEXT = "{prompt}"
WRAP_MARKER = "\u2709\u2709\u2709" # Triple envelope
2 changes: 2 additions & 0 deletions src/dynamicprompts/parser/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class ParserConfig:
wildcard_wrap: str = "__"
variable_start: str = "${"
variable_end: str = "}"
wrap_start: str = "%{"
wrap_end: str = "}"


default_parser_config = ParserConfig()
55 changes: 51 additions & 4 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<variant_literal_sequence> ::= <variant_literal>+
<variable_assignment> ::= "${" <variable_name> "=" <variant_chunk> "}"
<variable_access> ::= "${" <variable_name> (":" <variant_chunk>)? "}"
<wrap_command> ::= "%{" <variant_chunk> "$$" <variant_chunk> "}"
Note that whitespace is preserved in case it is significant to the user.
"""
Expand All @@ -44,6 +45,7 @@
VariantCommand,
VariantOption,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.commands.variable_commands import (
VariableAccessCommand,
Expand All @@ -62,6 +64,8 @@
sampler_cyclical = pp.Char("@")
sampler_symbol = sampler_random | sampler_combinatorial | sampler_cyclical

variant_delim = pp.Suppress("$$")

OPT_WS = pp.Opt(pp.White()) # Optional whitespace

var_name = pp.Word(pp.alphas + "_-", pp.alphanums + "_-")
Expand All @@ -75,7 +79,6 @@

def _configure_range() -> pp.ParserElement:
hyphen = pp.Suppress("-")
variant_delim = pp.Suppress("$$")

# Exclude:
# - $, which is used to indicate the end of the separator definition i.e. {1$$ and $$X|Y|Z}
Expand Down Expand Up @@ -136,7 +139,12 @@ def _configure_literal_sequence(
# - { denotes the start of a variant (or whatever variant_start is set to )
# - # denotes the start of a comment
# - $ denotes the start of a variable command (or whatever variable_start is set to)
non_literal_chars = rf"#{parser_config.variant_start}{parser_config.variable_start}"
# - % denotes the start of a wrap command (or whatever wrap_start is set to)
non_literal_chars = (
rf"#{parser_config.variant_start}"
rf"{parser_config.variable_start}"
rf"{parser_config.wrap_start}"
)

if is_variant_literal:
# Inside a variant the following characters are also not allowed
Expand Down Expand Up @@ -227,6 +235,23 @@ def _configure_variable_assignment(
return variable_assignment.leave_whitespace()


def _configure_wrap_command(
parser_config: ParserConfig,
prompt: pp.ParserElement,
) -> pp.ParserElement:
wrap_command = pp.Group(
pp.Suppress(parser_config.wrap_start)
+ OPT_WS
+ prompt()("wrapper")
+ OPT_WS
+ variant_delim
+ OPT_WS
+ prompt()("inner")
+ pp.Suppress(parser_config.wrap_end),
)
return wrap_command.leave_whitespace()


def _parse_literal_command(parse_result: pp.ParseResults) -> LiteralCommand:
s = " ".join(parse_result)
return LiteralCommand(s)
Expand Down Expand Up @@ -365,6 +390,16 @@ def _parse_variable_assignment_command(
)


def _parse_wrap_command(
parse_result: pp.ParseResults,
) -> WrapCommand:
parts = parse_result[0].as_dict()
return WrapCommand(
inner=parts["inner"],
wrapper=parts["wrapper"],
)


def create_parser(
*,
parser_config: ParserConfig,
Expand All @@ -382,6 +417,10 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)
wrap_command = _configure_wrap_command(
parser_config=parser_config,
prompt=variant_prompt,
)
wildcard = _configure_wildcard(parser_config=parser_config)
literal_sequence = _configure_literal_sequence(parser_config=parser_config)
variant_literal_sequence = _configure_literal_sequence(
Expand All @@ -395,9 +434,16 @@ def create_parser(
)

chunk = (
variable_assignment | variable_access | variants | wildcard | literal_sequence
variable_assignment
| variable_access
| wrap_command
| variants
| wildcard
| literal_sequence
)
variant_chunk = (
variable_access | wrap_command | variants | wildcard | variant_literal_sequence
)
variant_chunk = variable_access | variants | wildcard | variant_literal_sequence

prompt <<= pp.ZeroOrMore(chunk)("prompt")
variant_prompt <<= pp.ZeroOrMore(variant_chunk)("prompt")
Expand All @@ -417,6 +463,7 @@ def create_parser(
variable_assignment.set_parse_action(_parse_variable_assignment_command)
prompt.set_parse_action(_parse_sequence_or_single_command)
variant_prompt.set_parse_action(_parse_sequence_or_single_command)
wrap_command.set_parse_action(_parse_wrap_command)
return prompt


Expand Down
10 changes: 10 additions & 0 deletions src/dynamicprompts/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SequenceCommand,
VariantCommand,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.commands.variable_commands import (
VariableAccessCommand,
Expand Down Expand Up @@ -43,6 +44,8 @@ def generator_from_command(
)
if isinstance(command, VariableAccessCommand):
return self._get_variable(command, context)
if isinstance(command, WrapCommand):
return self._get_wrap(command, context)
return self._unsupported_command(command)

def _unsupported_command(self, command: Command) -> ResultGen:
Expand Down Expand Up @@ -100,3 +103,10 @@ def _get_variable(
return context.for_sampling_variable(variable).generator_from_command(
command_to_sample,
)

def _get_wrap(
self,
command: WrapCommand,
context: SamplingContext,
) -> ResultGen:
return self._unsupported_command(command)
7 changes: 7 additions & 0 deletions src/dynamicprompts/samplers/combinatorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SequenceCommand,
VariantCommand,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.command_collection import CommandCollection
Expand Down Expand Up @@ -158,3 +159,9 @@ def _get_literal(
context: SamplingContext,
) -> ResultGen:
yield SamplingResult(text=command.literal)

def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen:
for wrapper_result in context.sample_prompts(command.wrapper):
wrap = wrapper_result.as_wrapper()
for inner in context.sample_prompts(command.inner):
yield wrap(inner)
9 changes: 9 additions & 0 deletions src/dynamicprompts/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Command,
VariantCommand,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
Expand Down Expand Up @@ -124,3 +125,11 @@ def _get_wildcard(
while True:
value = next(gen)
yield from context.sample_prompts(value, 1)

def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen:
wrapper_gen = context.generator_from_command(command.wrapper)
inner_gen = context.generator_from_command(command.inner)
wrapper_result: SamplingResult
inner_result: SamplingResult
for wrapper_result, inner_result in zip(wrapper_gen, inner_gen):
yield wrapper_result.as_wrapper()(inner_result)
19 changes: 19 additions & 0 deletions src/dynamicprompts/sampling_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import dataclasses
from typing import Iterable

from dynamicprompts.constants import WRAP_MARKER


@dataclasses.dataclass(frozen=True)
class SamplingResult:
Expand All @@ -26,6 +28,23 @@ def whitespace_squashed(self) -> SamplingResult:

return dataclasses.replace(self, text=squash_whitespace(self.text))

def text_replaced(self, new_text: str) -> SamplingResult:
return dataclasses.replace(self, text=new_text)

def as_wrapper(self):
"""
Return a function that wraps a SamplingResult with this one,
partitioning this result's text along the wrap marker.
"""
prefix, _, suffix = str(self).partition(WRAP_MARKER)
prefix_res = self.text_replaced(prefix)
suffix_res = self.text_replaced(suffix)

def wrapper(inner: SamplingResult) -> SamplingResult:
return SamplingResult.joined([prefix_res, inner, suffix_res], separator="")

return wrapper

@classmethod
def joined(
cls,
Expand Down
20 changes: 20 additions & 0 deletions src/dynamicprompts/wildcards/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

import dataclasses
import re
from typing import Iterable

from dynamicprompts.constants import WILDCARD_PROMPT_WRAP_TEXT, WRAP_MARKER
from dynamicprompts.utils import removeprefix, removesuffix
from dynamicprompts.wildcards.item import WildcardItem


def clean_wildcard(wildcard: str, *, wildcard_wrap: str) -> str:
Expand Down Expand Up @@ -35,3 +39,19 @@ def combine_name_parts(*parts: str) -> str:
Combine and normalize tree node name parts.
"""
return "/".join(parts).strip("/")


def process_wrap_marker(value: str) -> str:
return str(value).replace(WILDCARD_PROMPT_WRAP_TEXT, WRAP_MARKER)


def process_wrap_markers(
values: Iterable[str | WildcardItem],
) -> Iterable[str | WildcardItem]:
for val in values:
if isinstance(val, str):
yield process_wrap_marker(val)
elif isinstance(val, WildcardItem):
yield dataclasses.replace(val, content=process_wrap_marker(val.content))
else:
raise NotImplementedError()
8 changes: 6 additions & 2 deletions src/dynamicprompts/wildcards/wildcard_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
WildcardTree,
build_tree_from_root_map,
)
from dynamicprompts.wildcards.utils import clean_wildcard
from dynamicprompts.wildcards.utils import (
clean_wildcard,
process_wrap_markers,
)
from dynamicprompts.wildcards.values import WildcardValues

if TYPE_CHECKING:
Expand Down Expand Up @@ -191,7 +194,8 @@ def _get_values(self, wildcard: str) -> WildcardValues:
", ".join(str(coll) for coll in rec_colls),
)

wildcards = list(values)
# TODO: this is an inelegant place to do this replacement (must be fixed before out-of-POC)
wildcards = list(process_wrap_markers(values))

if self.dedup_wildcards:
wildcards = list(dict.fromkeys(wildcards, None))
Expand Down
2 changes: 2 additions & 0 deletions tests/test_data/wildcards/wrappers.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Art Deco, {prompt}, sleek, geometric forms, art deco style
Pop Art, {prompt}, vivid colors, flat color, 2D, strong lines, Pop Art
31 changes: 31 additions & 0 deletions tests/test_wrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
from dynamicprompts.enums import SamplingMethod
from dynamicprompts.parser.parse import parse
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.wildcards import WildcardManager


@pytest.mark.parametrize(
"sampling_method",
[
SamplingMethod.COMBINATORIAL,
SamplingMethod.RANDOM,
],
)
def test_wrap(wildcard_manager: WildcardManager, sampling_method: SamplingMethod):
cmd = parse("%{__wrappers__$${fox|cow}}")
scon = SamplingContext(
default_sampling_method=sampling_method,
wildcard_manager=wildcard_manager,
)
seen = set()
for p in scon.sample_prompts(cmd):
seen.add(str(p))
if len(seen) == 4: # no matter the sampler, we should get 4 prompts
break
assert seen == {
"Art Deco, cow, sleek, geometric forms, art deco style",
"Art Deco, fox, sleek, geometric forms, art deco style",
"Pop Art, cow, vivid colors, flat color, 2D, strong lines, Pop Art",
"Pop Art, fox, vivid colors, flat color, 2D, strong lines, Pop Art",
}
4 changes: 3 additions & 1 deletion tests/wildcard/test_wildcardmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,17 @@ def test_hierarchy(wildcard_manager: WildcardManager):
"variant",
"weighted-animals/heavy",
"weighted-animals/light",
"wrappers",
}
assert set(root.collections) == {
"clothing", # from pantry YAML
"colors-cold", # .txt
"colors-warm", # .txt
"dupes", # .txt
"referencing-colors", # .txt
"shapes", # flat list YAML
"variant", # .txt
"dupes", # .txt
"wrappers", # .txt
}
assert set(root.child_nodes["animals"].collections) == {
"all-references",
Expand Down

0 comments on commit 9972cc3

Please sign in to comment.