diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ffa966e..f3474b9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,6 +16,8 @@ jobs: steps: - uses: actions/checkout@v4 - uses: pre-commit/action@v3.0.0 + env: + RUFF_OUTPUT_FORMAT: github mypy: runs-on: ubuntu-latest steps: @@ -26,7 +28,7 @@ jobs: cache: "pip" cache-dependency-path: | pyproject.toml - - run: python -m pip install mypy -e .[dev,attentiongrabber,magicprompt,feelinglucky] + - run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky,yaml] - run: mypy --install-types --non-interactive src test: runs-on: ${{ matrix.os }} @@ -45,7 +47,7 @@ jobs: cache-dependency-path: | pyproject.toml - name: Install dependencies - run: python -m pip install -e .[dev,attentiongrabber,magicprompt,feelinglucky] + run: python -m pip install -e .[dev,attentiongrabber,feelinglucky,yaml] - run: pytest --cov --cov-report=term-missing --cov-report=xml . env: PYPARSINGENABLEALLWARNINGS: 1 @@ -67,9 +69,6 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.11" - cache: "pip" - cache-dependency-path: | - pyproject.toml - run: python -m pip install hatch - run: hatch build -t wheel - name: Publish package distributions to PyPI diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ddf7bc2..0fe058d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,18 +1,13 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.290 + rev: v0.1.6 hooks: - id: ruff args: - --fix + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/psf/black - rev: 23.9.1 - hooks: - - id: black - args: - - --quiet diff --git a/pyproject.toml b/pyproject.toml index 0e00396..ff19746 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ attentiongrabber = [] # empty list for backwards compatibility (no "no extra" warnings) magicprompt = ["transformers[torch]~=4.19"] feelinglucky = ["requests~=2.28"] +yaml = ["pyyaml~=6.0"] dev = [ "pytest-cov~=4.0", "pytest-lazy-fixture~=0.6", @@ -84,6 +85,10 @@ exclude = "tests" module = "transformers" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "torch" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "spacy.*" ignore_missing_imports = true diff --git a/src/dynamicprompts/commands/__init__.py b/src/dynamicprompts/commands/__init__.py index 26e0c3e..2e82c8d 100644 --- a/src/dynamicprompts/commands/__init__.py +++ b/src/dynamicprompts/commands/__init__.py @@ -3,13 +3,15 @@ from dynamicprompts.commands.sequence_command import SequenceCommand from dynamicprompts.commands.variant_command import VariantCommand, VariantOption from dynamicprompts.commands.wildcard_command import WildcardCommand +from dynamicprompts.commands.wrap_command import WrapCommand __all__ = [ "Command", "LiteralCommand", + "SamplingMethod", "SequenceCommand", "VariantCommand", "VariantOption", "WildcardCommand", - "SamplingMethod", + "WrapCommand", ] diff --git a/src/dynamicprompts/commands/wrap_command.py b/src/dynamicprompts/commands/wrap_command.py new file mode 100644 index 0000000..4f58d93 --- /dev/null +++ b/src/dynamicprompts/commands/wrap_command.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import dataclasses +import logging +import re + +from dynamicprompts.commands import Command +from dynamicprompts.enums import SamplingMethod + +log = logging.getLogger(__name__) + +WRAP_MARKER_CHARACTERS = { + "\u1801", # Mongolian ellipsis + "\u2026", # Horizontal ellipsis + "\u22EE", # Vertical ellipsis + "\u22EF", # Midline horizontal ellipsis + "\u22F0", # Up right diagonal ellipsis + "\u22F1", # Down right diagonal ellipsis + "\uFE19", # Presentation form for vertical horizontal ellipsis +} + +WRAP_MARKER_RE = re.compile( + f"[{''.join(WRAP_MARKER_CHARACTERS)}]+" # One or more wrap marker characters + "|" + r"\.{3,}", # ASCII ellipsis of 3 or more dots +) + + +def split_wrapper_string(s: str) -> tuple[str, str]: + """ + Split a string into a prefix and suffix at the first wrap marker. + """ + match = WRAP_MARKER_RE.search(s) + if match is None: + log.warning("Found no wrap marker in string %r", s) + return s, "" + else: + return s[: match.start()], s[match.end() :] + + +@dataclasses.dataclass(frozen=True) +class WrapCommand(Command): + wrapper: Command + inner: Command + sampling_method: SamplingMethod | None = None diff --git a/src/dynamicprompts/generators/magicprompt.py b/src/dynamicprompts/generators/magicprompt.py index 96edd6d..7b9b780 100644 --- a/src/dynamicprompts/generators/magicprompt.py +++ b/src/dynamicprompts/generators/magicprompt.py @@ -9,22 +9,13 @@ logger = logging.getLogger(__name__) -try: +if TYPE_CHECKING: + import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, Pipeline, - pipeline, - set_seed, ) -except ImportError as ie: - raise ImportError( - "You need to install the transformers library to use the MagicPrompt generator. " - "You can do this by running `pip install -U dynamicprompts[magicprompt]`.", - ) from ie - -if TYPE_CHECKING: - import torch DEFAULT_MODEL_NAME = "Gustavosta/MagicPrompt-Stable-Diffusion" MAX_SEED = 2**32 - 1 @@ -71,6 +62,18 @@ def clean_up_magic_prompt(orig_prompt: str, prompt: str) -> str: return prompt +def _import_transformers(): # pragma: no cover + try: + import transformers + + return transformers + except ImportError as ie: + raise ImportError( + "You need to install the transformers library to use the MagicPrompt generator. " + "You can do this by running `pip install -U dynamicprompts[magicprompt]`.", + ) from ie + + class MagicPromptGenerator(PromptGenerator): generator: Pipeline | None = None tokenizer: AutoTokenizer | None = None @@ -83,13 +86,14 @@ def _load_pipeline(self, model_name: str) -> Pipeline: logger.warning("First load of MagicPrompt may take a while.") if MagicPromptGenerator.generator is None: - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name) + transformers = _import_transformers() + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + model = transformers.AutoModelForCausalLM.from_pretrained(model_name) tokenizer.pad_token_id = model.config.eos_token_id MagicPromptGenerator.tokenizer = tokenizer MagicPromptGenerator.model = model - MagicPromptGenerator.generator = pipeline( + MagicPromptGenerator.generator = transformers.pipeline( task="text-generation", tokenizer=tokenizer, model=model, @@ -123,6 +127,7 @@ def __init__( :param blocklist_regex: A regex to use to filter out prompts that match it. :param batch_size: The batch size to use when generating prompts. """ + transformers = _import_transformers() self._device = device self.set_model(model_name) @@ -140,7 +145,7 @@ def __init__( self._blocklist_regex = None if seed is not None: - set_seed(int(seed)) + transformers.set_seed(int(seed)) self._batch_size = batch_size diff --git a/src/dynamicprompts/jinja_extensions.py b/src/dynamicprompts/jinja_extensions.py index b6c37f7..75bc57b 100644 --- a/src/dynamicprompts/jinja_extensions.py +++ b/src/dynamicprompts/jinja_extensions.py @@ -42,8 +42,11 @@ def wildcard(environment: Environment, wildcard_name: str) -> list[str]: from dynamicprompts.generators import CombinatorialPromptGenerator from dynamicprompts.wildcards import WildcardManager - wm: WildcardManager = environment.globals["wildcard_manager"] # type: ignore - generator: CombinatorialPromptGenerator = environment.globals["generators"]["combinatorial"] # type: ignore + wm = cast(WildcardManager, environment.globals["wildcard_manager"]) + generator = cast( + CombinatorialPromptGenerator, + environment.globals["generators"]["combinatorial"], # type: ignore + ) return [str(r) for r in generator.generate(wm.to_wildcard(wildcard_name))] diff --git a/src/dynamicprompts/parser/config.py b/src/dynamicprompts/parser/config.py index bed26f5..a46721f 100644 --- a/src/dynamicprompts/parser/config.py +++ b/src/dynamicprompts/parser/config.py @@ -10,6 +10,8 @@ class ParserConfig: wildcard_wrap: str = "__" variable_start: str = "${" variable_end: str = "}" + wrap_start: str = "%{" + wrap_end: str = "}" default_parser_config = ParserConfig() diff --git a/src/dynamicprompts/parser/parse.py b/src/dynamicprompts/parser/parse.py index 8aec2d0..5dd0f47 100644 --- a/src/dynamicprompts/parser/parse.py +++ b/src/dynamicprompts/parser/parse.py @@ -23,6 +23,7 @@ ::= + ::= "${" "=" "}" ::= "${" (":" )? "}" + ::= "%{" "$$" "}" Note that whitespace is preserved in case it is significant to the user. """ @@ -44,6 +45,7 @@ VariantCommand, VariantOption, WildcardCommand, + WrapCommand, ) from dynamicprompts.commands.variable_commands import ( VariableAccessCommand, @@ -62,6 +64,8 @@ sampler_cyclical = pp.Char("@") sampler_symbol = sampler_random | sampler_combinatorial | sampler_cyclical +variant_delim = pp.Suppress("$$") + OPT_WS = pp.Opt(pp.White()) # Optional whitespace var_name = pp.Word(pp.alphas + "_-", pp.alphanums + "_-") @@ -75,7 +79,6 @@ def _configure_range() -> pp.ParserElement: hyphen = pp.Suppress("-") - variant_delim = pp.Suppress("$$") # Exclude: # - $, which is used to indicate the end of the separator definition i.e. {1$$ and $$X|Y|Z} @@ -128,6 +131,17 @@ def _configure_wildcard( return wildcard("wildcard").leave_whitespace() +def _configure_wildcard_path( + parser_config: ParserConfig, + variable_ref: pp.ParserElement, +) -> pp.ParserElement: + wildcard_path_literal_re = ( + r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+" + ) + wildcard_path = pp.Regex(wildcard_path_literal_re).leave_whitespace() + return pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path") + + def _configure_literal_sequence( parser_config: ParserConfig, is_variant_literal: bool = False, @@ -137,7 +151,12 @@ def _configure_literal_sequence( # - { denotes the start of a variant (or whatever variant_start is set to ) # - # denotes the start of a comment # - $ denotes the start of a variable command (or whatever variable_start is set to) - non_literal_chars = rf"#{parser_config.variant_start}{parser_config.variable_start}" + # - % denotes the start of a wrap command (or whatever wrap_start is set to) + non_literal_chars = ( + rf"#{parser_config.variant_start}" + rf"{parser_config.variable_start}" + rf"{parser_config.wrap_start}" + ) if is_variant_literal: # Inside a variant the following characters are also not allowed @@ -234,6 +253,23 @@ def _configure_variable_assignment( return variable_assignment.leave_whitespace() +def _configure_wrap_command( + parser_config: ParserConfig, + prompt: pp.ParserElement, +) -> pp.ParserElement: + wrap_command = pp.Group( + pp.Suppress(parser_config.wrap_start) + + OPT_WS + + prompt()("wrapper") + + OPT_WS + + variant_delim + + OPT_WS + + prompt()("inner") + + pp.Suppress(parser_config.wrap_end), + ) + return wrap_command.leave_whitespace() + + def _parse_literal_command(parse_result: pp.ParseResults) -> LiteralCommand: s = " ".join(parse_result) return LiteralCommand(s) @@ -386,6 +422,16 @@ def _parse_variable_assignment_command( ) +def _parse_wrap_command( + parse_result: pp.ParseResults, +) -> WrapCommand: + parts = parse_result[0].as_dict() + return WrapCommand( + inner=parts["inner"], + wrapper=parts["wrapper"], + ) + + def create_parser( *, parser_config: ParserConfig, @@ -408,7 +454,11 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) - wildcard = _configure_wildcard( + wrap_command = _configure_wrap_command( + parser_config=parser_config, + prompt=variant_prompt, + ) + wildcard = _configure_wildcard( parser_config=parser_config, prompt=wildcard_prompt, ) @@ -428,9 +478,16 @@ def create_parser( ) chunk = ( - variable_assignment | variable_access | variants | wildcard | literal_sequence + variable_assignment + | variable_access + | wrap_command + | variants + | wildcard + | literal_sequence + ) + variant_chunk = ( + variable_access | wrap_command | variants | wildcard | variant_literal_sequence ) - variant_chunk = variable_access | variants | wildcard | variant_literal_sequence wildcard_chunk = ( wildcard_variable_access | variants @@ -459,6 +516,7 @@ def create_parser( variable_assignment.set_parse_action(_parse_variable_assignment_command) prompt.set_parse_action(_parse_sequence_or_single_command) variant_prompt.set_parse_action(_parse_sequence_or_single_command) + wrap_command.set_parse_action(_parse_wrap_command) wildcard_prompt.set_parse_action(_parse_sequence_or_single_command) return prompt diff --git a/src/dynamicprompts/samplers/base.py b/src/dynamicprompts/samplers/base.py index 1399302..a01f248 100644 --- a/src/dynamicprompts/samplers/base.py +++ b/src/dynamicprompts/samplers/base.py @@ -8,6 +8,7 @@ SequenceCommand, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.commands.variable_commands import ( VariableAccessCommand, @@ -43,6 +44,8 @@ def generator_from_command( ) if isinstance(command, VariableAccessCommand): return self._get_variable(command, context) + if isinstance(command, WrapCommand): + return self._get_wrap(command, context) return self._unsupported_command(command) def _unsupported_command(self, command: Command) -> ResultGen: @@ -100,3 +103,10 @@ def _get_variable( return context.for_sampling_variable(variable).generator_from_command( command_to_sample, ) + + def _get_wrap( + self, + command: WrapCommand, + context: SamplingContext, + ) -> ResultGen: + return self._unsupported_command(command) diff --git a/src/dynamicprompts/samplers/combinatorial.py b/src/dynamicprompts/samplers/combinatorial.py index 54a0b60..022969f 100644 --- a/src/dynamicprompts/samplers/combinatorial.py +++ b/src/dynamicprompts/samplers/combinatorial.py @@ -10,6 +10,7 @@ SequenceCommand, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.command_collection import CommandCollection @@ -160,3 +161,9 @@ def _get_literal( context: SamplingContext, ) -> ResultGen: yield SamplingResult(text=command.literal) + + def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen: + for wrapper_result in context.sample_prompts(command.wrapper): + wrap = wrapper_result.as_wrapper() + for inner in context.sample_prompts(command.inner): + yield wrap(inner) diff --git a/src/dynamicprompts/samplers/cycle.py b/src/dynamicprompts/samplers/cycle.py index a607ef6..f3e8a4f 100644 --- a/src/dynamicprompts/samplers/cycle.py +++ b/src/dynamicprompts/samplers/cycle.py @@ -12,7 +12,7 @@ ) from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( - get_wildcard_not_found_fallback, + get_wildcard_not_found_fallback, wildcard_to_variant, ) from dynamicprompts.sampling_context import SamplingContext diff --git a/src/dynamicprompts/samplers/random.py b/src/dynamicprompts/samplers/random.py index 754a1c1..f882ff6 100644 --- a/src/dynamicprompts/samplers/random.py +++ b/src/dynamicprompts/samplers/random.py @@ -8,6 +8,7 @@ Command, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( @@ -125,3 +126,11 @@ def _get_wildcard( while True: value = next(gen) yield from context.sample_prompts(value, 1) + + def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen: + wrapper_gen = context.generator_from_command(command.wrapper) + inner_gen = context.generator_from_command(command.inner) + wrapper_result: SamplingResult + inner_result: SamplingResult + for wrapper_result, inner_result in zip(wrapper_gen, inner_gen): + yield wrapper_result.as_wrapper()(inner_result) diff --git a/src/dynamicprompts/samplers/utils.py b/src/dynamicprompts/samplers/utils.py index a2a32f4..1d6ea5f 100644 --- a/src/dynamicprompts/samplers/utils.py +++ b/src/dynamicprompts/samplers/utils.py @@ -1,6 +1,9 @@ from __future__ import annotations import logging +from functools import partial + +import pyparsing as pp from dynamicprompts.commands import ( Command, @@ -24,7 +27,8 @@ def wildcard_to_variant( max_bound=1, separator=",", ) -> VariantCommand: - values = context.wildcard_manager.get_values(command.wildcard) + wildcard = next(context.sample_prompts(command.wildcard, 1)).text + values = context.wildcard_manager.get_values(wildcard) min_bound = min(min_bound, len(values)) max_bound = min(max_bound, len(values)) diff --git a/src/dynamicprompts/sampling_result.py b/src/dynamicprompts/sampling_result.py index 3603c6f..eb52a4d 100644 --- a/src/dynamicprompts/sampling_result.py +++ b/src/dynamicprompts/sampling_result.py @@ -3,6 +3,8 @@ import dataclasses from typing import Iterable +from dynamicprompts.commands.wrap_command import split_wrapper_string + @dataclasses.dataclass(frozen=True) class SamplingResult: @@ -26,6 +28,23 @@ def whitespace_squashed(self) -> SamplingResult: return dataclasses.replace(self, text=squash_whitespace(self.text)) + def text_replaced(self, new_text: str) -> SamplingResult: + return dataclasses.replace(self, text=new_text) + + def as_wrapper(self): + """ + Return a function that wraps a SamplingResult with this one, + partitioning this result's text along the wrap marker. + """ + prefix, suffix = split_wrapper_string(self.text) + prefix_res = self.text_replaced(prefix) + suffix_res = self.text_replaced(suffix) + + def wrapper(inner: SamplingResult) -> SamplingResult: + return SamplingResult.joined([prefix_res, inner, suffix_res], separator="") + + return wrapper + @classmethod def joined( cls, diff --git a/tests/generators/test_magicprompt.py b/tests/generators/test_magicprompt.py index d81ee90..beff031 100644 --- a/tests/generators/test_magicprompt.py +++ b/tests/generators/test_magicprompt.py @@ -6,17 +6,20 @@ import pytest -pytest.importorskip("dynamicprompts.generators.magicprompt") +@pytest.fixture(autouse=True) +def mock_import_transformers(monkeypatch): + from dynamicprompts.generators import magicprompt -@pytest.mark.slow -class TestMagicPrompt: - def test_default_generator(self): - from dynamicprompts.generators.dummygenerator import DummyGenerator - from dynamicprompts.generators.magicprompt import MagicPromptGenerator + monkeypatch.setattr(magicprompt, "_import_transformers", MagicMock()) + + +def test_default_generator(): + from dynamicprompts.generators.dummygenerator import DummyGenerator + from dynamicprompts.generators.magicprompt import MagicPromptGenerator - generator = MagicPromptGenerator() - assert isinstance(generator._prompt_generator, DummyGenerator) + generator = MagicPromptGenerator() + assert isinstance(generator._prompt_generator, DummyGenerator) @pytest.mark.parametrize( @@ -121,7 +124,6 @@ def _generator( assert not any(artist in magic_prompt for artist in boring_artists) -@pytest.mark.slow def test_generate_passes_kwargs(): from dynamicprompts.generators.magicprompt import MagicPromptGenerator diff --git a/tests/parser/test_commands.py b/tests/parser/test_commands.py index 565d37b..addf117 100644 --- a/tests/parser/test_commands.py +++ b/tests/parser/test_commands.py @@ -62,7 +62,7 @@ def test_combinations(self): variant_command = VariantCommand.from_literals_and_weights(ONE_TWO_THREE) assert [ - v.literal for v, in variant_command.get_value_combinations(1) + v.literal for (v,) in variant_command.get_value_combinations(1) ] == ONE_TWO_THREE assert [ diff --git a/tests/parser/test_parser.py b/tests/parser/test_parser.py index cec2841..520124e 100644 --- a/tests/parser/test_parser.py +++ b/tests/parser/test_parser.py @@ -47,7 +47,7 @@ def test_literal_characters(self, input: str): "input", [ "colours", - "path/to/colours", + "path/to/colours", "änder", ], ) diff --git a/tests/test_data/wildcards/animal.txt b/tests/test_data/wildcards/animal.txt new file mode 100644 index 0000000..46c0125 --- /dev/null +++ b/tests/test_data/wildcards/animal.txt @@ -0,0 +1 @@ +__animals/${genus:*}/*__ diff --git a/tests/test_data/wildcards/animals/reptiles/lizards.yaml b/tests/test_data/wildcards/animals/reptiles/lizards.yaml new file mode 100644 index 0000000..e6e6fb5 --- /dev/null +++ b/tests/test_data/wildcards/animals/reptiles/lizards.yaml @@ -0,0 +1,2 @@ +- iguana +- gecko diff --git a/tests/test_data/wildcards/animals/reptiles/snakes.txt b/tests/test_data/wildcards/animals/reptiles/snakes.txt new file mode 100644 index 0000000..fe228ad --- /dev/null +++ b/tests/test_data/wildcards/animals/reptiles/snakes.txt @@ -0,0 +1,2 @@ +python +cobra diff --git a/tests/test_data/wildcards/cars.yaml b/tests/test_data/wildcards/cars.yaml new file mode 100644 index 0000000..ec28991 --- /dev/null +++ b/tests/test_data/wildcards/cars.yaml @@ -0,0 +1,31 @@ +cars: + ford: + name: + - Ford + types: + - muscle car + - compact car + - roadster + colors: + - white + - black + - brown + - red + + porsche: + name: + - Porsche + types: + - sports car + colors: + - red + - black + + john_deere: + name: + - John Deere + types: + - tractor + colors: + - green + - yellow diff --git a/tests/test_data/wildcards/wrappers.txt b/tests/test_data/wildcards/wrappers.txt new file mode 100644 index 0000000..b4a23ea --- /dev/null +++ b/tests/test_data/wildcards/wrappers.txt @@ -0,0 +1,2 @@ +Art Deco, ..., sleek, geometric forms, art deco style +Pop Art, ....., vivid colors, flat color, 2D, strong lines, Pop Art diff --git a/tests/test_wrapping.py b/tests/test_wrapping.py new file mode 100644 index 0000000..ea365de --- /dev/null +++ b/tests/test_wrapping.py @@ -0,0 +1,66 @@ +import pytest +from dynamicprompts.enums import SamplingMethod +from dynamicprompts.parser.parse import parse +from dynamicprompts.sampling_context import SamplingContext +from dynamicprompts.wildcards import WildcardManager + +from tests.utils import sample_n + + +# Methods currently supported by wrap command +@pytest.fixture( + params=[ + SamplingMethod.COMBINATORIAL, + SamplingMethod.RANDOM, + ], +) +def scon(request, wildcard_manager: WildcardManager) -> SamplingContext: + return SamplingContext( + default_sampling_method=request.param, + wildcard_manager=wildcard_manager, + ) + + +def test_wrap_with_wildcard(scon: SamplingContext): + cmd = parse("%{__wrappers__$${fox|cow}}") + assert sample_n(cmd, scon, n=4) == { + "Art Deco, cow, sleek, geometric forms, art deco style", + "Art Deco, fox, sleek, geometric forms, art deco style", + "Pop Art, cow, vivid colors, flat color, 2D, strong lines, Pop Art", + "Pop Art, fox, vivid colors, flat color, 2D, strong lines, Pop Art", + } + + +@pytest.mark.parametrize("placeholder", ["…", "᠁", ".........", "..."]) +def test_wrap_with_literal(scon: SamplingContext, placeholder: str): + cmd = parse("%{happy ... on a meadow$${fox|cow}}".replace("...", placeholder)) + assert sample_n(cmd, scon, n=2) == { + "happy fox on a meadow", + "happy cow on a meadow", + } + + +def test_bad_wrap_is_prefix(scon: SamplingContext): + cmd = parse("%{happy $${fox|cow}}") + assert sample_n(cmd, scon, n=2) == { + "happy fox", + "happy cow", + } + + +def test_wrap_suffix(scon: SamplingContext): + cmd = parse("%{... in jail$${fox|cow}}") + assert sample_n(cmd, scon, n=2) == { + "fox in jail", + "cow in jail", + } + + +def test_wrap_with_variant(scon): + cmd = parse("%{ {cool|hot} ...$${fox|cow}}") + assert sample_n(cmd, scon, n=4) == { + "cool fox", + "cool cow", + "hot fox", + "hot cow", + } diff --git a/tests/utils.py b/tests/utils.py index e17a7f6..86ee8c1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,8 @@ from __future__ import annotations +from dynamicprompts.commands import Command +from dynamicprompts.sampling_context import SamplingContext + def cross(list1: list[str], list2: list[str], sep=",") -> list[str]: return [f"{x}{sep}{y}" for x in list1 for y in list2 if x != y] @@ -15,3 +18,15 @@ def interleave(list1: list[str], list2: list[str]) -> list[str]: new_list[1::2] = list2 return new_list + + +def sample_n(cmd: Command, scon: SamplingContext, n: int) -> set[str]: + """ + Sample until we have n unique prompts. + """ + seen = set() + for p in scon.sample_prompts(cmd): + seen.add(str(p)) + if len(seen) == n: + break + return seen diff --git a/tests/wildcard/test_wildcardmanager.py b/tests/wildcard/test_wildcardmanager.py index 7e7a7e4..da04739 100644 --- a/tests/wildcard/test_wildcardmanager.py +++ b/tests/wildcard/test_wildcardmanager.py @@ -123,12 +123,24 @@ def test_hierarchy(wildcard_manager: WildcardManager): root = wildcard_manager.tree.root assert {name for name, item in root.walk_items()} == { "dupes", + "animal", "animals/all-references", "animals/mammals/canine", "animals/mammals/feline", "animals/mystical", + "animals/reptiles/lizards", + "animals/reptiles/snakes", "artists/dutch", "artists/finnish", + "cars/ford/colors", + "cars/ford/name", + "cars/ford/types", + "cars/john_deere/colors", + "cars/john_deere/name", + "cars/john_deere/types", + "cars/porsche/colors", + "cars/porsche/name", + "cars/porsche/types", "clothing", "colors-cold", "colors-warm", @@ -141,15 +153,18 @@ def test_hierarchy(wildcard_manager: WildcardManager): "variant", "weighted-animals/heavy", "weighted-animals/light", + "wrappers", } assert set(root.collections) == { + "animal", "clothing", # from pantry YAML "colors-cold", # .txt "colors-warm", # .txt + "dupes", # .txt "referencing-colors", # .txt "shapes", # flat list YAML "variant", # .txt - "dupes", # .txt + "wrappers", # .txt } assert set(root.child_nodes["animals"].collections) == { "all-references", @@ -159,11 +174,17 @@ def test_hierarchy(wildcard_manager: WildcardManager): "canine", "feline", } + assert set(root.child_nodes["animals"].child_nodes["reptiles"].collections) == { + "lizards", + "snakes", + }, "animals/reptiles does not match" assert set(root.child_nodes["animals"].walk_full_names()) == { "animals/all-references", "animals/mammals/canine", "animals/mammals/feline", "animals/mystical", + "animals/reptiles/lizards", + "animals/reptiles/snakes", } assert set(root.child_nodes["flavors"].collections) == { "sour", # .txt @@ -334,6 +355,8 @@ def test_wcm_roots(): "elaimet/mammals/canine", "elaimet/mammals/feline", "elaimet/mystical", + "elaimet/reptiles/lizards", + "elaimet/reptiles/snakes", "elaimet/sopot", "metasyntactic/fnord", "metasyntactic/foo", @@ -342,9 +365,13 @@ def test_wcm_roots(): v for v in wcm.get_values("elaimet/*").string_values if not v.startswith("_") } == { "cat", + "cobra", "dog", + "gecko", + "iguana", "okapi", "pingviini", + "python", "tiger", "unicorn", "wolf",