Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) Add Variable Support in Wildcard Paths #110

Merged
merged 8 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/dynamicprompts/commands/wildcard_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ class WildcardCommand(Command):
def __post_init__(self):
if not isinstance(self.wildcard, str):
raise TypeError(f"Wildcard must be a string, not {type(self.wildcard)}")

def with_content(self, content: str) -> WildcardCommand:
return dataclasses.replace(self, wildcard=content)
43 changes: 39 additions & 4 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@ def _configure_range() -> pp.ParserElement:

def _configure_wildcard(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
) -> pp.ParserElement:
wildcard_path_re = r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^({}#])+"
wildcard_path = pp.Regex(wildcard_path_re)("path").leave_whitespace()
wildcard = _configure_wildcard_path(
parser_config=parser_config,
variable_ref=variable_ref,
)

wildcard_enclosure = pp.Suppress(parser_config.wildcard_wrap)
wildcard_variable_spec = (
OPT_WS
Expand All @@ -123,14 +127,25 @@ def _configure_wildcard(
wildcard = (
wildcard_enclosure
+ pp.Opt(sampler_symbol)("sampling_method")
+ wildcard_path
+ wildcard
+ pp.Opt(wildcard_variable_spec)
+ wildcard_enclosure
)

return wildcard("wildcard").leave_whitespace()


def _configure_wildcard_path(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
) -> pp.ParserElement:
wildcard_path_literal_re = (
r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+"
)
wildcard_path = pp.Regex(wildcard_path_literal_re).leave_whitespace()
return pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path")


def _configure_literal_sequence(
parser_config: ParserConfig,
is_variant_literal: bool = False,
Expand Down Expand Up @@ -413,6 +428,23 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)

variable_ref = pp.Combine(
_configure_variable_access(
parser_config=parser_config,
prompt=pp.SkipTo(parser_config.variable_end),
),
).leave_whitespace()

# by default, the variable starting "${" and end "}" characters are
# stripped with variable access. This restores them so that the
# variable can be properly recognized at a later stage
variable_ref.set_parse_action(
lambda string,
location,
token: f"{parser_config.variable_start}{''.join(token)}{parser_config.variable_end}",
)

variable_assignment = _configure_variable_assignment(
parser_config=parser_config,
prompt=variant_prompt,
Expand All @@ -421,7 +453,10 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)
wildcard = _configure_wildcard(parser_config=parser_config)
wildcard = _configure_wildcard(
parser_config=parser_config,
variable_ref=variable_ref,
)
literal_sequence = _configure_literal_sequence(parser_config=parser_config)
variant_literal_sequence = _configure_literal_sequence(
is_variant_literal=True,
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/combinatorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dynamicprompts.samplers.command_collection import CommandCollection
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -143,6 +144,7 @@ def _get_wildcard(
context: SamplingContext,
) -> ResultGen:
# TODO: doesn't support weights
command = replace_wildcard_variables(command=command, context=context)
context = context.with_variables(command.variables)
values = context.wildcard_manager.get_values(command.wildcard)
if not values:
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -100,6 +101,7 @@ def _get_wildcard(
context: SamplingContext,
) -> ResultGen:
# TODO: doesn't support weights
command = replace_wildcard_variables(command=command, context=context)
wc_values = context.wildcard_manager.get_values(command.wildcard)
new_context = context.with_variables(
command.variables,
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -114,6 +115,7 @@ def _get_wildcard(
command: WildcardCommand,
context: SamplingContext,
) -> ResultGen:
command = replace_wildcard_variables(command=command, context=context)
context = context.with_variables(command.variables)
values = context.wildcard_manager.get_values(command.wildcard)

Expand Down
73 changes: 71 additions & 2 deletions src/dynamicprompts/samplers/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from __future__ import annotations

import logging
from functools import partial

from dynamicprompts.commands import VariantCommand, VariantOption, WildcardCommand
from dynamicprompts.parser.parse import parse
import pyparsing as pp

from dynamicprompts.commands import (
LiteralCommand,
VariantCommand,
VariantOption,
WildcardCommand,
)
from dynamicprompts.parser.parse import (
_configure_variable_access,
_configure_wildcard_path,
parse,
)
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
from dynamicprompts.types import ResultGen
Expand All @@ -19,6 +31,7 @@ def wildcard_to_variant(
max_bound=1,
separator=",",
) -> VariantCommand:
command = replace_wildcard_variables(command=command, context=context)
values = context.wildcard_manager.get_values(command.wildcard)
min_bound = min(min_bound, len(values))
max_bound = min(max_bound, len(values))
Expand Down Expand Up @@ -50,3 +63,59 @@ def get_wildcard_not_found_fallback(
res = SamplingResult(text=wrapped_wildcard)
while True:
yield res


def replace_wildcard_variables(
command: WildcardCommand,
*,
context: SamplingContext,
) -> WildcardCommand:
if context.parser_config.variable_start not in command.wildcard:
return command

prompt = pp.SkipTo(context.parser_config.variable_end)
variable_access = _configure_variable_access(
parser_config=context.parser_config,
prompt=prompt,
)
variable_access.set_parse_action(
partial(_replace_variable, variables=context.variables),
)
wildcard = _configure_wildcard_path(
parser_config=context.parser_config,
variable_ref=variable_access,
)
akx marked this conversation as resolved.
Show resolved Hide resolved

try:
wildcard_result = wildcard.parse_string(command.wildcard)
return command.with_content("".join(wildcard_result))
except Exception:
logger.warning("Unable to parse wildcard %r", command.wildcard, exc_info=True)
return command


def _replace_variable(string, location, token, *, variables: dict):
if isinstance(token, pp.ParseResults):
var_parts = token[0].as_dict()
var_name = var_parts.get("name")
if var_name:
var_name = var_name.strip()

default = var_parts.get("default")
if default:
default = default.strip()

else:
var_name = token

variable = None
if var_name:
variable = variables.get(var_name)

if isinstance(variable, LiteralCommand):
variable = variable.literal
akx marked this conversation as resolved.
Show resolved Hide resolved
if variable and not isinstance(variable, str):
raise NotImplementedError(
"evaluating complex commands within wildcards is not supported right now",
)
return variable or default or var_name
4 changes: 4 additions & 0 deletions tests/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def test_literal_characters(self, input: str):
[
"colours",
"path/to/colours",
"path/to/${subject}",
"${pallette}_colours",
"${pallette:warm}_colours",
"locations/${room: dining room }/furniture",
"änder",
],
)
Expand Down
88 changes: 88 additions & 0 deletions tests/samplers/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,94 @@ def test_wildcard_nested_in_wildcard(
ps = [str(p) for p in islice(gen, len(expected))]
assert ps == expected

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_nested_variable(self, sampling_context: SamplingContext):
cmd = parse("${temp=cold}wearing __colors-${temp}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in (
"wearing blue suede shoes",
"wearing green suede shoes",
)
else:
assert resolved_value == "wearing blue suede shoes"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_default_variable(self, sampling_context: SamplingContext):
cmd = parse("wearing __colors-${temp:cold}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in (
"wearing blue suede shoes",
"wearing green suede shoes",
)
else:
assert resolved_value == "wearing blue suede shoes"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_undefined_variable(self, sampling_context: SamplingContext):
cmd = parse("wearing __colors-${temp}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert (
resolved_value
== f"wearing {sampling_context.wildcard_manager.wildcard_wrap}colors-temp{sampling_context.wildcard_manager.wildcard_wrap} suede shoes"
)

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_multiple_variables(self, sampling_context: SamplingContext):
cmd = parse(
"${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__",
)
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ("cat", "tiger")
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_variable_and_glob(self, sampling_context: SamplingContext):
cmd = parse("${genus=reptiles}__animals/${genus}/*__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ("cobra", "gecko", "iguana", "python")
else:
assert resolved_value == "cobra"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_variable_in_nested_wildcard(
self,
sampling_context: SamplingContext,
):
cmd = parse("${genus=mammals}__animal__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ("cat", "tiger", "dog", "wolf")
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_nested_wildcard_with_parameterized_variable(
self,
sampling_context: SamplingContext,
):
cmd = parse("__animal(genus=mammals)__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ("cat", "tiger", "dog", "wolf")
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_nested_wildcards_in_single_file(self, sampling_context: SamplingContext):
cmd = parse(
"${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__",
)
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert resolved_value in (
"a sports car made by Porsche",
"a tractor made by John Deere",
)


class TestVariableCommands:
@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
Expand Down
53 changes: 52 additions & 1 deletion tests/samplers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from dynamicprompts.commands import (
WildcardCommand,
)
from dynamicprompts.samplers.utils import wildcard_to_variant
from dynamicprompts.samplers.utils import (
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
from pytest_lazyfixture import lazy_fixture

Expand Down Expand Up @@ -30,3 +33,51 @@ def test_wildcard_to_variant(sampling_context: SamplingContext):
)
assert variant_command.separator == "-"
assert variant_command.sampling_method == wildcard_command.sampling_method


@pytest.mark.parametrize(
("sampling_context", "initial_wildcard", "expected_wildcard", "variables"),
[
(
lazy_fixture("random_sampling_context"),
"colors-${temp: warm}-${ finish }",
"colors-warm-finish",
{},
),
(
lazy_fixture("random_sampling_context"),
"colors-${temp: warm}-${ finish: matte }",
"colors-warm-matte",
{},
),
(
lazy_fixture("random_sampling_context"),
"colors-${temp: warm}-${ finish: matte }",
"colors-cold-matte",
{"temp": "cold"},
),
(
lazy_fixture("random_sampling_context"),
"colors-${temp: warm}-${ finish: matte }",
"colors-cold-glossy",
{"temp": "cold", "finish": "glossy"},
),
],
)
def test_replace_wildcard_variables_multi_variable(
sampling_context: SamplingContext,
initial_wildcard: str,
expected_wildcard: str,
variables: dict,
):
var_sampling_context = sampling_context.with_variables(variables=variables)
wildcard_command = WildcardCommand(initial_wildcard)
updated_command = replace_wildcard_variables(
command=wildcard_command,
context=var_sampling_context,
)
assert isinstance(
updated_command,
WildcardCommand,
), "updated command is also a WildcardCommand"
assert updated_command.wildcard == expected_wildcard
1 change: 1 addition & 0 deletions tests/test_data/wildcards/animal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__animals/${genus:*}/*__
Loading