Skip to content

Commit

Permalink
Merge remote-tracking branch 'source/main' into feat/dynamic_wildcards
Browse files Browse the repository at this point in the history
  • Loading branch information
mwootendev committed Dec 19, 2023
2 parents 3524d37 + ff47441 commit 7127ee1
Show file tree
Hide file tree
Showing 26 changed files with 361 additions and 50 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: pre-commit/[email protected]
env:
RUFF_OUTPUT_FORMAT: github
mypy:
runs-on: ubuntu-latest
steps:
Expand All @@ -26,7 +28,7 @@ jobs:
cache: "pip"
cache-dependency-path: |
pyproject.toml
- run: python -m pip install mypy -e .[dev,attentiongrabber,magicprompt,feelinglucky]
- run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky,yaml]
- run: mypy --install-types --non-interactive src
test:
runs-on: ${{ matrix.os }}
Expand All @@ -45,7 +47,7 @@ jobs:
cache-dependency-path: |
pyproject.toml
- name: Install dependencies
run: python -m pip install -e .[dev,attentiongrabber,magicprompt,feelinglucky]
run: python -m pip install -e .[dev,attentiongrabber,feelinglucky,yaml]
- run: pytest --cov --cov-report=term-missing --cov-report=xml .
env:
PYPARSINGENABLEALLWARNINGS: 1
Expand All @@ -67,9 +69,6 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
pyproject.toml
- run: python -m pip install hatch
- run: hatch build -t wheel
- name: Publish package distributions to PyPI
Expand Down
11 changes: 3 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.290
rev: v0.1.6
hooks:
- id: ruff
args:
- --fix
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: black
args:
- --quiet
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
attentiongrabber = [] # empty list for backwards compatibility (no "no extra" warnings)
magicprompt = ["transformers[torch]~=4.19"]
feelinglucky = ["requests~=2.28"]
yaml = ["pyyaml~=6.0"]
dev = [
"pytest-cov~=4.0",
"pytest-lazy-fixture~=0.6",
Expand Down Expand Up @@ -84,6 +85,10 @@ exclude = "tests"
module = "transformers"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "torch"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "spacy.*"
ignore_missing_imports = true
Expand Down
4 changes: 3 additions & 1 deletion src/dynamicprompts/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from dynamicprompts.commands.sequence_command import SequenceCommand
from dynamicprompts.commands.variant_command import VariantCommand, VariantOption
from dynamicprompts.commands.wildcard_command import WildcardCommand
from dynamicprompts.commands.wrap_command import WrapCommand

__all__ = [
"Command",
"LiteralCommand",
"SamplingMethod",
"SequenceCommand",
"VariantCommand",
"VariantOption",
"WildcardCommand",
"SamplingMethod",
"WrapCommand",
]
45 changes: 45 additions & 0 deletions src/dynamicprompts/commands/wrap_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

import dataclasses
import logging
import re

from dynamicprompts.commands import Command
from dynamicprompts.enums import SamplingMethod

log = logging.getLogger(__name__)

WRAP_MARKER_CHARACTERS = {
"\u1801", # Mongolian ellipsis
"\u2026", # Horizontal ellipsis
"\u22EE", # Vertical ellipsis
"\u22EF", # Midline horizontal ellipsis
"\u22F0", # Up right diagonal ellipsis
"\u22F1", # Down right diagonal ellipsis
"\uFE19", # Presentation form for vertical horizontal ellipsis
}

WRAP_MARKER_RE = re.compile(
f"[{''.join(WRAP_MARKER_CHARACTERS)}]+" # One or more wrap marker characters
"|"
r"\.{3,}", # ASCII ellipsis of 3 or more dots
)


def split_wrapper_string(s: str) -> tuple[str, str]:
"""
Split a string into a prefix and suffix at the first wrap marker.
"""
match = WRAP_MARKER_RE.search(s)
if match is None:
log.warning("Found no wrap marker in string %r", s)
return s, ""
else:
return s[: match.start()], s[match.end() :]


@dataclasses.dataclass(frozen=True)
class WrapCommand(Command):
wrapper: Command
inner: Command
sampling_method: SamplingMethod | None = None
35 changes: 20 additions & 15 deletions src/dynamicprompts/generators/magicprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,13 @@

logger = logging.getLogger(__name__)

try:
if TYPE_CHECKING:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Pipeline,
pipeline,
set_seed,
)
except ImportError as ie:
raise ImportError(
"You need to install the transformers library to use the MagicPrompt generator. "
"You can do this by running `pip install -U dynamicprompts[magicprompt]`.",
) from ie

