Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) Dynamic Wildcards #114

Merged
merged 9 commits into from
Mar 21, 2024
11 changes: 5 additions & 6 deletions src/dynamicprompts/commands/wildcard_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

@dataclasses.dataclass(frozen=True)
class WildcardCommand(Command):
wildcard: str
wildcard: Command | str
sampling_method: SamplingMethod | None = None
variables: dict[str, Command] = dataclasses.field(default_factory=dict)

def __post_init__(self):
if not isinstance(self.wildcard, str):
raise TypeError(f"Wildcard must be a string, not {type(self.wildcard)}")

def with_content(self, content: str) -> WildcardCommand:
return dataclasses.replace(self, wildcard=content)
if not isinstance(self.wildcard, (Command, str)):
raise TypeError(
f"Wildcard must be a Command or str, not {type(self.wildcard)}",
)
71 changes: 45 additions & 26 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,21 @@ def _configure_range() -> pp.ParserElement:

def _configure_wildcard(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
prompt: pp.ParserElement,
) -> pp.ParserElement:
wildcard = _configure_wildcard_path(
parser_config=parser_config,
variable_ref=variable_ref,
)

wildcard_enclosure = pp.Suppress(parser_config.wildcard_wrap)
wildcard_variable_spec = (
OPT_WS
+ pp.Suppress("(")
+ pp.Literal("(")
+ pp.Regex(r"[^)]+")("variable_spec")
+ pp.Suppress(")")
)
wc_path = prompt()("path")

