Skip to content

Commit

Permalink
Addressed PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
mwootendev committed Nov 21, 2023
1 parent 0ccc4ef commit fb1bfe1
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 100 deletions.
3 changes: 3 additions & 0 deletions src/dynamicprompts/commands/wildcard_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ class WildcardCommand(Command):
def __post_init__(self):
if not isinstance(self.wildcard, str):
raise TypeError(f"Wildcard must be a string, not {type(self.wildcard)}")

def with_content(self, content: str) -> WildcardCommand:
return dataclasses.replace(self, wildcard=content)
48 changes: 38 additions & 10 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def _configure_wildcard(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
) -> pp.ParserElement:
wildcard_path_re = r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+"
wildcard_path = pp.Regex(wildcard_path_re).leave_whitespace()

wildcard = pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path")
wildcard = _configure_wildcard_path(
parser_config=parser_config,
variable_ref=variable_ref,
)

wildcard_enclosure = pp.Suppress(parser_config.wildcard_wrap)
wildcard_variable_spec = (
Expand All @@ -135,6 +135,17 @@ def _configure_wildcard(
return wildcard("wildcard").leave_whitespace()


def _configure_wildcard_path(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
) -> pp.ParserElement:
wildcard_path_literal_re = (
r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+"
)
wildcard_path = pp.Regex(wildcard_path_literal_re).leave_whitespace()
return pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path")


def _configure_literal_sequence(
parser_config: ParserConfig,
is_variant_literal: bool = False,
Expand Down Expand Up @@ -219,6 +230,7 @@ def _configure_variable_access(
)
return variable_access.leave_whitespace()


def _configure_variable_assignment(
parser_config: ParserConfig,
prompt: pp.ParserElement,
Expand Down Expand Up @@ -416,10 +428,23 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)

variable_ref = pp.Combine(_configure_variable_access(parser_config=parser_config, prompt=pp.SkipTo(parser_config.variable_end))).leave_whitespace()
variable_ref.set_parse_action(lambda s,l,t: parser_config.variable_start + "".join(t) + parser_config.variable_end)


variable_ref = pp.Combine(
_configure_variable_access(
parser_config=parser_config,
prompt=pp.SkipTo(parser_config.variable_end),
),
).leave_whitespace()

# by default, the variable starting "${" and end "}" characters are
# stripped with variable access. This restores them so that the
# variable can be properly recognized at a later stage
variable_ref.set_parse_action(
lambda s, l, t: parser_config.variable_start

Check failure on line 443 in src/dynamicprompts/parser/parse.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E741)

src/dynamicprompts/parser/parse.py:443:19: E741 Ambiguous variable name: `l`
+ "".join(t)
+ parser_config.variable_end,
)

variable_assignment = _configure_variable_assignment(
parser_config=parser_config,
prompt=variant_prompt,
Expand All @@ -428,7 +453,10 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)
wildcard = _configure_wildcard(parser_config=parser_config, variable_ref=variable_ref)
wildcard = _configure_wildcard(
parser_config=parser_config,
variable_ref=variable_ref,
)
literal_sequence = _configure_literal_sequence(parser_config=parser_config)
variant_literal_sequence = _configure_literal_sequence(
is_variant_literal=True,
Expand Down Expand Up @@ -466,7 +494,7 @@ def create_parser(
variants.set_parse_action(_parse_variant_command)
literal_sequence.set_parse_action(_parse_literal_command)
variant_literal_sequence.set_parse_action(_parse_literal_command)
variable_access.set_parse_action(_parse_variable_access_command)
variable_access.set_parse_action(_parse_variable_access_command)
variable_assignment.set_parse_action(_parse_variable_assignment_command)
prompt.set_parse_action(_parse_sequence_or_single_command)
variant_prompt.set_parse_action(_parse_sequence_or_single_command)
Expand Down
2 changes: 1 addition & 1 deletion src/dynamicprompts/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
replace_wildcard_variables
)
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
Expand Down
98 changes: 57 additions & 41 deletions src/dynamicprompts/samplers/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from __future__ import annotations

import logging
from functools import partial

import pyparsing as pp
import re

from dynamicprompts.commands import LiteralCommand, VariantCommand, VariantOption, WildcardCommand
from dynamicprompts.parser.parse import parse, _configure_variable_access, _configure_wildcard, _parse_variable_access_command
from dynamicprompts.commands import (
LiteralCommand,
VariantCommand,
VariantOption,
WildcardCommand,
)
from dynamicprompts.parser.parse import (
_configure_variable_access,
_configure_wildcard_path,
parse,
)
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
from dynamicprompts.types import ResultGen
Expand Down Expand Up @@ -55,47 +64,54 @@ def get_wildcard_not_found_fallback(
while True:
yield res


def replace_wildcard_variables(
command: WildcardCommand,
*,
command: WildcardCommand,
*,
context: SamplingContext,
) -> WildcardCommand:
) -> WildcardCommand:
if not command.wildcard.__contains__(context.parser_config.variable_start):
return command

prompt = pp.SkipTo(context.parser_config.variable_end)
variable_access = _configure_variable_access(
parser_config=context.parser_config,
prompt=prompt
prompt=prompt,
)
variable_access.set_parse_action(
partial(_replace_variable, variables=context.variables),
)
wildcard = _configure_wildcard_path(
parser_config=context.parser_config,
variable_ref=variable_access,
)
variable_access.set_parse_action(_variable_replacement(context.variables))

wildcard_path_re = r"((?!" + re.escape(context.parser_config.wildcard_wrap) + r")[^(${}#])+"
wildcard_path = pp.Regex(wildcard_path_re).leave_whitespace()
wildcard = pp.Combine(pp.OneOrMore(variable_access | wildcard_path))("path")

wildcard_result = wildcard.parse_string(command.wildcard)

return WildcardCommand(wildcard="".join(wildcard_result), sampling_method=command.sampling_method, variables=command.variables)

def _variable_replacement(variables: dict):
def var_replace(string, location, token):
if (isinstance(token, pp.ParseResults)):
var_parts = token[0].as_dict()
var_name = var_parts.get("name")
if (var_name != None):
var_name = var_name.strip()

default = var_parts.get("default")
if (default != None):
default = default.strip()

else:
var_name = token

variable = None
if (var_name != None):

variable = variables.get(var_name)

if (isinstance(variable, LiteralCommand)):
variable = variable.literal
return variable or default or var_name
return var_replace

try:
wildcard_result = wildcard.parse_string(command.wildcard)
return command.with_content("".join(wildcard_result))
except Exception as ex:

Check failure on line 92 in src/dynamicprompts/samplers/utils.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F841)

src/dynamicprompts/samplers/utils.py:92:25: F841 Local variable `ex` is assigned to but never used
logger.warning("Unable to parse wildcard path %s", command.wildcard)
return command


def _replace_variable(string, location, token, *, variables: dict):
if isinstance(token, pp.ParseResults):
var_parts = token[0].as_dict()
var_name = var_parts.get("name")
if var_name:
var_name = var_name.strip()

default = var_parts.get("default")
if default:
default = default.strip()

else:
var_name = token

variable = None
if var_name:
variable = variables.get(var_name)

if isinstance(variable, LiteralCommand):
variable = variable.literal
return variable or default or var_name
73 changes: 45 additions & 28 deletions tests/samplers/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
VariantOption,
WildcardCommand,
)
from dynamicprompts.commands.variable_commands import VariableAssignmentCommand
from dynamicprompts.enums import SamplingMethod
from dynamicprompts.parser.parse import parse
from dynamicprompts.samplers import CombinatorialSampler, CyclicalSampler, RandomSampler
Expand Down Expand Up @@ -630,72 +629,90 @@ def test_wildcard_nested_in_wildcard(
@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_nested_variable(self, sampling_context: SamplingContext):
cmd = parse("${temp=cold}wearing __colors-${temp}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value == "wearing blue suede shoes" or resolved_value == "wearing green suede shoes"
assert resolved_value in (
"wearing blue suede shoes",
"wearing green suede shoes",
)
else:
assert resolved_value == "wearing blue suede shoes"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_default_variable(self, sampling_context: SamplingContext):
cmd = parse("wearing __colors-${temp:cold}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value == "wearing blue suede shoes" or resolved_value == "wearing green suede shoes"
assert resolved_value in (
"wearing blue suede shoes",
"wearing green suede shoes",
)
else:
assert resolved_value == "wearing blue suede shoes"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_undefined_variable(self, sampling_context: SamplingContext):
cmd = parse("wearing __colors-${temp}__ suede shoes")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert resolved_value == "wearing {wildcard_wrap}colors-temp{wildcard_wrap} suede shoes".format(wildcard_wrap=sampling_context.wildcard_manager.wildcard_wrap)
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert (
resolved_value
== f"wearing {sampling_context.wildcard_manager.wildcard_wrap}colors-temp{sampling_context.wildcard_manager.wildcard_wrap} suede shoes"
)

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_multiple_variables(self, sampling_context: SamplingContext):
cmd = parse("${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
cmd = parse(
"${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__",
)
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ["cat", "tiger"]
assert resolved_value in ("cat", "tiger")
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_variable_and_glob(self, sampling_context: SamplingContext):
cmd = parse("${genus=reptiles}__animals/${genus}/*__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ["cobra", "gecko", "iguana", "python"]
assert resolved_value in ("cobra", "gecko", "iguana", "python")
else:
assert resolved_value == "cobra"
assert resolved_value == "cobra"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_wildcard_with_variable_in_nested_wildcard(self, sampling_context: SamplingContext):
def test_wildcard_with_variable_in_nested_wildcard(
self,
sampling_context: SamplingContext,
):
cmd = parse("${genus=mammals}__animal__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ["cat", "tiger", "dog", "wolf"]
assert resolved_value in ("cat", "tiger", "dog", "wolf")
else:
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_nested_wildcard_with_parameterized_variable(self, sampling_context: SamplingContext):
def test_nested_wildcard_with_parameterized_variable(
self,
sampling_context: SamplingContext,
):
cmd = parse("__animal(genus=mammals)__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
if isinstance(sampling_context.default_sampler, RandomSampler):
assert resolved_value in ["cat", "tiger", "dog", "wolf"]
assert resolved_value in ("cat", "tiger", "dog", "wolf")
else:
assert resolved_value == "cat"
assert resolved_value == "cat"

@pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures)
def test_nested_wildcards_in_single_file(self, sampling_context: SamplingContext):
cmd = parse("${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__")
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert resolved_value in [
cmd = parse(
"${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__",
)
resolved_value = str(next(sampling_context.generator_from_command(cmd)))
assert resolved_value in (
"a sports car made by Porsche",
"a tractor made by John Deere"
]

"a tractor made by John Deere",
)


class TestVariableCommands:
Expand Down
Loading

0 comments on commit fb1bfe1

Please sign in to comment.