diff --git a/src/dynamicprompts/parser/utils.py b/src/dynamicprompts/parser/utils.py new file mode 100644 index 0000000..b7210a3 --- /dev/null +++ b/src/dynamicprompts/parser/utils.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from functools import partial + +import pyparsing as pp + +from dynamicprompts.commands import LiteralCommand +from dynamicprompts.parser.config import ParserConfig +from dynamicprompts.parser.parse import ( + _configure_variable_access, + _configure_wildcard_path, +) + + +def resolve_variable_references( + subject: str, + *, + parser_config: ParserConfig, + variables: dict, +) -> str: + """ + Parse `subject` for variable references and resolve them. + + If there are no variable references, returns `subject` unchanged. + + """ + if parser_config.variable_start not in subject: + return subject + prompt = pp.SkipTo(parser_config.variable_end) + variable_access = _configure_variable_access( + parser_config=parser_config, + prompt=prompt, + ) + variable_access.set_parse_action( + partial(_replace_variable, variables=variables), + ) + wildcard = _configure_wildcard_path( + parser_config=parser_config, + variable_ref=variable_access, + ) + return "".join(wildcard.parse_string(subject)) + + +def _replace_variable(string, location, token, *, variables: dict): + if isinstance(token, pp.ParseResults): + var_parts = token[0].as_dict() + var_name = var_parts.get("name") + if var_name: + var_name = var_name.strip() + + default = var_parts.get("default") + if default: + default = default.strip() + + else: + var_name = token + + variable = None + if var_name: + variable = variables.get(var_name) + + if isinstance(variable, LiteralCommand): + variable = variable.literal + if variable and not isinstance(variable, str): + raise NotImplementedError( + "evaluating complex commands within wildcards is not supported right now", + ) + return variable or default or var_name diff --git a/src/dynamicprompts/samplers/utils.py b/src/dynamicprompts/samplers/utils.py index 55a8757..ed9af6e 100644 --- a/src/dynamicprompts/samplers/utils.py +++ b/src/dynamicprompts/samplers/utils.py @@ -1,21 +1,16 @@ from __future__ import annotations import logging -from functools import partial - -import pyparsing as pp from dynamicprompts.commands import ( - LiteralCommand, VariantCommand, VariantOption, WildcardCommand, ) from dynamicprompts.parser.parse import ( - _configure_variable_access, - _configure_wildcard_path, parse, ) +from dynamicprompts.parser.utils import resolve_variable_references from dynamicprompts.sampling_context import SamplingContext from dynamicprompts.sampling_result import SamplingResult from dynamicprompts.types import ResultGen @@ -70,52 +65,14 @@ def replace_wildcard_variables( *, context: SamplingContext, ) -> WildcardCommand: - if context.parser_config.variable_start not in command.wildcard: - return command - - prompt = pp.SkipTo(context.parser_config.variable_end) - variable_access = _configure_variable_access( - parser_config=context.parser_config, - prompt=prompt, - ) - variable_access.set_parse_action( - partial(_replace_variable, variables=context.variables), - ) - wildcard = _configure_wildcard_path( - parser_config=context.parser_config, - variable_ref=variable_access, - ) - try: - wildcard_result = wildcard.parse_string(command.wildcard) - return command.with_content("".join(wildcard_result)) + wildcard_result = resolve_variable_references( + command.wildcard, + parser_config=context.parser_config, + variables=context.variables, + ) + if wildcard_result != command.wildcard: + return command.with_content(wildcard_result) except Exception: logger.warning("Unable to parse wildcard %r", command.wildcard, exc_info=True) - return command - - -def _replace_variable(string, location, token, *, variables: dict): - if isinstance(token, pp.ParseResults): - var_parts = token[0].as_dict() - var_name = var_parts.get("name") - if var_name: - var_name = var_name.strip() - - default = var_parts.get("default") - if default: - default = default.strip() - - else: - var_name = token - - variable = None - if var_name: - variable = variables.get(var_name) - - if isinstance(variable, LiteralCommand): - variable = variable.literal - if variable and not isinstance(variable, str): - raise NotImplementedError( - "evaluating complex commands within wildcards is not supported right now", - ) - return variable or default or var_name + return command