Skip to content

Commit

Permalink
Refactor replace_wildcard_variables to use `resolve_variable_refere…
Browse files Browse the repository at this point in the history
…nces`
  • Loading branch information
akx committed Nov 23, 2023
1 parent ff47441 commit 6d83a28
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 52 deletions.
68 changes: 68 additions & 0 deletions src/dynamicprompts/parser/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from functools import partial

import pyparsing as pp

from dynamicprompts.commands import LiteralCommand
from dynamicprompts.parser.config import ParserConfig
from dynamicprompts.parser.parse import (
_configure_variable_access,
_configure_wildcard_path,
)


def resolve_variable_references(
subject: str,
*,
parser_config: ParserConfig,
variables: dict,
) -> str:
"""
Parse `subject` for variable references and resolve them.
If there are no variable references, returns `subject` unchanged.
"""
if parser_config.variable_start not in subject:
return subject
prompt = pp.SkipTo(parser_config.variable_end)
variable_access = _configure_variable_access(
parser_config=parser_config,
prompt=prompt,
)
variable_access.set_parse_action(
partial(_replace_variable, variables=variables),
)
wildcard = _configure_wildcard_path(
parser_config=parser_config,
variable_ref=variable_access,
)
return "".join(wildcard.parse_string(subject))


def _replace_variable(string, location, token, *, variables: dict):
if isinstance(token, pp.ParseResults):
var_parts = token[0].as_dict()
var_name = var_parts.get("name")
if var_name:
var_name = var_name.strip()

default = var_parts.get("default")
if default:
default = default.strip()

else:
var_name = token

variable = None
if var_name:
variable = variables.get(var_name)

if isinstance(variable, LiteralCommand):
variable = variable.literal
if variable and not isinstance(variable, str):
raise NotImplementedError(
"evaluating complex commands within wildcards is not supported right now",
)
return variable or default or var_name
61 changes: 9 additions & 52 deletions src/dynamicprompts/samplers/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from __future__ import annotations

import logging
from functools import partial

import pyparsing as pp

from dynamicprompts.commands import (
LiteralCommand,
VariantCommand,
VariantOption,
WildcardCommand,
)
from dynamicprompts.parser.parse import (
_configure_variable_access,
_configure_wildcard_path,
parse,
)
from dynamicprompts.parser.utils import resolve_variable_references
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
from dynamicprompts.types import ResultGen
Expand Down Expand Up @@ -70,52 +65,14 @@ def replace_wildcard_variables(
*,
context: SamplingContext,
) -> WildcardCommand:
if context.parser_config.variable_start not in command.wildcard:
return command

prompt = pp.SkipTo(context.parser_config.variable_end)
variable_access = _configure_variable_access(
parser_config=context.parser_config,
prompt=prompt,
)
variable_access.set_parse_action(
partial(_replace_variable, variables=context.variables),
)
wildcard = _configure_wildcard_path(
parser_config=context.parser_config,
variable_ref=variable_access,
)

try:
wildcard_result = wildcard.parse_string(command.wildcard)
return command.with_content("".join(wildcard_result))
wildcard_result = resolve_variable_references(
command.wildcard,
parser_config=context.parser_config,
variables=context.variables,
)
if wildcard_result != command.wildcard:
return command.with_content(wildcard_result)
except Exception:
logger.warning("Unable to parse wildcard %r", command.wildcard, exc_info=True)
return command


def _replace_variable(string, location, token, *, variables: dict):
if isinstance(token, pp.ParseResults):
var_parts = token[0].as_dict()
var_name = var_parts.get("name")
if var_name:
var_name = var_name.strip()

default = var_parts.get("default")
if default:
default = default.strip()

else:
var_name = token

variable = None
if var_name:
variable = variables.get(var_name)

if isinstance(variable, LiteralCommand):
variable = variable.literal
if variable and not isinstance(variable, str):
raise NotImplementedError(
"evaluating complex commands within wildcards is not supported right now",
)
return variable or default or var_name
return command

0 comments on commit 6d83a28

Please sign in to comment.