Skip to content

Commit

Permalink
Add Variable Support in Wildcard Paths (#110)
Browse files Browse the repository at this point in the history
Adds support for the issue described in adieyal/sd-dynamic-prompts#614.

The patch adds support for allowing variable refences within wildcard paths (e.g. `__cars/${car}/types__`). During the parsing phase, the variables will be retained in the path. During sampling, before resolving the wildcard, the variable is first resolved and replaced in the wildcard.

The update should allow support for variable replacement, including
providing default values using the normal `:default` syntax.

Some examples:

* `${gender=male}${section=upper} __clothes/${gender}/${section}__`
* `__clothes/${gender:female}/${section:lower}__`
* `__clothes/${gender:female}/${section:*}__`
  • Loading branch information
mwootendev authored Nov 23, 2023
1 parent 510d1e9 commit ff47441
Show file tree
Hide file tree
Showing 14 changed files with 324 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/dynamicprompts/commands/wildcard_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ class WildcardCommand(Command):
def __post_init__(self):
if not isinstance(self.wildcard, str):
raise TypeError(f"Wildcard must be a string, not {type(self.wildcard)}")

def with_content(self, content: str) -> WildcardCommand:
return dataclasses.replace(self, wildcard=content)
43 changes: 39 additions & 4 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@ def _configure_range() -> pp.ParserElement:

def _configure_wildcard(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
) -> pp.ParserElement:
wildcard_path_re = r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^({}#])+"
wildcard_path = pp.Regex(wildcard_path_re)("path").leave_whitespace()
wildcard = _configure_wildcard_path(
parser_config=parser_config,
variable_ref=variable_ref,
)

wildcard_enclosure = pp.Suppress(parser_config.wildcard_wrap)
wildcard_variable_spec = (
OPT_WS
Expand All @@ -123,14 +127,25 @@ def _configure_wildcard(
wildcard = (
wildcard_enclosure
+ pp.Opt(sampler_symbol)("sampling_method")
+ wildcard_path
+ wildcard
+ pp.Opt(wildcard_variable_spec)
+ wildcard_enclosure
)

return wildcard("wildcard").leave_whitespace()


def _configure_wildcard_path(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
) -> pp.ParserElement:
wildcard_path_literal_re = (
r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+"
)
wildcard_path = pp.Regex(wildcard_path_literal_re).leave_whitespace()
return pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path")


def _configure_literal_sequence(
parser_config: ParserConfig,
is_variant_literal: bool = False,
Expand Down Expand Up @@ -413,6 +428,23 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)

variable_ref = pp.Combine(
_configure_variable_access(
parser_config=parser_config,
prompt=pp.SkipTo(parser_config.variable_end),
),
).leave_whitespace()

# by default, the variable starting "${" and end "}" characters are
# stripped with variable access. This restores them so that the
# variable can be properly recognized at a later stage
variable_ref.set_parse_action(
lambda string,
location,
token: f"{parser_config.variable_start}{''.join(token)}{parser_config.variable_end}",
)

variable_assignment = _configure_variable_assignment(
parser_config=parser_config,
prompt=variant_prompt,
Expand All @@ -421,7 +453,10 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)
wildcard = _configure_wildcard(parser_config=parser_config)
wildcard = _configure_wildcard(
parser_config=parser_config,
variable_ref=variable_ref,
)
literal_sequence = _configure_literal_sequence(parser_config=parser_config)
variant_literal_sequence = _configure_literal_sequence(
is_variant_literal=True,
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/combinatorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dynamicprompts.samplers.command_collection import CommandCollection
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -143,6 +144,7 @@ def _get_wildcard(
context: SamplingContext,
) -> ResultGen:
# TODO: doesn't support weights
command = replace_wildcard_variables(command=command, context=context)
context = context.with_variables(command.variables)
values = context.wildcard_manager.get_values(command.wildcard)
if not values:
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -100,6 +101,7 @@ def _get_wildcard(
context: SamplingContext,
) -> ResultGen:
# TODO: doesn't support weights
command = replace_wildcard_variables(command=command, context=context)
wc_values = context.wildcard_manager.get_values(command.wildcard)
new_context = context.with_variables(
command.variables,
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -114,6 +115,7 @@ def _get_wildcard(
command: WildcardCommand,
context: SamplingContext,
) -> ResultGen:
command = replace_wildcard_variables(command=command, context=context)
context = context.with_variables(command.variables)
values = context.wildcard_manager.get_values(command.wildcard)

Expand Down
73 changes: 71 additions & 2 deletions src/dynamicprompts/samplers/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from __future__ import annotations

import logging
from functools import partial

from dynamicprompts.commands import VariantCommand, VariantOption, WildcardCommand
from dynamicprompts.parser.parse import parse
import pyparsing as pp

from dynamicprompts.commands import (
LiteralCommand,
VariantCommand,
VariantOption,
WildcardCommand,
)
from dynamicprompts.parser.parse import (
_configure_variable_access,
_configure_wildcard_path,
parse,
)
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
from dynamicprompts.types import ResultGen
Expand All @@ -19,6 +31,7 @@ def wildcard_to_variant(
max_bound=1,
separator=",",
) -> VariantCommand:
command = replace_wildcard_variables(command=command, context=context)
values = context.wildcard_manager.get_values(command.wildcard)
min_bound = min(min_bound, len(values))
max_bound = min(max_bound, len(values))
Expand Down Expand Up @@ -50,3 +63,59 @@ def get_wildcard_not_found_fallback(
res = SamplingResult(text=wrapped_wildcard)
while True:
yield res


def replace_wildcard_variables(
command: WildcardCommand,
*,
context: SamplingContext,
) -> WildcardCommand:
if context.parser_config.variable_start not in command.wildcard:
return command

prompt = pp.SkipTo(context.parser_config.variable_end)
variable_access = _configure_variable_access(
parser_config=context.parser_config,
prompt=prompt,
)
variable_access.set_parse_action(
partial(_replace_variable, variables=context.variables),
)
wildcard = _configure_wildcard_path(
parser_config=context.parser_config,
variable_ref=variable_access,
)

try:
wildcard_result = wildcard.parse_string(command.wildcard)
return command.with_content("".join(wildcard_result))
except Exception:
logger.warning("Unable to parse wildcard %r", command.wildcard, exc_info=True)
return command


def _replace_variable(string, location, token, *, variables: dict):
if isinstance(token, pp.ParseResults):
var_parts = token[0].as_dict()
var_name = var_parts.get("name")
if var_name:
var_name = var_name.strip()

default = var_parts.get("default")
if default:
default = default.strip()

else:
var_name = token

variable = None
if var_name:
variable = variables.get(var_name)

if isinstance(variable, LiteralCommand):
variable = variable.literal
if variable and not isinstance(variable, str):
raise NotImplementedError(
"evaluating complex commands within wildcards is not supported right now",
)
return variable or default or var_name
4 changes: 4 additions & 0 deletions tests/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def test_literal_characters(self, input: str):
[
"colours",
"path/to/colours",
"path/to/${subject}",
"${pallette}_colours",
"${pallette:warm}_colours",
"locations/${room: dining room }/furniture",
"änder",
],
)
Expand Down
88 changes: 88 additions & 0 deletions tests/samplers/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,94 @@ def test_wildcard_nested_in_wildcard(
ps = [str(p) for p in islice(gen, len(expected))]
assert ps == expected

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_nested_variable(self, sampling_context: SamplingContext):
cmd = parse("${temp=cold}wearing __colors-${temp}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in (
"wearing blue suede shoes",
"wearing green suede shoes",
)
else:
assert resolved_value == "wearing blue suede shoes"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_default_variable(self, sampling_context: SamplingContext):
cmd = parse("wearing __colors-${temp:cold}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in (
"wearing blue suede shoes",
"wearing green suede shoes",
)
else:
assert resolved_value == "wearing blue suede shoes"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_undefined_variable(self, sampling_context: SamplingContext):
cmd = parse("wearing __colors-${temp}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert (
resolved_value
== f"wearing {sampling_context.wildcard_manager.wildcard_wrap}colors-temp{sampling_context.wildcard_manager.wildcard_wrap} suede shoes"
)

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_multiple_variables(self, sampling_context: SamplingContext):
cmd = parse(
"${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__",
)
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ("cat", "tiger")
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_variable_and_glob(self, sampling_context: SamplingContext):
cmd = parse("${genus=reptiles}__animals/${genus}/*__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ("cobra", "gecko", "iguana", "python")
else:
assert resolved_value == "cobra"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_variable_in_nested_wildcard(
self,
sampling_context: SamplingContext,
):
cmd = parse("${genus=mammals}__animal__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ("cat", "tiger", "dog", "wolf")
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_nested_wildcard_with_parameterized_variable(
self,
sampling_context: SamplingContext,
):
cmd = parse("__animal(genus=mammals)__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ("cat", "tiger", "dog", "wolf")
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_nested_wildcards_in_single_file(self, sampling_context: SamplingContext):
cmd = parse(
"${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__",
)
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert resolved_value in (
"a sports car made by Porsche",
"a tractor made by John Deere",
)


class TestVariableCommands:
@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
Expand Down
53 changes: 52 additions & 1 deletion tests/samplers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from dynamicprompts.commands import (
WildcardCommand,
)
from dynamicprompts.samplers.utils import wildcard_to_variant
from dynamicprompts.samplers.utils import (
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
from pytest_lazyfixture import lazy_fixture

Expand Down Expand Up @@ -30,3 +33,51 @@ def test_wildcard_to_variant(sampling_context: SamplingContext):
)
assert variant_command.separator == "-"
assert variant_command.sampling_method == wildcard_command.sampling_method


@pytest.mark.parametrize(
("sampling_context", "initial_wildcard", "expected_wildcard", "variables"),
[
(
lazy_fixture("random_sampling_context"),
"colors-${temp: warm}-${ finish }",
"colors-warm-finish",
{},
),
(
lazy_fixture("random_sampling_context"),
"colors-${temp: warm}-${ finish: matte }",
"colors-warm-matte",
{},
),
(
lazy_fixture("random_sampling_context"),
"colors-${temp: warm}-${ finish: matte }",
"colors-cold-matte",
{"temp": "cold"},
),
(
lazy_fixture("random_sampling_context"),
"colors-${temp: warm}-${ finish: matte }",
"colors-cold-glossy",
{"temp": "cold", "finish": "glossy"},
),
],
)
def test_replace_wildcard_variables_multi_variable(
sampling_context: SamplingContext,
initial_wildcard: str,
expected_wildcard: str,
variables: dict,
):
var_sampling_context = sampling_context.with_variables(variables=variables)
wildcard_command = WildcardCommand(initial_wildcard)
updated_command = replace_wildcard_variables(
command=wildcard_command,
context=var_sampling_context,
)
assert isinstance(
updated_command,
WildcardCommand,
), "updated command is also a WildcardCommand"
assert updated_command.wildcard == expected_wildcard
1 change: 1 addition & 0 deletions tests/test_data/wildcards/animal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__animals/${genus:*}/*__
Loading

0 comments on commit ff47441

Please sign in to comment.