Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) Add Variable Support in Wildcard Paths #110

Merged
merged 8 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,13 @@ def _configure_range() -> pp.ParserElement:

def _configure_wildcard(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
) -> pp.ParserElement:
wildcard_path_re = r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^({}#])+"
wildcard_path = pp.Regex(wildcard_path_re)("path").leave_whitespace()
wildcard_path_re = r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+"
wildcard_path = pp.Regex(wildcard_path_re).leave_whitespace()

wildcard = pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path")

wildcard_enclosure = pp.Suppress(parser_config.wildcard_wrap)
wildcard_variable_spec = (
OPT_WS
Expand All @@ -120,7 +124,7 @@ def _configure_wildcard(
wildcard = (
wildcard_enclosure
+ pp.Opt(sampler_symbol)("sampling_method")
+ wildcard_path
+ wildcard
+ pp.Opt(wildcard_variable_spec)
+ wildcard_enclosure
)
Expand Down Expand Up @@ -207,7 +211,6 @@ def _configure_variable_access(
)
return variable_access.leave_whitespace()


def _configure_variable_assignment(
parser_config: ParserConfig,
prompt: pp.ParserElement,
Expand Down Expand Up @@ -378,11 +381,15 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)

variable_ref = pp.Combine(_configure_variable_access(parser_config=parser_config, prompt=pp.SkipTo(parser_config.variable_end))).leave_whitespace()
variable_ref.set_parse_action(lambda s,l,t: parser_config.variable_start + "".join(t) + parser_config.variable_end)
akx marked this conversation as resolved.
Show resolved Hide resolved

variable_assignment = _configure_variable_assignment(
parser_config=parser_config,
prompt=variant_prompt,
)
wildcard = _configure_wildcard(parser_config=parser_config)
wildcard = _configure_wildcard(parser_config=parser_config, variable_ref=variable_ref)
literal_sequence = _configure_literal_sequence(parser_config=parser_config)
variant_literal_sequence = _configure_literal_sequence(
is_variant_literal=True,
Expand Down Expand Up @@ -413,7 +420,7 @@ def create_parser(
variants.set_parse_action(_parse_variant_command)
literal_sequence.set_parse_action(_parse_literal_command)
variant_literal_sequence.set_parse_action(_parse_literal_command)
variable_access.set_parse_action(_parse_variable_access_command)
variable_access.set_parse_action(_parse_variable_access_command)
variable_assignment.set_parse_action(_parse_variable_assignment_command)
prompt.set_parse_action(_parse_sequence_or_single_command)
variant_prompt.set_parse_action(_parse_sequence_or_single_command)
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/combinatorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dynamicprompts.samplers.command_collection import CommandCollection
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -142,6 +143,7 @@ def _get_wildcard(
context: SamplingContext,
) -> ResultGen:
# TODO: doesn't support weights
command = replace_wildcard_variables(command=command, context=context)
context = context.with_variables(command.variables)
values = context.wildcard_manager.get_values(command.wildcard)
if not values:
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -100,6 +101,7 @@ def _get_wildcard(
context: SamplingContext,
) -> ResultGen:
# TODO: doesn't support weights
command = replace_wildcard_variables(command=command, context=context)
wc_values = context.wildcard_manager.get_values(command.wildcard)
new_context = context.with_variables(
command.variables,
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
wildcard_to_variant,
replace_wildcard_variables
)
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
Expand Down Expand Up @@ -113,6 +114,7 @@ def _get_wildcard(
command: WildcardCommand,
context: SamplingContext,
) -> ResultGen:
command = replace_wildcard_variables(command=command, context=context)
context = context.with_variables(command.variables)
values = context.wildcard_manager.get_values(command.wildcard)

Expand Down
53 changes: 51 additions & 2 deletions src/dynamicprompts/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import logging

from dynamicprompts.commands import VariantCommand, VariantOption, WildcardCommand
from dynamicprompts.parser.parse import parse
import pyparsing as pp
import re

from dynamicprompts.commands import LiteralCommand, VariantCommand, VariantOption, WildcardCommand
from dynamicprompts.parser.parse import parse, _configure_variable_access, _configure_wildcard, _parse_variable_access_command
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
from dynamicprompts.types import ResultGen
Expand All @@ -19,6 +22,7 @@ def wildcard_to_variant(
max_bound=1,
separator=",",
) -> VariantCommand:
command = replace_wildcard_variables(command=command, context=context)
values = context.wildcard_manager.get_values(command.wildcard)
min_bound = min(min_bound, len(values))
max_bound = min(max_bound, len(values))
Expand Down Expand Up @@ -50,3 +54,48 @@ def get_wildcard_not_found_fallback(
res = SamplingResult(text=wrapped_wildcard)
while True:
yield res

def replace_wildcard_variables(
command: WildcardCommand,
*,
context: SamplingContext,
) -> WildcardCommand:
akx marked this conversation as resolved.
Show resolved Hide resolved
prompt = pp.SkipTo(context.parser_config.variable_end)
variable_access = _configure_variable_access(
parser_config=context.parser_config,
prompt=prompt
)
variable_access.set_parse_action(_variable_replacement(context.variables))

wildcard_path_re = r"((?!" + re.escape(context.parser_config.wildcard_wrap) + r")[^(${}#])+"
wildcard_path = pp.Regex(wildcard_path_re).leave_whitespace()
wildcard = pp.Combine(pp.OneOrMore(variable_access | wildcard_path))("path")
akx marked this conversation as resolved.
Show resolved Hide resolved

wildcard_result = wildcard.parse_string(command.wildcard)
akx marked this conversation as resolved.
Show resolved Hide resolved

return WildcardCommand(wildcard="".join(wildcard_result), sampling_method=command.sampling_method, variables=command.variables)
akx marked this conversation as resolved.
Show resolved Hide resolved

def _variable_replacement(variables: dict):
def var_replace(string, location, token):
akx marked this conversation as resolved.
Show resolved Hide resolved
if (isinstance(token, pp.ParseResults)):
var_parts = token[0].as_dict()
var_name = var_parts.get("name")
if (var_name != None):
akx marked this conversation as resolved.
Show resolved Hide resolved
var_name = var_name.strip()

default = var_parts.get("default")
if (default != None):
akx marked this conversation as resolved.
Show resolved Hide resolved
default = default.strip()

else:
var_name = token

variable = None
if (var_name != None):
akx marked this conversation as resolved.
Show resolved Hide resolved

variable = variables.get(var_name)

if (isinstance(variable, LiteralCommand)):
variable = variable.literal
return variable or default or var_name
return var_replace
4 changes: 4 additions & 0 deletions tests/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def test_literal_characters(self, input: str):
[
"colours",
"path/to/colours",
"path/to/${subject}",
"${pallette}_colours",
"${pallette:warm}_colours",
"locations/${room: dining room }/furniture",
"änder",
],
)
Expand Down
71 changes: 71 additions & 0 deletions tests/samplers/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
VariantOption,
WildcardCommand,
)
from dynamicprompts.commands.variable_commands import VariableAssignmentCommand
from dynamicprompts.enums import SamplingMethod
from dynamicprompts.parser.parse import parse
from dynamicprompts.samplers import CombinatorialSampler, CyclicalSampler, RandomSampler
Expand Down Expand Up @@ -626,6 +627,76 @@ def test_wildcard_nested_in_wildcard(
ps = [str(p) for p in islice(gen, len(expected))]
assert ps == expected

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_nested_variable(self, sampling_context: SamplingContext):
cmd = parse("${temp=cold}wearing __colors-${temp}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value == "wearing blue suede shoes" or resolved_value == "wearing green suede shoes"
akx marked this conversation as resolved.
Show resolved Hide resolved
else:
assert resolved_value == "wearing blue suede shoes"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_default_variable(self, sampling_context: SamplingContext):
cmd = parse("wearing __colors-${temp:cold}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value == "wearing blue suede shoes" or resolved_value == "wearing green suede shoes"
else:
assert resolved_value == "wearing blue suede shoes"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_undefined_variable(self, sampling_context: SamplingContext):
cmd = parse("wearing __colors-${temp}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert resolved_value == "wearing {wildcard_wrap}colors-temp{wildcard_wrap} suede shoes".format(wildcard_wrap=sampling_context.wildcard_manager.wildcard_wrap)

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_multiple_variables(self, sampling_context: SamplingContext):
cmd = parse("${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ["cat", "tiger"]
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_variable_and_glob(self, sampling_context: SamplingContext):
cmd = parse("${genus=reptiles}__animals/${genus}/*__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ["cobra", "gecko", "iguana", "python"]
else:
assert resolved_value == "cobra"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_variable_in_nested_wildcard(self, sampling_context: SamplingContext):
cmd = parse("${genus=mammals}__animal__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ["cat", "tiger", "dog", "wolf"]
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_nested_wildcard_with_parameterized_variable(self, sampling_context: SamplingContext):
cmd = parse("__animal(genus=mammals)__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ["cat", "tiger", "dog", "wolf"]
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_nested_wildcards_in_single_file(self, sampling_context: SamplingContext):
cmd = parse("${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert resolved_value in [
"a sports car made by Porsche",
"a tractor made by John Deere"
]
akx marked this conversation as resolved.
Show resolved Hide resolved



class TestVariableCommands:
@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
Expand Down
24 changes: 23 additions & 1 deletion tests/samplers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from dynamicprompts.commands import (
WildcardCommand,
)
from dynamicprompts.samplers.utils import wildcard_to_variant
from dynamicprompts.samplers.utils import (
wildcard_to_variant,
replace_wildcard_variables
)
from dynamicprompts.sampling_context import SamplingContext
from pytest_lazyfixture import lazy_fixture

Expand Down Expand Up @@ -30,3 +33,22 @@ def test_wildcard_to_variant(sampling_context: SamplingContext):
)
assert variant_command.separator == "-"
assert variant_command.sampling_method == wildcard_command.sampling_method

@pytest.mark.parametrize(
("sampling_context","initial_wildcard","expected_wildcard","variables"),
[
(lazy_fixture("random_sampling_context"),"colors-${temp: warm}-${ finish }", "colors-warm-finish", {}),
(lazy_fixture("random_sampling_context"),"colors-${temp: warm}-${ finish: matte }", "colors-warm-matte", {}),
(lazy_fixture("random_sampling_context"),"colors-${temp: warm}-${ finish: matte }", "colors-cold-matte", {"temp": "cold"}),
(lazy_fixture("random_sampling_context"),"colors-${temp: warm}-${ finish: matte }", "colors-cold-glossy", {"temp": "cold", "finish": "glossy"}),
],
)
def test_replace_wildcard_variables_multi_variable(sampling_context: SamplingContext,
initial_wildcard: str,
expected_wildcard: str,
variables: dict):
var_sampling_context = sampling_context.with_variables(variables=variables)
wildcard_command = WildcardCommand(initial_wildcard)
updated_command = replace_wildcard_variables(command=wildcard_command, context=var_sampling_context)
assert isinstance(updated_command, WildcardCommand), "updated command is also a WildcardCommand"
assert updated_command.wildcard == expected_wildcard
1 change: 1 addition & 0 deletions tests/test_data/wildcards/animal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__animals/${genus:*}/*__
2 changes: 2 additions & 0 deletions tests/test_data/wildcards/animals/reptiles/lizards.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- iguana
- gecko
2 changes: 2 additions & 0 deletions tests/test_data/wildcards/animals/reptiles/snakes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python
cobra
31 changes: 31 additions & 0 deletions tests/test_data/wildcards/cars.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
cars:
ford:
name:
- Ford
types:
- muscle car
- compact car
- roadster
colors:
- white
- black
- brown
- red

porsche:
name:
- Porsche
types:
- sports car
colors:
- red
- black

john_deere:
name:
- John Deere
types:
- tractor
colors:
- green
- yellow
Loading
Loading