wildcard = (
wildcard_enclosure
+ pp.Opt(sampler_symbol)("sampling_method")
+ wildcard
+ wc_path
+ pp.Opt(wildcard_variable_spec)
+ wildcard_enclosure
)
Expand All @@ -149,6 +145,7 @@ def _configure_wildcard_path(
def _configure_literal_sequence(
parser_config: ParserConfig,
is_variant_literal: bool = False,
is_wildcard_literal: bool = False,
) -> pp.ParserElement:
# Characters that are not allowed in a literal
# - { denotes the start of a variant (or whatever variant_start is set to )
Expand All @@ -168,6 +165,12 @@ def _configure_literal_sequence(
# - $ denotes the end of a bound expression
non_literal_chars += rf"|${parser_config.variant_end}"

if is_wildcard_literal:
# Inside a wildcard the following characters are also not allowed
# - ( denotes the beginning of wildcard variable parameters
# - ) denotes the end of wildcard variable parameters
non_literal_chars += r")("

non_literal_chars = re.escape(non_literal_chars)
literal = pp.Regex(
rf"((?!{re.escape(parser_config.wildcard_wrap)})[^{non_literal_chars}])+",
Expand Down Expand Up @@ -356,7 +359,9 @@ def _parse_wildcard_command(
else:
variables = {}

assert isinstance(wildcard, str)
assert isinstance(wildcard, (Command, str))
if isinstance(wildcard, LiteralCommand):
wildcard = wildcard.literal
return WildcardCommand(
wildcard=wildcard,
sampling_method=sampling_method,
Expand Down Expand Up @@ -394,6 +399,18 @@ def _parse_variable_access_command(
return VariableAccessCommand(name=parts["name"], default=parts.get("default"))


def _parse_wildcard_variable_access_command(
parse_result: pp.ParseResults,
) -> VariableAccessCommand:
parts = parse_result[0].as_dict()
name = parts["name"]
default = parts.get("default") or LiteralCommand(name)
return VariableAccessCommand(
name=name,
default=LiteralCommand(default.literal.strip()),
)


def _parse_variable_assignment_command(
parse_result: pp.ParseResults,
) -> VariableAssignmentCommand:
Expand Down Expand Up @@ -423,28 +440,16 @@ def create_parser(

prompt = pp.Forward()
variant_prompt = pp.Forward()
wildcard_prompt = pp.Forward()

variable_access = _configure_variable_access(
parser_config=parser_config,
prompt=variant_prompt,
)

variable_ref = pp.Combine(
_configure_variable_access(
parser_config=parser_config,
prompt=pp.SkipTo(parser_config.variable_end),
),
).leave_whitespace()

# by default, the variable starting "${" and end "}" characters are
# stripped with variable access. This restores them so that the
# variable can be properly recognized at a later stage
variable_ref.set_parse_action(
lambda string,
location,
token: f"{parser_config.variable_start}{''.join(token)}{parser_config.variable_end}",
wildcard_variable_access = _configure_variable_access(
parser_config=parser_config,
prompt=variant_prompt,
)

variable_assignment = _configure_variable_assignment(
parser_config=parser_config,
prompt=variant_prompt,
Expand All @@ -455,9 +460,13 @@ def create_parser(
)
wildcard = _configure_wildcard(
parser_config=parser_config,
variable_ref=variable_ref,
prompt=wildcard_prompt,
)
literal_sequence = _configure_literal_sequence(parser_config=parser_config)
wildcard_literal_sequence = _configure_literal_sequence(
parser_config=parser_config,
is_wildcard_literal=True,
)
variant_literal_sequence = _configure_literal_sequence(
is_variant_literal=True,
parser_config=parser_config,
Expand All @@ -479,9 +488,16 @@ def create_parser(
variant_chunk = (
variable_access | wrap_command | variants | wildcard | variant_literal_sequence
)
wildcard_chunk = (
wildcard_variable_access
| variants
| wildcard_literal_sequence
| variant_literal_sequence
)

prompt <<= pp.ZeroOrMore(chunk)("prompt")
variant_prompt <<= pp.ZeroOrMore(variant_chunk)("prompt")
wildcard_prompt <<= pp.OneOrMore(wildcard_chunk, stop_on=pp.Char("("))("prompt")

# Configure comments
prompt.ignore("#" + pp.restOfLine)
Expand All @@ -493,12 +509,15 @@ def create_parser(
)
variants.set_parse_action(_parse_variant_command)
literal_sequence.set_parse_action(_parse_literal_command)
wildcard_literal_sequence.set_parse_action(_parse_literal_command)
variant_literal_sequence.set_parse_action(_parse_literal_command)
variable_access.set_parse_action(_parse_variable_access_command)
wildcard_variable_access.set_parse_action(_parse_wildcard_variable_access_command)
variable_assignment.set_parse_action(_parse_variable_assignment_command)
prompt.set_parse_action(_parse_sequence_or_single_command)
variant_prompt.set_parse_action(_parse_sequence_or_single_command)
wrap_command.set_parse_action(_parse_wrap_command)
wildcard_prompt.set_parse_action(_parse_sequence_or_single_command)
return prompt


Expand Down
6 changes: 3 additions & 3 deletions src/dynamicprompts/samplers/combinatorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from dynamicprompts.samplers.command_collection import CommandCollection
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -144,9 +143,10 @@ def _get_wildcard(
context: SamplingContext,
) -> ResultGen:
# TODO: doesn't support weights
command = replace_wildcard_variables(command=command, context=context)
wildcard_path = next(iter(context.sample_prompts(command.wildcard, 1))).text
context = context.with_variables(command.variables)
values = context.wildcard_manager.get_values(command.wildcard)
values = context.wildcard_manager.get_values(wildcard_path)

if not values:
yield from get_wildcard_not_found_fallback(command, context)
return
Expand Down
5 changes: 2 additions & 3 deletions src/dynamicprompts/samplers/cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -101,8 +100,8 @@ def _get_wildcard(
context: SamplingContext,
) -> ResultGen:
# TODO: doesn't support weights
command = replace_wildcard_variables(command=command, context=context)
wc_values = context.wildcard_manager.get_values(command.wildcard)
wildcard_path = next(iter(context.sample_prompts(command.wildcard, 1))).text
wc_values = context.wildcard_manager.get_values(wildcard_path)
new_context = context.with_variables(
command.variables,
).with_sampling_method(SamplingMethod.CYCLICAL)
Expand Down
5 changes: 2 additions & 3 deletions src/dynamicprompts/samplers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from dynamicprompts.samplers.base import Sampler
from dynamicprompts.samplers.utils import (
get_wildcard_not_found_fallback,
replace_wildcard_variables,
wildcard_to_variant,
)
from dynamicprompts.sampling_context import SamplingContext
Expand Down Expand Up @@ -115,9 +114,9 @@ def _get_wildcard(
command: WildcardCommand,
context: SamplingContext,
) -> ResultGen:
command = replace_wildcard_variables(command=command, context=context)
wildcard_path = next(iter(context.sample_prompts(command.wildcard, 1))).text
context = context.with_variables(command.variables)
values = context.wildcard_manager.get_values(command.wildcard)
values = context.wildcard_manager.get_values(wildcard_path)

if len(values) == 0:
yield from get_wildcard_not_found_fallback(command, context)
Expand Down
79 changes: 10 additions & 69 deletions src/dynamicprompts/samplers/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from __future__ import annotations

import logging
from functools import partial

import pyparsing as pp

from dynamicprompts.commands import (
LiteralCommand,
Command,
VariantCommand,
VariantOption,
WildcardCommand,
)
from dynamicprompts.parser.parse import (
_configure_variable_access,
_configure_wildcard_path,
parse,
)
from dynamicprompts.parser.parse import parse
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
from dynamicprompts.types import ResultGen
Expand All @@ -31,8 +24,8 @@ def wildcard_to_variant(
max_bound=1,
separator=",",
) -> VariantCommand:
command = replace_wildcard_variables(command=command, context=context)
values = context.wildcard_manager.get_values(command.wildcard)
wildcard = next(iter(context.sample_prompts(command.wildcard, 1))).text
values = context.wildcard_manager.get_values(wildcard)
min_bound = min(min_bound, len(values))
max_bound = min(max_bound, len(values))

Expand All @@ -58,64 +51,12 @@ def get_wildcard_not_found_fallback(
"""
Logs a warning, then infinitely yields the wrapped wildcard.
"""
logger.warning(f"No values found for wildcard {command.wildcard!r}")
wrapped_wildcard = context.wildcard_manager.to_wildcard(command.wildcard)
if isinstance(command.wildcard, Command):
wildcard = next(iter(context.sample_prompts(command.wildcard, 1))).text
else:
wildcard = str(command.wildcard)
logger.warning(f"No values found for wildcard {wildcard!r}")
wrapped_wildcard = context.wildcard_manager.to_wildcard(wildcard)
res = SamplingResult(text=wrapped_wildcard)
while True:
yield res


def replace_wildcard_variables(
command: WildcardCommand,
*,
context: SamplingContext,
) -> WildcardCommand:
if context.parser_config.variable_start not in command.wildcard:
return command

prompt = pp.SkipTo(context.parser_config.variable_end)
variable_access = _configure_variable_access(
parser_config=context.parser_config,
prompt=prompt,
)
variable_access.set_parse_action(
partial(_replace_variable, variables=context.variables),
)
wildcard = _configure_wildcard_path(
parser_config=context.parser_config,
variable_ref=variable_access,
)

try:
wildcard_result = wildcard.parse_string(command.wildcard)
return command.with_content("".join(wildcard_result))
except Exception:
logger.warning("Unable to parse wildcard %r", command.wildcard, exc_info=True)
return command


def _replace_variable(string, location, token, *, variables: dict):
if isinstance(token, pp.ParseResults):
var_parts = token[0].as_dict()
var_name = var_parts.get("name")
if var_name:
var_name = var_name.strip()

default = var_parts.get("default")
if default:
default = default.strip()

else:
var_name = token

variable = None
if var_name:
variable = variables.get(var_name)

if isinstance(variable, LiteralCommand):
variable = variable.literal
if variable and not isinstance(variable, str):
raise NotImplementedError(
"evaluating complex commands within wildcards is not supported right now",
)
return variable or default or var_name
Loading