From 719ab7142d293311b7186aa898f48c3966150b9f Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 13 Nov 2023 12:02:08 +0200 Subject: [PATCH 1/7] CI: don't use big cache for publish step --- .github/workflows/ci.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ffa966e..68ad75f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,9 +67,6 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.11" - cache: "pip" - cache-dependency-path: | - pyproject.toml - run: python -m pip install hatch - run: hatch build -t wheel - name: Publish package distributions to PyPI From f919d45e31d1a15f197c25a5086e93a07b7ee5ad Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 13 Nov 2023 12:14:36 +0200 Subject: [PATCH 2/7] Late-import transformers for magic prompt --- pyproject.toml | 4 +++ src/dynamicprompts/generators/magicprompt.py | 35 +++++++++++--------- tests/generators/test_magicprompt.py | 20 ++++++----- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0e00396..804b99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,10 @@ exclude = "tests" module = "transformers" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "torch" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "spacy.*" ignore_missing_imports = true diff --git a/src/dynamicprompts/generators/magicprompt.py b/src/dynamicprompts/generators/magicprompt.py index 96edd6d..7b9b780 100644 --- a/src/dynamicprompts/generators/magicprompt.py +++ b/src/dynamicprompts/generators/magicprompt.py @@ -9,22 +9,13 @@ logger = logging.getLogger(__name__) -try: +if TYPE_CHECKING: + import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, Pipeline, - pipeline, - set_seed, ) -except ImportError as ie: - raise ImportError( - "You need to install the transformers library to use the MagicPrompt generator. " - "You can do this by running `pip install -U dynamicprompts[magicprompt]`.", - ) from ie - -if TYPE_CHECKING: - import torch DEFAULT_MODEL_NAME = "Gustavosta/MagicPrompt-Stable-Diffusion" MAX_SEED = 2**32 - 1 @@ -71,6 +62,18 @@ def clean_up_magic_prompt(orig_prompt: str, prompt: str) -> str: return prompt +def _import_transformers(): # pragma: no cover + try: + import transformers + + return transformers + except ImportError as ie: + raise ImportError( + "You need to install the transformers library to use the MagicPrompt generator. " + "You can do this by running `pip install -U dynamicprompts[magicprompt]`.", + ) from ie + + class MagicPromptGenerator(PromptGenerator): generator: Pipeline | None = None tokenizer: AutoTokenizer | None = None @@ -83,13 +86,14 @@ def _load_pipeline(self, model_name: str) -> Pipeline: logger.warning("First load of MagicPrompt may take a while.") if MagicPromptGenerator.generator is None: - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name) + transformers = _import_transformers() + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + model = transformers.AutoModelForCausalLM.from_pretrained(model_name) tokenizer.pad_token_id = model.config.eos_token_id MagicPromptGenerator.tokenizer = tokenizer MagicPromptGenerator.model = model - MagicPromptGenerator.generator = pipeline( + MagicPromptGenerator.generator = transformers.pipeline( task="text-generation", tokenizer=tokenizer, model=model, @@ -123,6 +127,7 @@ def __init__( :param blocklist_regex: A regex to use to filter out prompts that match it. :param batch_size: The batch size to use when generating prompts. """ + transformers = _import_transformers() self._device = device self.set_model(model_name) @@ -140,7 +145,7 @@ def __init__( self._blocklist_regex = None if seed is not None: - set_seed(int(seed)) + transformers.set_seed(int(seed)) self._batch_size = batch_size diff --git a/tests/generators/test_magicprompt.py b/tests/generators/test_magicprompt.py index d81ee90..beff031 100644 --- a/tests/generators/test_magicprompt.py +++ b/tests/generators/test_magicprompt.py @@ -6,17 +6,20 @@ import pytest -pytest.importorskip("dynamicprompts.generators.magicprompt") +@pytest.fixture(autouse=True) +def mock_import_transformers(monkeypatch): + from dynamicprompts.generators import magicprompt -@pytest.mark.slow -class TestMagicPrompt: - def test_default_generator(self): - from dynamicprompts.generators.dummygenerator import DummyGenerator - from dynamicprompts.generators.magicprompt import MagicPromptGenerator + monkeypatch.setattr(magicprompt, "_import_transformers", MagicMock()) + + +def test_default_generator(): + from dynamicprompts.generators.dummygenerator import DummyGenerator + from dynamicprompts.generators.magicprompt import MagicPromptGenerator - generator = MagicPromptGenerator() - assert isinstance(generator._prompt_generator, DummyGenerator) + generator = MagicPromptGenerator() + assert isinstance(generator._prompt_generator, DummyGenerator) @pytest.mark.parametrize( @@ -121,7 +124,6 @@ def _generator( assert not any(artist in magic_prompt for artist in boring_artists) -@pytest.mark.slow def test_generate_passes_kwargs(): from dynamicprompts.generators.magicprompt import MagicPromptGenerator From 93980f816248c74bd4118d3881bc9e28126bbf16 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 13 Nov 2023 12:15:06 +0200 Subject: [PATCH 3/7] CI: don't bother with transformers --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68ad75f..54e9248 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: cache: "pip" cache-dependency-path: | pyproject.toml - - run: python -m pip install mypy -e .[dev,attentiongrabber,magicprompt,feelinglucky] + - run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky] - run: mypy --install-types --non-interactive src test: runs-on: ${{ matrix.os }} @@ -45,7 +45,7 @@ jobs: cache-dependency-path: | pyproject.toml - name: Install dependencies - run: python -m pip install -e .[dev,attentiongrabber,magicprompt,feelinglucky] + run: python -m pip install -e .[dev,attentiongrabber,feelinglucky] - run: pytest --cov --cov-report=term-missing --cov-report=xml . env: PYPARSINGENABLEALLWARNINGS: 1 From c9df45e246ccec0bff52fa24ab73e0db8e0336fc Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 13 Nov 2023 12:27:47 +0200 Subject: [PATCH 4/7] Add `yaml` extra (now that it's not implicitly required via transformers) --- .github/workflows/ci.yml | 4 ++-- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 54e9248..1532dea 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: cache: "pip" cache-dependency-path: | pyproject.toml - - run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky] + - run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky,yaml] - run: mypy --install-types --non-interactive src test: runs-on: ${{ matrix.os }} @@ -45,7 +45,7 @@ jobs: cache-dependency-path: | pyproject.toml - name: Install dependencies - run: python -m pip install -e .[dev,attentiongrabber,feelinglucky] + run: python -m pip install -e .[dev,attentiongrabber,feelinglucky,yaml] - run: pytest --cov --cov-report=term-missing --cov-report=xml . env: PYPARSINGENABLEALLWARNINGS: 1 diff --git a/pyproject.toml b/pyproject.toml index 804b99c..ff19746 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ attentiongrabber = [] # empty list for backwards compatibility (no "no extra" warnings) magicprompt = ["transformers[torch]~=4.19"] feelinglucky = ["requests~=2.28"] +yaml = ["pyyaml~=6.0"] dev = [ "pytest-cov~=4.0", "pytest-lazy-fixture~=0.6", From e3afdf4d9664585859c2cda671d1dcd2b18730ed Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 20 Nov 2023 08:42:10 +0200 Subject: [PATCH 5/7] Upgrade CI/lint tools (#111) * Update pre-commit tools * Switch formatting from black to ruff-format * CI: use GitHub output format for Ruff --- .github/workflows/ci.yml | 2 ++ .pre-commit-config.yaml | 11 +++-------- src/dynamicprompts/jinja_extensions.py | 7 +++++-- tests/parser/test_commands.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1532dea..f3474b9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,6 +16,8 @@ jobs: steps: - uses: actions/checkout@v4 - uses: pre-commit/action@v3.0.0 + env: + RUFF_OUTPUT_FORMAT: github mypy: runs-on: ubuntu-latest steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ddf7bc2..0fe058d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,18 +1,13 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.290 + rev: v0.1.6 hooks: - id: ruff args: - --fix + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/psf/black - rev: 23.9.1 - hooks: - - id: black - args: - - --quiet diff --git a/src/dynamicprompts/jinja_extensions.py b/src/dynamicprompts/jinja_extensions.py index b6c37f7..75bc57b 100644 --- a/src/dynamicprompts/jinja_extensions.py +++ b/src/dynamicprompts/jinja_extensions.py @@ -42,8 +42,11 @@ def wildcard(environment: Environment, wildcard_name: str) -> list[str]: from dynamicprompts.generators import CombinatorialPromptGenerator from dynamicprompts.wildcards import WildcardManager - wm: WildcardManager = environment.globals["wildcard_manager"] # type: ignore - generator: CombinatorialPromptGenerator = environment.globals["generators"]["combinatorial"] # type: ignore + wm = cast(WildcardManager, environment.globals["wildcard_manager"]) + generator = cast( + CombinatorialPromptGenerator, + environment.globals["generators"]["combinatorial"], # type: ignore + ) return [str(r) for r in generator.generate(wm.to_wildcard(wildcard_name))] diff --git a/tests/parser/test_commands.py b/tests/parser/test_commands.py index 565d37b..addf117 100644 --- a/tests/parser/test_commands.py +++ b/tests/parser/test_commands.py @@ -62,7 +62,7 @@ def test_combinations(self): variant_command = VariantCommand.from_literals_and_weights(ONE_TWO_THREE) assert [ - v.literal for v, in variant_command.get_value_combinations(1) + v.literal for (v,) in variant_command.get_value_combinations(1) ] == ONE_TWO_THREE assert [ From 510d1e958c26276512188532c6cfae4306d291d7 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 20 Nov 2023 13:33:42 +0200 Subject: [PATCH 6/7] Add wrap command (#102) --- src/dynamicprompts/commands/__init__.py | 4 +- src/dynamicprompts/commands/wrap_command.py | 45 +++++++++++++ src/dynamicprompts/parser/config.py | 2 + src/dynamicprompts/parser/parse.py | 55 ++++++++++++++-- src/dynamicprompts/samplers/base.py | 10 +++ src/dynamicprompts/samplers/combinatorial.py | 7 +++ src/dynamicprompts/samplers/random.py | 9 +++ src/dynamicprompts/sampling_result.py | 19 ++++++ tests/test_data/wildcards/wrappers.txt | 2 + tests/test_wrapping.py | 66 ++++++++++++++++++++ tests/utils.py | 15 +++++ tests/wildcard/test_wildcardmanager.py | 4 +- 12 files changed, 232 insertions(+), 6 deletions(-) create mode 100644 src/dynamicprompts/commands/wrap_command.py create mode 100644 tests/test_data/wildcards/wrappers.txt create mode 100644 tests/test_wrapping.py diff --git a/src/dynamicprompts/commands/__init__.py b/src/dynamicprompts/commands/__init__.py index 26e0c3e..2e82c8d 100644 --- a/src/dynamicprompts/commands/__init__.py +++ b/src/dynamicprompts/commands/__init__.py @@ -3,13 +3,15 @@ from dynamicprompts.commands.sequence_command import SequenceCommand from dynamicprompts.commands.variant_command import VariantCommand, VariantOption from dynamicprompts.commands.wildcard_command import WildcardCommand +from dynamicprompts.commands.wrap_command import WrapCommand __all__ = [ "Command", "LiteralCommand", + "SamplingMethod", "SequenceCommand", "VariantCommand", "VariantOption", "WildcardCommand", - "SamplingMethod", + "WrapCommand", ] diff --git a/src/dynamicprompts/commands/wrap_command.py b/src/dynamicprompts/commands/wrap_command.py new file mode 100644 index 0000000..4f58d93 --- /dev/null +++ b/src/dynamicprompts/commands/wrap_command.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import dataclasses +import logging +import re + +from dynamicprompts.commands import Command +from dynamicprompts.enums import SamplingMethod + +log = logging.getLogger(__name__) + +WRAP_MARKER_CHARACTERS = { + "\u1801", # Mongolian ellipsis + "\u2026", # Horizontal ellipsis + "\u22EE", # Vertical ellipsis + "\u22EF", # Midline horizontal ellipsis + "\u22F0", # Up right diagonal ellipsis + "\u22F1", # Down right diagonal ellipsis + "\uFE19", # Presentation form for vertical horizontal ellipsis +} + +WRAP_MARKER_RE = re.compile( + f"[{''.join(WRAP_MARKER_CHARACTERS)}]+" # One or more wrap marker characters + "|" + r"\.{3,}", # ASCII ellipsis of 3 or more dots +) + + +def split_wrapper_string(s: str) -> tuple[str, str]: + """ + Split a string into a prefix and suffix at the first wrap marker. + """ + match = WRAP_MARKER_RE.search(s) + if match is None: + log.warning("Found no wrap marker in string %r", s) + return s, "" + else: + return s[: match.start()], s[match.end() :] + + +@dataclasses.dataclass(frozen=True) +class WrapCommand(Command): + wrapper: Command + inner: Command + sampling_method: SamplingMethod | None = None diff --git a/src/dynamicprompts/parser/config.py b/src/dynamicprompts/parser/config.py index bed26f5..a46721f 100644 --- a/src/dynamicprompts/parser/config.py +++ b/src/dynamicprompts/parser/config.py @@ -10,6 +10,8 @@ class ParserConfig: wildcard_wrap: str = "__" variable_start: str = "${" variable_end: str = "}" + wrap_start: str = "%{" + wrap_end: str = "}" default_parser_config = ParserConfig() diff --git a/src/dynamicprompts/parser/parse.py b/src/dynamicprompts/parser/parse.py index 50829d8..e727762 100644 --- a/src/dynamicprompts/parser/parse.py +++ b/src/dynamicprompts/parser/parse.py @@ -23,6 +23,7 @@ ::= + ::= "${" "=" "}" ::= "${" (":" )? "}" + ::= "%{" "$$" "}" Note that whitespace is preserved in case it is significant to the user. """ @@ -44,6 +45,7 @@ VariantCommand, VariantOption, WildcardCommand, + WrapCommand, ) from dynamicprompts.commands.variable_commands import ( VariableAccessCommand, @@ -62,6 +64,8 @@ sampler_cyclical = pp.Char("@") sampler_symbol = sampler_random | sampler_combinatorial | sampler_cyclical +variant_delim = pp.Suppress("$$") + OPT_WS = pp.Opt(pp.White()) # Optional whitespace var_name = pp.Word(pp.alphas + "_-", pp.alphanums + "_-") @@ -75,7 +79,6 @@ def _configure_range() -> pp.ParserElement: hyphen = pp.Suppress("-") - variant_delim = pp.Suppress("$$") # Exclude: # - $, which is used to indicate the end of the separator definition i.e. {1$$ and $$X|Y|Z} @@ -136,7 +139,12 @@ def _configure_literal_sequence( # - { denotes the start of a variant (or whatever variant_start is set to ) # - # denotes the start of a comment # - $ denotes the start of a variable command (or whatever variable_start is set to) - non_literal_chars = rf"#{parser_config.variant_start}{parser_config.variable_start}" + # - % denotes the start of a wrap command (or whatever wrap_start is set to) + non_literal_chars = ( + rf"#{parser_config.variant_start}" + rf"{parser_config.variable_start}" + rf"{parser_config.wrap_start}" + ) if is_variant_literal: # Inside a variant the following characters are also not allowed @@ -227,6 +235,23 @@ def _configure_variable_assignment( return variable_assignment.leave_whitespace() +def _configure_wrap_command( + parser_config: ParserConfig, + prompt: pp.ParserElement, +) -> pp.ParserElement: + wrap_command = pp.Group( + pp.Suppress(parser_config.wrap_start) + + OPT_WS + + prompt()("wrapper") + + OPT_WS + + variant_delim + + OPT_WS + + prompt()("inner") + + pp.Suppress(parser_config.wrap_end), + ) + return wrap_command.leave_whitespace() + + def _parse_literal_command(parse_result: pp.ParseResults) -> LiteralCommand: s = " ".join(parse_result) return LiteralCommand(s) @@ -365,6 +390,16 @@ def _parse_variable_assignment_command( ) +def _parse_wrap_command( + parse_result: pp.ParseResults, +) -> WrapCommand: + parts = parse_result[0].as_dict() + return WrapCommand( + inner=parts["inner"], + wrapper=parts["wrapper"], + ) + + def create_parser( *, parser_config: ParserConfig, @@ -382,6 +417,10 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) + wrap_command = _configure_wrap_command( + parser_config=parser_config, + prompt=variant_prompt, + ) wildcard = _configure_wildcard(parser_config=parser_config) literal_sequence = _configure_literal_sequence(parser_config=parser_config) variant_literal_sequence = _configure_literal_sequence( @@ -395,9 +434,16 @@ def create_parser( ) chunk = ( - variable_assignment | variable_access | variants | wildcard | literal_sequence + variable_assignment + | variable_access + | wrap_command + | variants + | wildcard + | literal_sequence + ) + variant_chunk = ( + variable_access | wrap_command | variants | wildcard | variant_literal_sequence ) - variant_chunk = variable_access | variants | wildcard | variant_literal_sequence prompt <<= pp.ZeroOrMore(chunk)("prompt") variant_prompt <<= pp.ZeroOrMore(variant_chunk)("prompt") @@ -417,6 +463,7 @@ def create_parser( variable_assignment.set_parse_action(_parse_variable_assignment_command) prompt.set_parse_action(_parse_sequence_or_single_command) variant_prompt.set_parse_action(_parse_sequence_or_single_command) + wrap_command.set_parse_action(_parse_wrap_command) return prompt diff --git a/src/dynamicprompts/samplers/base.py b/src/dynamicprompts/samplers/base.py index 1399302..a01f248 100644 --- a/src/dynamicprompts/samplers/base.py +++ b/src/dynamicprompts/samplers/base.py @@ -8,6 +8,7 @@ SequenceCommand, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.commands.variable_commands import ( VariableAccessCommand, @@ -43,6 +44,8 @@ def generator_from_command( ) if isinstance(command, VariableAccessCommand): return self._get_variable(command, context) + if isinstance(command, WrapCommand): + return self._get_wrap(command, context) return self._unsupported_command(command) def _unsupported_command(self, command: Command) -> ResultGen: @@ -100,3 +103,10 @@ def _get_variable( return context.for_sampling_variable(variable).generator_from_command( command_to_sample, ) + + def _get_wrap( + self, + command: WrapCommand, + context: SamplingContext, + ) -> ResultGen: + return self._unsupported_command(command) diff --git a/src/dynamicprompts/samplers/combinatorial.py b/src/dynamicprompts/samplers/combinatorial.py index dfa5913..2e838ac 100644 --- a/src/dynamicprompts/samplers/combinatorial.py +++ b/src/dynamicprompts/samplers/combinatorial.py @@ -10,6 +10,7 @@ SequenceCommand, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.command_collection import CommandCollection @@ -158,3 +159,9 @@ def _get_literal( context: SamplingContext, ) -> ResultGen: yield SamplingResult(text=command.literal) + + def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen: + for wrapper_result in context.sample_prompts(command.wrapper): + wrap = wrapper_result.as_wrapper() + for inner in context.sample_prompts(command.inner): + yield wrap(inner) diff --git a/src/dynamicprompts/samplers/random.py b/src/dynamicprompts/samplers/random.py index c1070b7..8523606 100644 --- a/src/dynamicprompts/samplers/random.py +++ b/src/dynamicprompts/samplers/random.py @@ -8,6 +8,7 @@ Command, VariantCommand, WildcardCommand, + WrapCommand, ) from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( @@ -124,3 +125,11 @@ def _get_wildcard( while True: value = next(gen) yield from context.sample_prompts(value, 1) + + def _get_wrap(self, command: WrapCommand, context: SamplingContext) -> ResultGen: + wrapper_gen = context.generator_from_command(command.wrapper) + inner_gen = context.generator_from_command(command.inner) + wrapper_result: SamplingResult + inner_result: SamplingResult + for wrapper_result, inner_result in zip(wrapper_gen, inner_gen): + yield wrapper_result.as_wrapper()(inner_result) diff --git a/src/dynamicprompts/sampling_result.py b/src/dynamicprompts/sampling_result.py index 3603c6f..eb52a4d 100644 --- a/src/dynamicprompts/sampling_result.py +++ b/src/dynamicprompts/sampling_result.py @@ -3,6 +3,8 @@ import dataclasses from typing import Iterable +from dynamicprompts.commands.wrap_command import split_wrapper_string + @dataclasses.dataclass(frozen=True) class SamplingResult: @@ -26,6 +28,23 @@ def whitespace_squashed(self) -> SamplingResult: return dataclasses.replace(self, text=squash_whitespace(self.text)) + def text_replaced(self, new_text: str) -> SamplingResult: + return dataclasses.replace(self, text=new_text) + + def as_wrapper(self): + """ + Return a function that wraps a SamplingResult with this one, + partitioning this result's text along the wrap marker. + """ + prefix, suffix = split_wrapper_string(self.text) + prefix_res = self.text_replaced(prefix) + suffix_res = self.text_replaced(suffix) + + def wrapper(inner: SamplingResult) -> SamplingResult: + return SamplingResult.joined([prefix_res, inner, suffix_res], separator="") + + return wrapper + @classmethod def joined( cls, diff --git a/tests/test_data/wildcards/wrappers.txt b/tests/test_data/wildcards/wrappers.txt new file mode 100644 index 0000000..b4a23ea --- /dev/null +++ b/tests/test_data/wildcards/wrappers.txt @@ -0,0 +1,2 @@ +Art Deco, ..., sleek, geometric forms, art deco style +Pop Art, ....., vivid colors, flat color, 2D, strong lines, Pop Art diff --git a/tests/test_wrapping.py b/tests/test_wrapping.py new file mode 100644 index 0000000..ea365de --- /dev/null +++ b/tests/test_wrapping.py @@ -0,0 +1,66 @@ +import pytest +from dynamicprompts.enums import SamplingMethod +from dynamicprompts.parser.parse import parse +from dynamicprompts.sampling_context import SamplingContext +from dynamicprompts.wildcards import WildcardManager + +from tests.utils import sample_n + + +# Methods currently supported by wrap command +@pytest.fixture( + params=[ + SamplingMethod.COMBINATORIAL, + SamplingMethod.RANDOM, + ], +) +def scon(request, wildcard_manager: WildcardManager) -> SamplingContext: + return SamplingContext( + default_sampling_method=request.param, + wildcard_manager=wildcard_manager, + ) + + +def test_wrap_with_wildcard(scon: SamplingContext): + cmd = parse("%{__wrappers__$${fox|cow}}") + assert sample_n(cmd, scon, n=4) == { + "Art Deco, cow, sleek, geometric forms, art deco style", + "Art Deco, fox, sleek, geometric forms, art deco style", + "Pop Art, cow, vivid colors, flat color, 2D, strong lines, Pop Art", + "Pop Art, fox, vivid colors, flat color, 2D, strong lines, Pop Art", + } + + +@pytest.mark.parametrize("placeholder", ["…", "᠁", ".........", "..."]) +def test_wrap_with_literal(scon: SamplingContext, placeholder: str): + cmd = parse("%{happy ... on a meadow$${fox|cow}}".replace("...", placeholder)) + assert sample_n(cmd, scon, n=2) == { + "happy fox on a meadow", + "happy cow on a meadow", + } + + +def test_bad_wrap_is_prefix(scon: SamplingContext): + cmd = parse("%{happy $${fox|cow}}") + assert sample_n(cmd, scon, n=2) == { + "happy fox", + "happy cow", + } + + +def test_wrap_suffix(scon: SamplingContext): + cmd = parse("%{... in jail$${fox|cow}}") + assert sample_n(cmd, scon, n=2) == { + "fox in jail", + "cow in jail", + } + + +def test_wrap_with_variant(scon): + cmd = parse("%{ {cool|hot} ...$${fox|cow}}") + assert sample_n(cmd, scon, n=4) == { + "cool fox", + "cool cow", + "hot fox", + "hot cow", + } diff --git a/tests/utils.py b/tests/utils.py index e17a7f6..86ee8c1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,8 @@ from __future__ import annotations +from dynamicprompts.commands import Command +from dynamicprompts.sampling_context import SamplingContext + def cross(list1: list[str], list2: list[str], sep=",") -> list[str]: return [f"{x}{sep}{y}" for x in list1 for y in list2 if x != y] @@ -15,3 +18,15 @@ def interleave(list1: list[str], list2: list[str]) -> list[str]: new_list[1::2] = list2 return new_list + + +def sample_n(cmd: Command, scon: SamplingContext, n: int) -> set[str]: + """ + Sample until we have n unique prompts. + """ + seen = set() + for p in scon.sample_prompts(cmd): + seen.add(str(p)) + if len(seen) == n: + break + return seen diff --git a/tests/wildcard/test_wildcardmanager.py b/tests/wildcard/test_wildcardmanager.py index 7e7a7e4..0f19aca 100644 --- a/tests/wildcard/test_wildcardmanager.py +++ b/tests/wildcard/test_wildcardmanager.py @@ -141,15 +141,17 @@ def test_hierarchy(wildcard_manager: WildcardManager): "variant", "weighted-animals/heavy", "weighted-animals/light", + "wrappers", } assert set(root.collections) == { "clothing", # from pantry YAML "colors-cold", # .txt "colors-warm", # .txt + "dupes", # .txt "referencing-colors", # .txt "shapes", # flat list YAML "variant", # .txt - "dupes", # .txt + "wrappers", # .txt } assert set(root.child_nodes["animals"].collections) == { "all-references", From ff474419aafb0959894295a9c468e5a4a5b4c216 Mon Sep 17 00:00:00 2001 From: Michael Wooten Date: Thu, 23 Nov 2023 08:55:42 -0500 Subject: [PATCH 7/7] Add Variable Support in Wildcard Paths (#110) Adds support for the issue described in https://github.com/adieyal/sd-dynamic-prompts/issues/614. The patch adds support for allowing variable refences within wildcard paths (e.g. `__cars/${car}/types__`). During the parsing phase, the variables will be retained in the path. During sampling, before resolving the wildcard, the variable is first resolved and replaced in the wildcard. The update should allow support for variable replacement, including providing default values using the normal `:default` syntax. Some examples: * `${gender=male}${section=upper} __clothes/${gender}/${section}__` * `__clothes/${gender:female}/${section:lower}__` * `__clothes/${gender:female}/${section:*}__` --- .../commands/wildcard_command.py | 3 + src/dynamicprompts/parser/parse.py | 43 ++++++++- src/dynamicprompts/samplers/combinatorial.py | 2 + src/dynamicprompts/samplers/cycle.py | 2 + src/dynamicprompts/samplers/random.py | 2 + src/dynamicprompts/samplers/utils.py | 73 ++++++++++++++- tests/parser/test_parser.py | 4 + tests/samplers/test_common.py | 88 +++++++++++++++++++ tests/samplers/test_utils.py | 53 ++++++++++- tests/test_data/wildcards/animal.txt | 1 + .../wildcards/animals/reptiles/lizards.yaml | 2 + .../wildcards/animals/reptiles/snakes.txt | 2 + tests/test_data/wildcards/cars.yaml | 31 +++++++ tests/wildcard/test_wildcardmanager.py | 25 ++++++ 14 files changed, 324 insertions(+), 7 deletions(-) create mode 100644 tests/test_data/wildcards/animal.txt create mode 100644 tests/test_data/wildcards/animals/reptiles/lizards.yaml create mode 100644 tests/test_data/wildcards/animals/reptiles/snakes.txt create mode 100644 tests/test_data/wildcards/cars.yaml diff --git a/src/dynamicprompts/commands/wildcard_command.py b/src/dynamicprompts/commands/wildcard_command.py index 7958d27..ef0052c 100644 --- a/src/dynamicprompts/commands/wildcard_command.py +++ b/src/dynamicprompts/commands/wildcard_command.py @@ -15,3 +15,6 @@ class WildcardCommand(Command): def __post_init__(self): if not isinstance(self.wildcard, str): raise TypeError(f"Wildcard must be a string, not {type(self.wildcard)}") + + def with_content(self, content: str) -> WildcardCommand: + return dataclasses.replace(self, wildcard=content) diff --git a/src/dynamicprompts/parser/parse.py b/src/dynamicprompts/parser/parse.py index e727762..7b8ca35 100644 --- a/src/dynamicprompts/parser/parse.py +++ b/src/dynamicprompts/parser/parse.py @@ -109,9 +109,13 @@ def _configure_range() -> pp.ParserElement: def _configure_wildcard( parser_config: ParserConfig, + variable_ref: pp.ParserElement, ) -> pp.ParserElement: - wildcard_path_re = r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^({}#])+" - wildcard_path = pp.Regex(wildcard_path_re)("path").leave_whitespace() + wildcard = _configure_wildcard_path( + parser_config=parser_config, + variable_ref=variable_ref, + ) + wildcard_enclosure = pp.Suppress(parser_config.wildcard_wrap) wildcard_variable_spec = ( OPT_WS @@ -123,7 +127,7 @@ def _configure_wildcard( wildcard = ( wildcard_enclosure + pp.Opt(sampler_symbol)("sampling_method") - + wildcard_path + + wildcard + pp.Opt(wildcard_variable_spec) + wildcard_enclosure ) @@ -131,6 +135,17 @@ def _configure_wildcard( return wildcard("wildcard").leave_whitespace() +def _configure_wildcard_path( + parser_config: ParserConfig, + variable_ref: pp.ParserElement, +) -> pp.ParserElement: + wildcard_path_literal_re = ( + r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+" + ) + wildcard_path = pp.Regex(wildcard_path_literal_re).leave_whitespace() + return pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path") + + def _configure_literal_sequence( parser_config: ParserConfig, is_variant_literal: bool = False, @@ -413,6 +428,23 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) + + variable_ref = pp.Combine( + _configure_variable_access( + parser_config=parser_config, + prompt=pp.SkipTo(parser_config.variable_end), + ), + ).leave_whitespace() + + # by default, the variable starting "${" and end "}" characters are + # stripped with variable access. This restores them so that the + # variable can be properly recognized at a later stage + variable_ref.set_parse_action( + lambda string, + location, + token: f"{parser_config.variable_start}{''.join(token)}{parser_config.variable_end}", + ) + variable_assignment = _configure_variable_assignment( parser_config=parser_config, prompt=variant_prompt, @@ -421,7 +453,10 @@ def create_parser( parser_config=parser_config, prompt=variant_prompt, ) - wildcard = _configure_wildcard(parser_config=parser_config) + wildcard = _configure_wildcard( + parser_config=parser_config, + variable_ref=variable_ref, + ) literal_sequence = _configure_literal_sequence(parser_config=parser_config) variant_literal_sequence = _configure_literal_sequence( is_variant_literal=True, diff --git a/src/dynamicprompts/samplers/combinatorial.py b/src/dynamicprompts/samplers/combinatorial.py index 2e838ac..b939330 100644 --- a/src/dynamicprompts/samplers/combinatorial.py +++ b/src/dynamicprompts/samplers/combinatorial.py @@ -16,6 +16,7 @@ from dynamicprompts.samplers.command_collection import CommandCollection from dynamicprompts.samplers.utils import ( get_wildcard_not_found_fallback, + replace_wildcard_variables, wildcard_to_variant, ) from dynamicprompts.sampling_context import SamplingContext @@ -143,6 +144,7 @@ def _get_wildcard( context: SamplingContext, ) -> ResultGen: # TODO: doesn't support weights + command = replace_wildcard_variables(command=command, context=context) context = context.with_variables(command.variables) values = context.wildcard_manager.get_values(command.wildcard) if not values: diff --git a/src/dynamicprompts/samplers/cycle.py b/src/dynamicprompts/samplers/cycle.py index 4d46e67..0c916ea 100644 --- a/src/dynamicprompts/samplers/cycle.py +++ b/src/dynamicprompts/samplers/cycle.py @@ -13,6 +13,7 @@ from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( get_wildcard_not_found_fallback, + replace_wildcard_variables, wildcard_to_variant, ) from dynamicprompts.sampling_context import SamplingContext @@ -100,6 +101,7 @@ def _get_wildcard( context: SamplingContext, ) -> ResultGen: # TODO: doesn't support weights + command = replace_wildcard_variables(command=command, context=context) wc_values = context.wildcard_manager.get_values(command.wildcard) new_context = context.with_variables( command.variables, diff --git a/src/dynamicprompts/samplers/random.py b/src/dynamicprompts/samplers/random.py index 8523606..b8b3aec 100644 --- a/src/dynamicprompts/samplers/random.py +++ b/src/dynamicprompts/samplers/random.py @@ -13,6 +13,7 @@ from dynamicprompts.samplers.base import Sampler from dynamicprompts.samplers.utils import ( get_wildcard_not_found_fallback, + replace_wildcard_variables, wildcard_to_variant, ) from dynamicprompts.sampling_context import SamplingContext @@ -114,6 +115,7 @@ def _get_wildcard( command: WildcardCommand, context: SamplingContext, ) -> ResultGen: + command = replace_wildcard_variables(command=command, context=context) context = context.with_variables(command.variables) values = context.wildcard_manager.get_values(command.wildcard) diff --git a/src/dynamicprompts/samplers/utils.py b/src/dynamicprompts/samplers/utils.py index 35f6554..55a8757 100644 --- a/src/dynamicprompts/samplers/utils.py +++ b/src/dynamicprompts/samplers/utils.py @@ -1,9 +1,21 @@ from __future__ import annotations import logging +from functools import partial -from dynamicprompts.commands import VariantCommand, VariantOption, WildcardCommand -from dynamicprompts.parser.parse import parse +import pyparsing as pp + +from dynamicprompts.commands import ( + LiteralCommand, + VariantCommand, + VariantOption, + WildcardCommand, +) +from dynamicprompts.parser.parse import ( + _configure_variable_access, + _configure_wildcard_path, + parse, +) from dynamicprompts.sampling_context import SamplingContext from dynamicprompts.sampling_result import SamplingResult from dynamicprompts.types import ResultGen @@ -19,6 +31,7 @@ def wildcard_to_variant( max_bound=1, separator=",", ) -> VariantCommand: + command = replace_wildcard_variables(command=command, context=context) values = context.wildcard_manager.get_values(command.wildcard) min_bound = min(min_bound, len(values)) max_bound = min(max_bound, len(values)) @@ -50,3 +63,59 @@ def get_wildcard_not_found_fallback( res = SamplingResult(text=wrapped_wildcard) while True: yield res + + +def replace_wildcard_variables( + command: WildcardCommand, + *, + context: SamplingContext, +) -> WildcardCommand: + if context.parser_config.variable_start not in command.wildcard: + return command + + prompt = pp.SkipTo(context.parser_config.variable_end) + variable_access = _configure_variable_access( + parser_config=context.parser_config, + prompt=prompt, + ) + variable_access.set_parse_action( + partial(_replace_variable, variables=context.variables), + ) + wildcard = _configure_wildcard_path( + parser_config=context.parser_config, + variable_ref=variable_access, + ) + + try: + wildcard_result = wildcard.parse_string(command.wildcard) + return command.with_content("".join(wildcard_result)) + except Exception: + logger.warning("Unable to parse wildcard %r", command.wildcard, exc_info=True) + return command + + +def _replace_variable(string, location, token, *, variables: dict): + if isinstance(token, pp.ParseResults): + var_parts = token[0].as_dict() + var_name = var_parts.get("name") + if var_name: + var_name = var_name.strip() + + default = var_parts.get("default") + if default: + default = default.strip() + + else: + var_name = token + + variable = None + if var_name: + variable = variables.get(var_name) + + if isinstance(variable, LiteralCommand): + variable = variable.literal + if variable and not isinstance(variable, str): + raise NotImplementedError( + "evaluating complex commands within wildcards is not supported right now", + ) + return variable or default or var_name diff --git a/tests/parser/test_parser.py b/tests/parser/test_parser.py index e01debc..a8603c8 100644 --- a/tests/parser/test_parser.py +++ b/tests/parser/test_parser.py @@ -48,6 +48,10 @@ def test_literal_characters(self, input: str): [ "colours", "path/to/colours", + "path/to/${subject}", + "${pallette}_colours", + "${pallette:warm}_colours", + "locations/${room: dining room }/furniture", "änder", ], ) diff --git a/tests/samplers/test_common.py b/tests/samplers/test_common.py index df561c3..c92f170 100644 --- a/tests/samplers/test_common.py +++ b/tests/samplers/test_common.py @@ -626,6 +626,94 @@ def test_wildcard_nested_in_wildcard( ps = [str(p) for p in islice(gen, len(expected))] assert ps == expected + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_nested_variable(self, sampling_context: SamplingContext): + cmd = parse("${temp=cold}wearing __colors-${temp}__ suede shoes") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ( + "wearing blue suede shoes", + "wearing green suede shoes", + ) + else: + assert resolved_value == "wearing blue suede shoes" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_default_variable(self, sampling_context: SamplingContext): + cmd = parse("wearing __colors-${temp:cold}__ suede shoes") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ( + "wearing blue suede shoes", + "wearing green suede shoes", + ) + else: + assert resolved_value == "wearing blue suede shoes" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_undefined_variable(self, sampling_context: SamplingContext): + cmd = parse("wearing __colors-${temp}__ suede shoes") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + assert ( + resolved_value + == f"wearing {sampling_context.wildcard_manager.wildcard_wrap}colors-temp{sampling_context.wildcard_manager.wildcard_wrap} suede shoes" + ) + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_multiple_variables(self, sampling_context: SamplingContext): + cmd = parse( + "${genus=mammals}${species=feline}__animals/${genus:reptiles}/${species:snakes}__", + ) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ("cat", "tiger") + else: + assert resolved_value == "cat" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_variable_and_glob(self, sampling_context: SamplingContext): + cmd = parse("${genus=reptiles}__animals/${genus}/*__") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ("cobra", "gecko", "iguana", "python") + else: + assert resolved_value == "cobra" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_wildcard_with_variable_in_nested_wildcard( + self, + sampling_context: SamplingContext, + ): + cmd = parse("${genus=mammals}__animal__") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ("cat", "tiger", "dog", "wolf") + else: + assert resolved_value == "cat" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_nested_wildcard_with_parameterized_variable( + self, + sampling_context: SamplingContext, + ): + cmd = parse("__animal(genus=mammals)__") + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + if isinstance(sampling_context.default_sampler, RandomSampler): + assert resolved_value in ("cat", "tiger", "dog", "wolf") + else: + assert resolved_value == "cat" + + @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) + def test_nested_wildcards_in_single_file(self, sampling_context: SamplingContext): + cmd = parse( + "${car=!{porsche|john_deere}}a __cars/${car}/types__ made by __cars/${car}/name__", + ) + resolved_value = str(next(sampling_context.generator_from_command(cmd))) + assert resolved_value in ( + "a sports car made by Porsche", + "a tractor made by John Deere", + ) + class TestVariableCommands: @pytest.mark.parametrize("sampling_context", sampling_context_lazy_fixtures) diff --git a/tests/samplers/test_utils.py b/tests/samplers/test_utils.py index 01fca1d..5fae98f 100644 --- a/tests/samplers/test_utils.py +++ b/tests/samplers/test_utils.py @@ -2,7 +2,10 @@ from dynamicprompts.commands import ( WildcardCommand, ) -from dynamicprompts.samplers.utils import wildcard_to_variant +from dynamicprompts.samplers.utils import ( + replace_wildcard_variables, + wildcard_to_variant, +) from dynamicprompts.sampling_context import SamplingContext from pytest_lazyfixture import lazy_fixture @@ -30,3 +33,51 @@ def test_wildcard_to_variant(sampling_context: SamplingContext): ) assert variant_command.separator == "-" assert variant_command.sampling_method == wildcard_command.sampling_method + + +@pytest.mark.parametrize( + ("sampling_context", "initial_wildcard", "expected_wildcard", "variables"), + [ + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish }", + "colors-warm-finish", + {}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-warm-matte", + {}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-cold-matte", + {"temp": "cold"}, + ), + ( + lazy_fixture("random_sampling_context"), + "colors-${temp: warm}-${ finish: matte }", + "colors-cold-glossy", + {"temp": "cold", "finish": "glossy"}, + ), + ], +) +def test_replace_wildcard_variables_multi_variable( + sampling_context: SamplingContext, + initial_wildcard: str, + expected_wildcard: str, + variables: dict, +): + var_sampling_context = sampling_context.with_variables(variables=variables) + wildcard_command = WildcardCommand(initial_wildcard) + updated_command = replace_wildcard_variables( + command=wildcard_command, + context=var_sampling_context, + ) + assert isinstance( + updated_command, + WildcardCommand, + ), "updated command is also a WildcardCommand" + assert updated_command.wildcard == expected_wildcard diff --git a/tests/test_data/wildcards/animal.txt b/tests/test_data/wildcards/animal.txt new file mode 100644 index 0000000..46c0125 --- /dev/null +++ b/tests/test_data/wildcards/animal.txt @@ -0,0 +1 @@ +__animals/${genus:*}/*__ diff --git a/tests/test_data/wildcards/animals/reptiles/lizards.yaml b/tests/test_data/wildcards/animals/reptiles/lizards.yaml new file mode 100644 index 0000000..e6e6fb5 --- /dev/null +++ b/tests/test_data/wildcards/animals/reptiles/lizards.yaml @@ -0,0 +1,2 @@ +- iguana +- gecko diff --git a/tests/test_data/wildcards/animals/reptiles/snakes.txt b/tests/test_data/wildcards/animals/reptiles/snakes.txt new file mode 100644 index 0000000..fe228ad --- /dev/null +++ b/tests/test_data/wildcards/animals/reptiles/snakes.txt @@ -0,0 +1,2 @@ +python +cobra diff --git a/tests/test_data/wildcards/cars.yaml b/tests/test_data/wildcards/cars.yaml new file mode 100644 index 0000000..ec28991 --- /dev/null +++ b/tests/test_data/wildcards/cars.yaml @@ -0,0 +1,31 @@ +cars: + ford: + name: + - Ford + types: + - muscle car + - compact car + - roadster + colors: + - white + - black + - brown + - red + + porsche: + name: + - Porsche + types: + - sports car + colors: + - red + - black + + john_deere: + name: + - John Deere + types: + - tractor + colors: + - green + - yellow diff --git a/tests/wildcard/test_wildcardmanager.py b/tests/wildcard/test_wildcardmanager.py index 0f19aca..da04739 100644 --- a/tests/wildcard/test_wildcardmanager.py +++ b/tests/wildcard/test_wildcardmanager.py @@ -123,12 +123,24 @@ def test_hierarchy(wildcard_manager: WildcardManager): root = wildcard_manager.tree.root assert {name for name, item in root.walk_items()} == { "dupes", + "animal", "animals/all-references", "animals/mammals/canine", "animals/mammals/feline", "animals/mystical", + "animals/reptiles/lizards", + "animals/reptiles/snakes", "artists/dutch", "artists/finnish", + "cars/ford/colors", + "cars/ford/name", + "cars/ford/types", + "cars/john_deere/colors", + "cars/john_deere/name", + "cars/john_deere/types", + "cars/porsche/colors", + "cars/porsche/name", + "cars/porsche/types", "clothing", "colors-cold", "colors-warm", @@ -144,6 +156,7 @@ def test_hierarchy(wildcard_manager: WildcardManager): "wrappers", } assert set(root.collections) == { + "animal", "clothing", # from pantry YAML "colors-cold", # .txt "colors-warm", # .txt @@ -161,11 +174,17 @@ def test_hierarchy(wildcard_manager: WildcardManager): "canine", "feline", } + assert set(root.child_nodes["animals"].child_nodes["reptiles"].collections) == { + "lizards", + "snakes", + }, "animals/reptiles does not match" assert set(root.child_nodes["animals"].walk_full_names()) == { "animals/all-references", "animals/mammals/canine", "animals/mammals/feline", "animals/mystical", + "animals/reptiles/lizards", + "animals/reptiles/snakes", } assert set(root.child_nodes["flavors"].collections) == { "sour", # .txt @@ -336,6 +355,8 @@ def test_wcm_roots(): "elaimet/mammals/canine", "elaimet/mammals/feline", "elaimet/mystical", + "elaimet/reptiles/lizards", + "elaimet/reptiles/snakes", "elaimet/sopot", "metasyntactic/fnord", "metasyntactic/foo", @@ -344,9 +365,13 @@ def test_wcm_roots(): v for v in wcm.get_values("elaimet/*").string_values if not v.startswith("_") } == { "cat", + "cobra", "dog", + "gecko", + "iguana", "okapi", "pingviini", + "python", "tiger", "unicorn", "wolf",