diff --git a/src/dynamicprompts/commands/wildcard_command.py b/src/dynamicprompts/commands/wildcard_command.py index 7958d27..ef0052c 100644 --- a/src/dynamicprompts/commands/wildcard_command.py +++ b/src/dynamicprompts/commands/wildcard_command.py @@ -15,3 +15,6 @@ class WildcardCommand(Command): def __post_init__(self): if not isinstance(self.wildcard, str): raise TypeError(f"Wildcard must be a string, not {type(self.wildcard)}") + + def with_content(self, content: str) -> WildcardCommand: + return dataclasses.replace(self, wildcard=content) diff --git a/src/dynamicprompts/parser/parse.py b/src/dynamicprompts/parser/parse.py index aa4d0eb..56cacfe 100644 --- a/src/dynamicprompts/parser/parse.py +++ b/src/dynamicprompts/parser/parse.py @@ -111,10 +111,10 @@ def _configure_wildcard( parser_config: ParserConfig, variable_ref: pp.ParserElement, ) -> pp.ParserElement: - wildcard_path_re = r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+" - wildcard_path = pp.Regex(wildcard_path_re).leave_whitespace() - - wildcard = pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path") + wildcard = _configure_wildcard_path( + parser_config=parser_config, + variable_ref=variable_ref, + ) wildcard_enclosure = pp.Suppress(parser_config.wildcard_wrap) wildcard_variable_spec = ( @@ -135,6 +135,17 @@ def _configure_wildcard( return wildcard("wildcard").leave_whitespace() +def _configure_wildcard_path( + parser_config: ParserConfig, + variable_ref: pp.ParserElement, +) -> pp.ParserElement: + wildcard_path_literal_re = ( + r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+" + ) + wildcard_path = pp.Regex(wildcard_path_literal_re).leave_whitespace() + return pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path") + + def _configure_literal_sequence( parser_config: ParserConfig, is_variant_literal: bool = False, @@ -219,6 +230,7 @@ def _configure_variable_access( ) return variable_access.leave_whitespace() + def _configure_variable_assignment( parser_config: ParserConfig, prompt: pp.ParserElement, @@ -416,10 +428,23 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) - - variable_ref = pp.Combine(_configure_variable_access(parser_config=parser_config, prompt=pp.SkipTo(parser_config.variable_end))).leave_whitespace() - variable_ref.set_parse_action(lambda s,l,t: parser_config.variable_start + "".join(t) + parser_config.variable_end) - + + variable_ref = pp.Combine( + _configure_variable_access( + parser_config=parser_config, + prompt=pp.SkipTo(parser_config.variable_end), + ), + ).leave_whitespace() + + # by default, the variable starting "${" and end "}" characters are + # stripped with variable access. This restores them so that the + # variable can be properly recognized at a later stage + variable_ref.set_parse_action( + lambda s, l, t: parser_config.variable_start + + "".join(t) + + parser_config.variable_end, + ) + variable_assignment = _configure_variable_assignment( parser_config=parser_config, prompt=variant_prompt, @@ -428,7 +453,10 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) - wildcard = _configure_wildcard(parser_config=parser_config, variable_ref=variable_ref) + wildcard = _configure_wildcard( + parser_config=parser_config, + variable_ref=variable_ref, + ) literal_sequence = _configure_literal_sequence(parser_config=parser_config) variant_literal_sequence = _configure_literal_sequence( is_variant_literal=True, @@ -466,7 +494,7 @@ def create_parser( variants.set_parse_action(_parse_variant_command) literal_sequence.set_parse_action(_parse_literal_command) variant_literal_sequence.set_parse_action(_parse_literal_command) - variable_access.set_parse_action(_parse_variable_access_command) + variable_access.set_parse_action(_parse_variable_access_command) variable_assignment.set_parse_action(_parse_variable_assignment_command) prompt.set_parse_action(_parse_sequence_or_single_command) variant_prompt.set_parse_action(_parse_sequence_or_single_command) diff --git a/src/dynamicprompts/samplers/random.py b/src/dynamicprompts/samplers/random.py index 3779fc9..b8b3aec 100644 --- a/src/dynamicprompts/samplers/random.py +++ b/src/dynamicprompts/samplers/random.py @@ -13,8 +13,8 @@ from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( get_wildcard_not_found_fallback, + replace_wildcard_variables, wildcard_to_variant, - replace_wildcard_variables ) from dynamicprompts.sampling_context import SamplingContext from dynamicprompts.sampling_result import SamplingResult diff --git a/src/dynamicprompts/samplers/utils.py b/src/dynamicprompts/samplers/utils.py index c462aef..a06fd05 100644 --- a/src/dynamicprompts/samplers/utils.py +++ b/src/dynamicprompts/samplers/utils.py @@ -1,12 +1,21 @@ from __future__ import annotations import logging +from functools import partial import pyparsing as pp -import re -from dynamicprompts.commands import LiteralCommand, VariantCommand, VariantOption, WildcardCommand -from dynamicprompts.parser.parse import parse, _configure_variable_access, _configure_wildcard, _parse_variable_access_command +from dynamicprompts.commands import ( + LiteralCommand, + VariantCommand, + VariantOption, + WildcardCommand, +) +from dynamicprompts.parser.parse import ( + _configure_variable_access, + _configure_wildcard_path, + parse, +) from dynamicprompts.sampling_context import SamplingContext from dynamicprompts.sampling_result import SamplingResult from dynamicprompts.types import ResultGen @@ -55,47 +64,54 @@ def get_wildcard_not_found_fallback( while True: yield res + def replace_wildcard_variables( - command: WildcardCommand, - *, + command: WildcardCommand, + *, context: SamplingContext, -) -> WildcardCommand: +) -> WildcardCommand: + if not command.wildcard.__contains__(context.parser_config.variable_start): + return command + prompt = pp.SkipTo(context.parser_config.variable_end) variable_access = _configure_variable_access( parser_config=context.parser_config, - prompt=prompt + prompt=prompt, + ) + variable_access.set_parse_action( + partial(_replace_variable, variables=context.variables), + ) + wildcard = _configure_wildcard_path( + parser_config=context.parser_config, + variable_ref=variable_access, ) - variable_access.set_parse_action(_variable_replacement(context.variables)) - - wildcard_path_re = r"((?!" + re.escape(context.parser_config.wildcard_wrap) + r")[^(${}#])+" - wildcard_path = pp.Regex(wildcard_path_re).leave_whitespace() - wildcard = pp.Combine(pp.OneOrMore(variable_access | wildcard_path))("path") - - wildcard_result = wildcard.parse_string(command.wildcard) - - return WildcardCommand(wildcard="".join(wildcard_result), sampling_method=command.sampling_method, variables=command.variables) - -def _variable_replacement(variables: dict): - def var_replace(string, location, token): - if (isinstance(token, pp.ParseResults)): - var_parts = token[0].as_dict() - var_name = var_parts.get("name") - if (var_name != None): - var_name = var_name.strip() - - default = var_parts.get("default") - if (default != None): - default = default.strip() - - else: - var_name = token - - variable = None - if (var_name != None): - - variable = variables.get(var_name) - - if (isinstance(variable, LiteralCommand)): - variable = variable.literal - return variable or default or var_name - return var_replace \ No newline at end of file + + try: + wildcard_result = wildcard.parse_string(command.wildcard) + return command.with_content("".join(wildcard_result)) + except Exception as ex: + logger.warning("Unable to parse wildcard path %s", command.wildcard) + return command + + +def _replace_variable(string, location, token, *, variables: dict): + if isinstance(token, pp.ParseResults): + var_parts = token[0].as_dict() + var_name = var_parts.get("name") + if var_name: + var_name = var_name.strip() + + default = var_parts.get("default") + if default: + default = default.strip() + + else: + var_name = token + + variable = None + if var_name: + variable = variables.get(var_name) + + if isinstance(variable, LiteralCommand): + variable = variable.literal + return variable or default or var_name diff --git a/tests/samplers/test_common.py b/tests/samplers/test_common.py index a602bea..c92f170 100644 --- a/tests/samplers/test_common.py +++ b/tests/samplers/test_common.py @@ -12,7 +12,6 @@ VariantOption, WildcardCommand, ) -from dynamicprompts.commands.variable_commands import VariableAssignmentCommand from dynamicprompts.enums import SamplingMethod from dynamicprompts.parser.parse import parse from dynamicprompts.samplers import CombinatorialSampler, CyclicalSampler, RandomSampler @@ -630,72 +629,90 @@ def test_wildcard_nested_in_wildcard( @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) def test_wildcard_with_nested_variable(self, sampling_context: SamplingContext): cmd = parse("${temp=cold}wearing __colors-${temp}__ suede shoes") - resolved_value = str(next(sampling_context.generator_from_command(cmd))) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) if isinstance(sampling_context.default_sampler, RandomSampler): - assert resolved_value == "wearing blue suede shoes" or resolved_value == "wearing green suede shoes" + assert resolved_value in ( + "wearing blue suede shoes", + "wearing green suede shoes", + ) else: assert resolved_value == "wearing blue suede shoes" @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) def test_wildcard_with_default_variable(self, sampling_context: SamplingContext): cmd = parse("wearing __colors-${temp:cold}__ suede shoes") - resolved_value = str(next(sampling_context.generator_from_command(cmd))) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) if isinstance(sampling_context.default_sampler, RandomSampler): - assert resolved_value == "wearing blue suede shoes" or resolved_value == "wearing green suede shoes" + assert resolved_value in ( + "wearing blue suede shoes", + "wearing green suede shoes", + ) else: assert resolved_value == "wearing blue suede shoes" @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) def test_wildcard_with_undefined_variable(self, sampling_context: SamplingContext): cmd = parse("wearing __colors-${temp}__ suede shoes") - resolved_value = str(next(sampling_context.generator_from_command(cmd))) - assert resolved_value == "wearing {wildcard_wrap}colors-temp{wildcard_wrap} suede shoes".format(wildcard_wrap=sampling_context.wildcard_manager.wildcard_wrap) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + assert ( + resolved_value + == f"wearing {sampling_context.wildcard_manager.wildcard_wrap}colors-temp{sampling_context.wildcard_manager.wildcard_wrap} suede shoes" + ) @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) def test_wildcard_with_multiple_variables(self, sampling_context: SamplingContext): - cmd = parse("${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__") - resolved_value = str(next(sampling_context.generator_from_command(cmd))) + cmd = parse( + "${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__", + ) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) if isinstance(sampling_context.default_sampler, RandomSampler): - assert resolved_value in ["cat", "tiger"] + assert resolved_value in ("cat", "tiger") else: assert resolved_value == "cat" @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) def test_wildcard_with_variable_and_glob(self, sampling_context: SamplingContext): cmd = parse("${genus=reptiles}__animals/${genus}/*__") - resolved_value = str(next(sampling_context.generator_from_command(cmd))) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) if isinstance(sampling_context.default_sampler, RandomSampler): - assert resolved_value in ["cobra", "gecko", "iguana", "python"] + assert resolved_value in ("cobra", "gecko", "iguana", "python") else: - assert resolved_value == "cobra" - + assert resolved_value == "cobra" + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) - def test_wildcard_with_variable_in_nested_wildcard(self, sampling_context: SamplingContext): + def test_wildcard_with_variable_in_nested_wildcard( + self, + sampling_context: SamplingContext, + ): cmd = parse("${genus=mammals}__animal__") - resolved_value = str(next(sampling_context.generator_from_command(cmd))) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) if isinstance(sampling_context.default_sampler, RandomSampler): - assert resolved_value in ["cat", "tiger", "dog", "wolf"] + assert resolved_value in ("cat", "tiger", "dog", "wolf") else: assert resolved_value == "cat" @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) - def test_nested_wildcard_with_parameterized_variable(self, sampling_context: SamplingContext): + def test_nested_wildcard_with_parameterized_variable( + self, + sampling_context: SamplingContext, + ): cmd = parse("__animal(genus=mammals)__") - resolved_value = str(next(sampling_context.generator_from_command(cmd))) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) if isinstance(sampling_context.default_sampler, RandomSampler): - assert resolved_value in ["cat", "tiger", "dog", "wolf"] + assert resolved_value in ("cat", "tiger", "dog", "wolf") else: - assert resolved_value == "cat" - + assert resolved_value == "cat" + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) def test_nested_wildcards_in_single_file(self, sampling_context: SamplingContext): - cmd = parse("${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__") - resolved_value = str(next(sampling_context.generator_from_command(cmd))) - assert resolved_value in [ + cmd = parse( + "${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__", + ) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + assert resolved_value in ( "a sports car made by Porsche", - "a tractor made by John Deere" - ] - + "a tractor made by John Deere", + ) class TestVariableCommands: diff --git a/tests/samplers/test_utils.py b/tests/samplers/test_utils.py index 02fa1f1..5fae98f 100644 --- a/tests/samplers/test_utils.py +++ b/tests/samplers/test_utils.py @@ -3,8 +3,8 @@ WildcardCommand, ) from dynamicprompts.samplers.utils import ( - wildcard_to_variant, - replace_wildcard_variables + replace_wildcard_variables, + wildcard_to_variant, ) from dynamicprompts.sampling_context import SamplingContext from pytest_lazyfixture import lazy_fixture @@ -34,21 +34,50 @@ def test_wildcard_to_variant(sampling_context: SamplingContext): assert variant_command.separator == "-" assert variant_command.sampling_method == wildcard_command.sampling_method + @pytest.mark.parametrize( - ("sampling_context","initial_wildcard","expected_wildcard","variables"), + ("sampling_context", "initial_wildcard", "expected_wildcard", "variables"), [ - (lazy_fixture("random_sampling_context"),"colors-${temp: warm}-${ finish }", "colors-warm-finish", {}), - (lazy_fixture("random_sampling_context"),"colors-${temp: warm}-${ finish: matte }", "colors-warm-matte", {}), - (lazy_fixture("random_sampling_context"),"colors-${temp: warm}-${ finish: matte }", "colors-cold-matte", {"temp": "cold"}), - (lazy_fixture("random_sampling_context"),"colors-${temp: warm}-${ finish: matte }", "colors-cold-glossy", {"temp": "cold", "finish": "glossy"}), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish }", + "colors-warm-finish", + {}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-warm-matte", + {}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-cold-matte", + {"temp": "cold"}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-cold-glossy", + {"temp": "cold", "finish": "glossy"}, + ), ], ) -def test_replace_wildcard_variables_multi_variable(sampling_context: SamplingContext, - initial_wildcard: str, - expected_wildcard: str, - variables: dict): +def test_replace_wildcard_variables_multi_variable( + sampling_context: SamplingContext, + initial_wildcard: str, + expected_wildcard: str, + variables: dict, +): var_sampling_context = sampling_context.with_variables(variables=variables) - wildcard_command = WildcardCommand(initial_wildcard) - updated_command = replace_wildcard_variables(command=wildcard_command, context=var_sampling_context) - assert isinstance(updated_command, WildcardCommand), "updated command is also a WildcardCommand" + wildcard_command = WildcardCommand(initial_wildcard) + updated_command = replace_wildcard_variables( + command=wildcard_command, + context=var_sampling_context, + ) + assert isinstance( + updated_command, + WildcardCommand, + ), "updated command is also a WildcardCommand" assert updated_command.wildcard == expected_wildcard diff --git a/tests/test_data/wildcards/animal.txt b/tests/test_data/wildcards/animal.txt index 56a00d3..46c0125 100644 --- a/tests/test_data/wildcards/animal.txt +++ b/tests/test_data/wildcards/animal.txt @@ -1 +1 @@ -__animals/${genus:*}/*__ \ No newline at end of file +__animals/${genus:*}/*__ diff --git a/tests/test_data/wildcards/animals/reptiles/lizards.yaml b/tests/test_data/wildcards/animals/reptiles/lizards.yaml index 673d8bf..e6e6fb5 100644 --- a/tests/test_data/wildcards/animals/reptiles/lizards.yaml +++ b/tests/test_data/wildcards/animals/reptiles/lizards.yaml @@ -1,2 +1,2 @@ - iguana -- gecko \ No newline at end of file +- gecko diff --git a/tests/test_data/wildcards/animals/reptiles/snakes.txt b/tests/test_data/wildcards/animals/reptiles/snakes.txt index 500ff03..fe228ad 100644 --- a/tests/test_data/wildcards/animals/reptiles/snakes.txt +++ b/tests/test_data/wildcards/animals/reptiles/snakes.txt @@ -1,2 +1,2 @@ python -cobra \ No newline at end of file +cobra diff --git a/tests/test_data/wildcards/cars.yaml b/tests/test_data/wildcards/cars.yaml index 8e6ecb2..ec28991 100644 --- a/tests/test_data/wildcards/cars.yaml +++ b/tests/test_data/wildcards/cars.yaml @@ -28,4 +28,4 @@ cars: - tractor colors: - green - - yellow \ No newline at end of file + - yellow diff --git a/tests/wildcard/test_wildcardmanager.py b/tests/wildcard/test_wildcardmanager.py index 9465667..da04739 100644 --- a/tests/wildcard/test_wildcardmanager.py +++ b/tests/wildcard/test_wildcardmanager.py @@ -156,7 +156,7 @@ def test_hierarchy(wildcard_manager: WildcardManager): "wrappers", } assert set(root.collections) == { - "animal", + "animal", "clothing", # from pantry YAML "colors-cold", # .txt "colors-warm", # .txt @@ -184,7 +184,7 @@ def test_hierarchy(wildcard_manager: WildcardManager): "animals/mammals/feline", "animals/mystical", "animals/reptiles/lizards", - "animals/reptiles/snakes" + "animals/reptiles/snakes", } assert set(root.child_nodes["flavors"].collections) == { "sour", # .txt