if TYPE_CHECKING:
import torch

DEFAULT_MODEL_NAME = "Gustavosta/MagicPrompt-Stable-Diffusion"
MAX_SEED = 2**32 - 1
Expand Down Expand Up @@ -71,6 +62,18 @@ def clean_up_magic_prompt(orig_prompt: str, prompt: str) -> str:
return prompt


def _import_transformers(): # pragma: no cover
try:
import transformers

return transformers
except ImportError as ie:
raise ImportError(
"You need to install the transformers library to use the MagicPrompt generator. "
"You can do this by running `pip install -U dynamicprompts[magicprompt]`.",
) from ie


class MagicPromptGenerator(PromptGenerator):
generator: Pipeline | None = None
tokenizer: AutoTokenizer | None = None
Expand All @@ -83,13 +86,14 @@ def _load_pipeline(self, model_name: str) -> Pipeline:
logger.warning("First load of MagicPrompt may take a while.")

if MagicPromptGenerator.generator is None:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
transformers = _import_transformers()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token_id = model.config.eos_token_id

MagicPromptGenerator.tokenizer = tokenizer
MagicPromptGenerator.model = model
MagicPromptGenerator.generator = pipeline(
MagicPromptGenerator.generator = transformers.pipeline(
task="text-generation",
tokenizer=tokenizer,
model=model,
Expand Down Expand Up @@ -123,6 +127,7 @@ def __init__(
:param blocklist_regex: A regex to use to filter out prompts that match it.
:param batch_size: The batch size to use when generating prompts.
"""
transformers = _import_transformers()
self._device = device
self.set_model(model_name)

Expand All @@ -140,7 +145,7 @@ def __init__(
self._blocklist_regex = None

if seed is not None:
set_seed(int(seed))
transformers.set_seed(int(seed))

self._batch_size = batch_size

Expand Down
7 changes: 5 additions & 2 deletions src/dynamicprompts/jinja_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ def wildcard(environment: Environment, wildcard_name: str) -> list[str]:
from dynamicprompts.generators import CombinatorialPromptGenerator
from dynamicprompts.wildcards import WildcardManager

wm: WildcardManager = environment.globals["wildcard_manager"] # type: ignore
generator: CombinatorialPromptGenerator = environment.globals["generators"]["combinatorial"] # type: ignore
wm = cast(WildcardManager, environment.globals["wildcard_manager"])
generator = cast(
CombinatorialPromptGenerator,
environment.globals["generators"]["combinatorial"], # type: ignore
)

return [str(r) for r in generator.generate(wm.to_wildcard(wildcard_name))]

Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/parser/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class ParserConfig:
wildcard_wrap: str = "__"
variable_start: str = "${"
variable_end: str = "}"
wrap_start: str = "%{"
wrap_end: str = "}"


default_parser_config = ParserConfig()
68 changes: 63 additions & 5 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<variant_literal_sequence> ::= <variant_literal>+
<variable_assignment> ::= "${" <variable_name> "=" <variant_chunk> "}"
<variable_access> ::= "${" <variable_name> (":" <variant_chunk>)? "}"
<wrap_command> ::= "%{" <variant_chunk> "$$" <variant_chunk> "}"
Note that whitespace is preserved in case it is significant to the user.
"""
Expand All @@ -44,6 +45,7 @@
VariantCommand,
VariantOption,
WildcardCommand,
WrapCommand,
)
from dynamicprompts.commands.variable_commands import (
VariableAccessCommand,
Expand All @@ -62,6 +64,8 @@
sampler_cyclical = pp.Char("@")
sampler_symbol = sampler_random | sampler_combinatorial | sampler_cyclical

variant_delim = pp.Suppress("$$")

OPT_WS = pp.Opt(pp.White()) # Optional whitespace

var_name = pp.Word(pp.alphas + "_-", pp.alphanums + "_-")
Expand All @@ -75,7 +79,6 @@

def _configure_range() -> pp.ParserElement:
hyphen = pp.Suppress("-")
variant_delim = pp.Suppress("$$")

# Exclude:
# - $, which is used to indicate the end of the separator definition i.e. {1$$ and $$X|Y|Z}
Expand Down Expand Up @@ -128,6 +131,17 @@ def _configure_wildcard(
return wildcard("wildcard").leave_whitespace()


def _configure_wildcard_path(
parser_config: ParserConfig,
variable_ref: pp.ParserElement,
) -> pp.ParserElement:
wildcard_path_literal_re = (
r"((?!" + re.escape(parser_config.wildcard_wrap) + r")[^(${}#])+"
)
wildcard_path = pp.Regex(wildcard_path_literal_re).leave_whitespace()
return pp.Combine(pp.OneOrMore(variable_ref | wildcard_path))("path")


def _configure_literal_sequence(
parser_config: ParserConfig,
is_variant_literal: bool = False,
Expand All @@ -137,7 +151,12 @@ def _configure_literal_sequence(
# - { denotes the start of a variant (or whatever variant_start is set to )
# - # denotes the start of a comment
# - $ denotes the start of a variable command (or whatever variable_start is set to)
non_literal_chars = rf"#{parser_config.variant_start}{parser_config.variable_start}"
# - % denotes the start of a wrap command (or whatever wrap_start is set to)
non_literal_chars = (
rf"#{parser_config.variant_start}"
rf"{parser_config.variable_start}"
rf"{parser_config.wrap_start}"
)

if is_variant_literal:
# Inside a variant the following characters are also not allowed
Expand Down Expand Up @@ -234,6 +253,23 @@ def _configure_variable_assignment(
return variable_assignment.leave_whitespace()


def _configure_wrap_command(
parser_config: ParserConfig,
prompt: pp.ParserElement,
) -> pp.ParserElement:
wrap_command = pp.Group(
pp.Suppress(parser_config.wrap_start)
+ OPT_WS
+ prompt()("wrapper")
+ OPT_WS
+ variant_delim
+ OPT_WS
+ prompt()("inner")
+ pp.Suppress(parser_config.wrap_end),
)
return wrap_command.leave_whitespace()


def _parse_literal_command(parse_result: pp.ParseResults) -> LiteralCommand:
s = " ".join(parse_result)
return LiteralCommand(s)
Expand Down Expand Up @@ -386,6 +422,16 @@ def _parse_variable_assignment_command(
)


def _parse_wrap_command(
parse_result: pp.ParseResults,
) -> WrapCommand:
parts = parse_result[0].as_dict()
return WrapCommand(
inner=parts["inner"],
wrapper=parts["wrapper"],
)


def create_parser(
*,
parser_config: ParserConfig,
Expand All @@ -408,7 +454,11 @@ def create_parser(
parser_config=parser_config,
prompt=variant_prompt,
)
wildcard = _configure_wildcard(
wrap_command = _configure_wrap_command(
parser_config=parser_config,
prompt=variant_prompt,
)
wildcard = _configure_wildcard(
parser_config=parser_config,
prompt=wildcard_prompt,
)
Expand All @@ -428,9 +478,16 @@ def create_parser(
)

chunk = (
variable_assignment | variable_access | variants | wildcard | literal_sequence
variable_assignment
| variable_access
| wrap_command
| variants
| wildcard
| literal_sequence
)
variant_chunk = (
variable_access | wrap_command | variants | wildcard | variant_literal_sequence
)
variant_chunk = variable_access | variants | wildcard | variant_literal_sequence
wildcard_chunk = (
wildcard_variable_access
| variants
Expand Down Expand Up @@ -459,6 +516,7 @@ def create_parser(
variable_assignment.set_parse_action(_parse_variable_assignment_command)
prompt.set_parse_action(_parse_sequence_or_single_command)
variant_prompt.set_parse_action(_parse_sequence_or_single_command)
wrap_command.set_parse_action(_parse_wrap_command)
wildcard_prompt.set_parse_action(_parse_sequence_or_single_command)
return prompt

Expand Down
Loading

0 comments on commit 7127ee1

Please sign in to comment.