From ff474419aafb0959894295a9c468e5a4a5b4c216 Mon Sep 17 00:00:00 2001 From: Michael Wooten Date: Thu, 23 Nov 2023 08:55:42 -0500 Subject: [PATCH] Add Variable Support in Wildcard Paths (#110) Adds support for the issue described in https://github.com/adieyal/sd-dynamic-prompts/issues/614. The patch adds support for allowing variable refences within wildcard paths (e.g. `__cars/${car}/types__`). During the parsing phase, the variables will be retained in the path. During sampling, before resolving the wildcard, the variable is first resolved and replaced in the wildcard. The update should allow support for variable replacement, including providing default values using the normal `:default` syntax. Some examples: * `${gender=male}${section=upper} __clothes/${gender}/${section}__` * `__clothes/${gender:female}/${section:lower}__` * `__clothes/${gender:female}/${section:*}__` --- .../commands/wildcard_command.py | 3 + src/dynamicprompts/parser/parse.py | 43 ++++++++- src/dynamicprompts/samplers/combinatorial.py | 2 + src/dynamicprompts/samplers/cycle.py | 2 + src/dynamicprompts/samplers/random.py | 2 + src/dynamicprompts/samplers/utils.py | 73 ++++++++++++++- tests/parser/test_parser.py | 4 + tests/samplers/test_common.py | 88 +++++++++++++++++++ tests/samplers/test_utils.py | 53 ++++++++++- tests/test_data/wildcards/animal.txt | 1 + .../wildcards/animals/reptiles/lizards.yaml | 2 + .../wildcards/animals/reptiles/snakes.txt | 2 + tests/test_data/wildcards/cars.yaml | 31 +++++++ tests/wildcard/test_wildcardmanager.py | 25 ++++++ 14 files changed, 324 insertions(+), 7 deletions(-) create mode 100644 tests/test_data/wildcards/animal.txt create mode 100644 tests/test_data/wildcards/animals/reptiles/lizards.yaml create mode 100644 tests/test_data/wildcards/animals/reptiles/snakes.txt create mode 100644 tests/test_data/wildcards/cars.yaml diff --git a/src/dynamicprompts/commands/wildcard_command.py b/src/dynamicprompts/commands/wildcard_command.py index 7958d27..ef0052c 100644 --- a/src/dynamicprompts/commands/wildcard_command.py +++ b/src/dynamicprompts/commands/wildcard_command.py @@ -15,3 +15,6 @@ class WildcardCommand(Command): def __post_init__(self): if not isinstance(self.wildcard, str): raise TypeError(f"Wildcard must be a string, not {type(self.wildcard)}") + + def with_content(self, content: str) -> WildcardCommand: + return dataclasses.replace(self, wildcard=content) diff --git a/src/dynamicprompts/parser/parse.py b/src/dynamicprompts/parser/parse.py index e727762..7b8ca35 100644 --- a/src/dynamicprompts/parser/parse.py +++ b/src/dynamicprompts/parser/parse.py @@ -109,9 +109,13 @@ def _configure_range() -> pp.ParserElement: def _configure_wildcard( parser_config: ParserConfig, + variable_ref: pp.ParserElement, ) -> pp.ParserElement: - wildcard_path_re = r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^({}#])+" - wildcard_path = pp.Regex(wildcard_path_re)("path").leave_whitespace() + wildcard = _configure_wildcard_path( + parser_config=parser_config, + variable_ref=variable_ref, + ) + wildcard_enclosure = pp.Suppress(parser_config.wildcard_wrap) wildcard_variable_spec = ( OPT_WS @@ -123,7 +127,7 @@ def _configure_wildcard( wildcard = ( wildcard_enclosure + pp.Opt(sampler_symbol)("sampling_method") - + wildcard_path + + wildcard + pp.Opt(wildcard_variable_spec) + wildcard_enclosure ) @@ -131,6 +135,17 @@ def _configure_wildcard( return wildcard("wildcard").leave_whitespace() +def _configure_wildcard_path( + parser_config: ParserConfig, + variable_ref: pp.ParserElement, +) -> pp.ParserElement: + wildcard_path_literal_re = ( + r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+" + ) + wildcard_path = pp.Regex(wildcard_path_literal_re).leave_whitespace() + return pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path") + + def _configure_literal_sequence( parser_config: ParserConfig, is_variant_literal: bool = False, @@ -413,6 +428,23 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) + + variable_ref = pp.Combine( + _configure_variable_access( + parser_config=parser_config, + prompt=pp.SkipTo(parser_config.variable_end), + ), + ).leave_whitespace() + + # by default, the variable starting "${" and end "}" characters are + # stripped with variable access. This restores them so that the + # variable can be properly recognized at a later stage + variable_ref.set_parse_action( + lambda string, + location, + token: f"{parser_config.variable_start}{''.join(token)}{parser_config.variable_end}", + ) + variable_assignment = _configure_variable_assignment( parser_config=parser_config, prompt=variant_prompt, @@ -421,7 +453,10 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) - wildcard = _configure_wildcard(parser_config=parser_config) + wildcard = _configure_wildcard( + parser_config=parser_config, + variable_ref=variable_ref, + ) literal_sequence = _configure_literal_sequence(parser_config=parser_config) variant_literal_sequence = _configure_literal_sequence( is_variant_literal=True, diff --git a/src/dynamicprompts/samplers/combinatorial.py b/src/dynamicprompts/samplers/combinatorial.py index 2e838ac..b939330 100644 --- a/src/dynamicprompts/samplers/combinatorial.py +++ b/src/dynamicprompts/samplers/combinatorial.py @@ -16,6 +16,7 @@ from dynamicprompts.samplers.command_collection import CommandCollection from dynamicprompts.samplers.utils import ( get_wildcard_not_found_fallback, + replace_wildcard_variables, wildcard_to_variant, ) from dynamicprompts.sampling_context import SamplingContext @@ -143,6 +144,7 @@ def _get_wildcard( context: SamplingContext, ) -> ResultGen: # TODO: doesn't support weights + command = replace_wildcard_variables(command=command, context=context) context = context.with_variables(command.variables) values = context.wildcard_manager.get_values(command.wildcard) if not values: diff --git a/src/dynamicprompts/samplers/cycle.py b/src/dynamicprompts/samplers/cycle.py index 4d46e67..0c916ea 100644 --- a/src/dynamicprompts/samplers/cycle.py +++ b/src/dynamicprompts/samplers/cycle.py @@ -13,6 +13,7 @@ from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( get_wildcard_not_found_fallback, + replace_wildcard_variables, wildcard_to_variant, ) from dynamicprompts.sampling_context import SamplingContext @@ -100,6 +101,7 @@ def _get_wildcard( context: SamplingContext, ) -> ResultGen: # TODO: doesn't support weights + command = replace_wildcard_variables(command=command, context=context) wc_values = context.wildcard_manager.get_values(command.wildcard) new_context = context.with_variables( command.variables, diff --git a/src/dynamicprompts/samplers/random.py b/src/dynamicprompts/samplers/random.py index 8523606..b8b3aec 100644 --- a/src/dynamicprompts/samplers/random.py +++ b/src/dynamicprompts/samplers/random.py @@ -13,6 +13,7 @@ from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( get_wildcard_not_found_fallback, + replace_wildcard_variables, wildcard_to_variant, ) from dynamicprompts.sampling_context import SamplingContext @@ -114,6 +115,7 @@ def _get_wildcard( command: WildcardCommand, context: SamplingContext, ) -> ResultGen: + command = replace_wildcard_variables(command=command, context=context) context = context.with_variables(command.variables) values = context.wildcard_manager.get_values(command.wildcard) diff --git a/src/dynamicprompts/samplers/utils.py b/src/dynamicprompts/samplers/utils.py index 35f6554..55a8757 100644 --- a/src/dynamicprompts/samplers/utils.py +++ b/src/dynamicprompts/samplers/utils.py @@ -1,9 +1,21 @@ from __future__ import annotations import logging +from functools import partial -from dynamicprompts.commands import VariantCommand, VariantOption, WildcardCommand -from dynamicprompts.parser.parse import parse +import pyparsing as pp + +from dynamicprompts.commands import ( + LiteralCommand, + VariantCommand, + VariantOption, + WildcardCommand, +) +from dynamicprompts.parser.parse import ( + _configure_variable_access, + _configure_wildcard_path, + parse, +) from dynamicprompts.sampling_context import SamplingContext from dynamicprompts.sampling_result import SamplingResult from dynamicprompts.types import ResultGen @@ -19,6 +31,7 @@ def wildcard_to_variant( max_bound=1, separator=",", ) -> VariantCommand: + command = replace_wildcard_variables(command=command, context=context) values = context.wildcard_manager.get_values(command.wildcard) min_bound = min(min_bound, len(values)) max_bound = min(max_bound, len(values)) @@ -50,3 +63,59 @@ def get_wildcard_not_found_fallback( res = SamplingResult(text=wrapped_wildcard) while True: yield res + + +def replace_wildcard_variables( + command: WildcardCommand, + *, + context: SamplingContext, +) -> WildcardCommand: + if context.parser_config.variable_start not in command.wildcard: + return command + + prompt = pp.SkipTo(context.parser_config.variable_end) + variable_access = _configure_variable_access( + parser_config=context.parser_config, + prompt=prompt, + ) + variable_access.set_parse_action( + partial(_replace_variable, variables=context.variables), + ) + wildcard = _configure_wildcard_path( + parser_config=context.parser_config, + variable_ref=variable_access, + ) + + try: + wildcard_result = wildcard.parse_string(command.wildcard) + return command.with_content("".join(wildcard_result)) + except Exception: + logger.warning("Unable to parse wildcard %r", command.wildcard, exc_info=True) + return command + + +def _replace_variable(string, location, token, *, variables: dict): + if isinstance(token, pp.ParseResults): + var_parts = token[0].as_dict() + var_name = var_parts.get("name") + if var_name: + var_name = var_name.strip() + + default = var_parts.get("default") + if default: + default = default.strip() + + else: + var_name = token + + variable = None + if var_name: + variable = variables.get(var_name) + + if isinstance(variable, LiteralCommand): + variable = variable.literal + if variable and not isinstance(variable, str): + raise NotImplementedError( + "evaluating complex commands within wildcards is not supported right now", + ) + return variable or default or var_name diff --git a/tests/parser/test_parser.py b/tests/parser/test_parser.py index e01debc..a8603c8 100644 --- a/tests/parser/test_parser.py +++ b/tests/parser/test_parser.py @@ -48,6 +48,10 @@ def test_literal_characters(self, input: str): [ "colours", "path/to/colours", + "path/to/${subject}", + "${pallette}_colours", + "${pallette:warm}_colours", + "locations/${room: dining room }/furniture", "änder", ], ) diff --git a/tests/samplers/test_common.py b/tests/samplers/test_common.py index df561c3..c92f170 100644 --- a/tests/samplers/test_common.py +++ b/tests/samplers/test_common.py @@ -626,6 +626,94 @@ def test_wildcard_nested_in_wildcard( ps = [str(p) for p in islice(gen, len(expected))] assert ps == expected + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_nested_variable(self, sampling_context: SamplingContext): + cmd = parse("${temp=cold}wearing __colors-${temp}__ suede shoes") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ( + "wearing blue suede shoes", + "wearing green suede shoes", + ) + else: + assert resolved_value == "wearing blue suede shoes" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_default_variable(self, sampling_context: SamplingContext): + cmd = parse("wearing __colors-${temp:cold}__ suede shoes") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ( + "wearing blue suede shoes", + "wearing green suede shoes", + ) + else: + assert resolved_value == "wearing blue suede shoes" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_undefined_variable(self, sampling_context: SamplingContext): + cmd = parse("wearing __colors-${temp}__ suede shoes") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + assert ( + resolved_value + == f"wearing {sampling_context.wildcard_manager.wildcard_wrap}colors-temp{sampling_context.wildcard_manager.wildcard_wrap} suede shoes" + ) + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_multiple_variables(self, sampling_context: SamplingContext): + cmd = parse( + "${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__", + ) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ("cat", "tiger") + else: + assert resolved_value == "cat" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_variable_and_glob(self, sampling_context: SamplingContext): + cmd = parse("${genus=reptiles}__animals/${genus}/*__") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ("cobra", "gecko", "iguana", "python") + else: + assert resolved_value == "cobra" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_variable_in_nested_wildcard( + self, + sampling_context: SamplingContext, + ): + cmd = parse("${genus=mammals}__animal__") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ("cat", "tiger", "dog", "wolf") + else: + assert resolved_value == "cat" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_nested_wildcard_with_parameterized_variable( + self, + sampling_context: SamplingContext, + ): + cmd = parse("__animal(genus=mammals)__") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ("cat", "tiger", "dog", "wolf") + else: + assert resolved_value == "cat" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_nested_wildcards_in_single_file(self, sampling_context: SamplingContext): + cmd = parse( + "${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__", + ) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + assert resolved_value in ( + "a sports car made by Porsche", + "a tractor made by John Deere", + ) + class TestVariableCommands: @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) diff --git a/tests/samplers/test_utils.py b/tests/samplers/test_utils.py index 01fca1d..5fae98f 100644 --- a/tests/samplers/test_utils.py +++ b/tests/samplers/test_utils.py @@ -2,7 +2,10 @@ from dynamicprompts.commands import ( WildcardCommand, ) -from dynamicprompts.samplers.utils import wildcard_to_variant +from dynamicprompts.samplers.utils import ( + replace_wildcard_variables, + wildcard_to_variant, +) from dynamicprompts.sampling_context import SamplingContext from pytest_lazyfixture import lazy_fixture @@ -30,3 +33,51 @@ def test_wildcard_to_variant(sampling_context: SamplingContext): ) assert variant_command.separator == "-" assert variant_command.sampling_method == wildcard_command.sampling_method + + +@pytest.mark.parametrize( + ("sampling_context", "initial_wildcard", "expected_wildcard", "variables"), + [ + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish }", + "colors-warm-finish", + {}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-warm-matte", + {}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-cold-matte", + {"temp": "cold"}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-cold-glossy", + {"temp": "cold", "finish": "glossy"}, + ), + ], +) +def test_replace_wildcard_variables_multi_variable( + sampling_context: SamplingContext, + initial_wildcard: str, + expected_wildcard: str, + variables: dict, +): + var_sampling_context = sampling_context.with_variables(variables=variables) + wildcard_command = WildcardCommand(initial_wildcard) + updated_command = replace_wildcard_variables( + command=wildcard_command, + context=var_sampling_context, + ) + assert isinstance( + updated_command, + WildcardCommand, + ), "updated command is also a WildcardCommand" + assert updated_command.wildcard == expected_wildcard diff --git a/tests/test_data/wildcards/animal.txt b/tests/test_data/wildcards/animal.txt new file mode 100644 index 0000000..46c0125 --- /dev/null +++ b/tests/test_data/wildcards/animal.txt @@ -0,0 +1 @@ +__animals/${genus:*}/*__ diff --git a/tests/test_data/wildcards/animals/reptiles/lizards.yaml b/tests/test_data/wildcards/animals/reptiles/lizards.yaml new file mode 100644 index 0000000..e6e6fb5 --- /dev/null +++ b/tests/test_data/wildcards/animals/reptiles/lizards.yaml @@ -0,0 +1,2 @@ +- iguana +- gecko diff --git a/tests/test_data/wildcards/animals/reptiles/snakes.txt b/tests/test_data/wildcards/animals/reptiles/snakes.txt new file mode 100644 index 0000000..fe228ad --- /dev/null +++ b/tests/test_data/wildcards/animals/reptiles/snakes.txt @@ -0,0 +1,2 @@ +python +cobra diff --git a/tests/test_data/wildcards/cars.yaml b/tests/test_data/wildcards/cars.yaml new file mode 100644 index 0000000..ec28991 --- /dev/null +++ b/tests/test_data/wildcards/cars.yaml @@ -0,0 +1,31 @@ +cars: + ford: + name: + - Ford + types: + - muscle car + - compact car + - roadster + colors: + - white + - black + - brown + - red + + porsche: + name: + - Porsche + types: + - sports car + colors: + - red + - black + + john_deere: + name: + - John Deere + types: + - tractor + colors: + - green + - yellow diff --git a/tests/wildcard/test_wildcardmanager.py b/tests/wildcard/test_wildcardmanager.py index 0f19aca..da04739 100644 --- a/tests/wildcard/test_wildcardmanager.py +++ b/tests/wildcard/test_wildcardmanager.py @@ -123,12 +123,24 @@ def test_hierarchy(wildcard_manager: WildcardManager): root = wildcard_manager.tree.root assert {name for name, item in root.walk_items()} == { "dupes", + "animal", "animals/all-references", "animals/mammals/canine", "animals/mammals/feline", "animals/mystical", + "animals/reptiles/lizards", + "animals/reptiles/snakes", "artists/dutch", "artists/finnish", + "cars/ford/colors", + "cars/ford/name", + "cars/ford/types", + "cars/john_deere/colors", + "cars/john_deere/name", + "cars/john_deere/types", + "cars/porsche/colors", + "cars/porsche/name", + "cars/porsche/types", "clothing", "colors-cold", "colors-warm", @@ -144,6 +156,7 @@ def test_hierarchy(wildcard_manager: WildcardManager): "wrappers", } assert set(root.collections) == { + "animal", "clothing", # from pantry YAML "colors-cold", # .txt "colors-warm", # .txt @@ -161,11 +174,17 @@ def test_hierarchy(wildcard_manager: WildcardManager): "canine", "feline", } + assert set(root.child_nodes["animals"].child_nodes["reptiles"].collections) == { + "lizards", + "snakes", + }, "animals/reptiles does not match" assert set(root.child_nodes["animals"].walk_full_names()) == { "animals/all-references", "animals/mammals/canine", "animals/mammals/feline", "animals/mystical", + "animals/reptiles/lizards", + "animals/reptiles/snakes", } assert set(root.child_nodes["flavors"].collections) == { "sour", # .txt @@ -336,6 +355,8 @@ def test_wcm_roots(): "elaimet/mammals/canine", "elaimet/mammals/feline", "elaimet/mystical", + "elaimet/reptiles/lizards", + "elaimet/reptiles/snakes", "elaimet/sopot", "metasyntactic/fnord", "metasyntactic/foo", @@ -344,9 +365,13 @@ def test_wcm_roots(): v for v in wcm.get_values("elaimet/*").string_values if not v.startswith("_") } == { "cat", + "cobra", "dog", + "gecko", + "iguana", "okapi", "pingviini", + "python", "tiger", "unicorn", "wolf",