From 719ab7142d293311b7186aa898f48c3966150b9f Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 13 Nov 2023 12:02:08 +0200 Subject: [PATCH 1/4] CI: don't use big cache for publish step --- .github/workflows/ci.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ffa966e..68ad75f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,9 +67,6 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.11" - cache: "pip" - cache-dependency-path: | - pyproject.toml - run: python -m pip install hatch - run: hatch build -t wheel - name: Publish package distributions to PyPI From f919d45e31d1a15f197c25a5086e93a07b7ee5ad Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 13 Nov 2023 12:14:36 +0200 Subject: [PATCH 2/4] Late-import transformers for magic prompt --- pyproject.toml | 4 +++ src/dynamicprompts/generators/magicprompt.py | 35 +++++++++++--------- tests/generators/test_magicprompt.py | 20 ++++++----- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0e00396..804b99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,10 @@ exclude = "tests" module = "transformers" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "torch" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "spacy.*" ignore_missing_imports = true diff --git a/src/dynamicprompts/generators/magicprompt.py b/src/dynamicprompts/generators/magicprompt.py index 96edd6d..7b9b780 100644 --- a/src/dynamicprompts/generators/magicprompt.py +++ b/src/dynamicprompts/generators/magicprompt.py @@ -9,22 +9,13 @@ logger = logging.getLogger(__name__) -try: +if TYPE_CHECKING: + import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, Pipeline, - pipeline, - set_seed, ) -except ImportError as ie: - raise ImportError( - "You need to install the transformers library to use the MagicPrompt generator. " - "You can do this by running `pip install -U dynamicprompts[magicprompt]`.", - ) from ie - -if TYPE_CHECKING: - import torch DEFAULT_MODEL_NAME = "Gustavosta/MagicPrompt-Stable-Diffusion" MAX_SEED = 2**32 - 1 @@ -71,6 +62,18 @@ def clean_up_magic_prompt(orig_prompt: str, prompt: str) -> str: return prompt +def _import_transformers(): # pragma: no cover + try: + import transformers + + return transformers + except ImportError as ie: + raise ImportError( + "You need to install the transformers library to use the MagicPrompt generator. " + "You can do this by running `pip install -U dynamicprompts[magicprompt]`.", + ) from ie + + class MagicPromptGenerator(PromptGenerator): generator: Pipeline | None = None tokenizer: AutoTokenizer | None = None @@ -83,13 +86,14 @@ def _load_pipeline(self, model_name: str) -> Pipeline: logger.warning("First load of MagicPrompt may take a while.") if MagicPromptGenerator.generator is None: - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name) + transformers = _import_transformers() + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + model = transformers.AutoModelForCausalLM.from_pretrained(model_name) tokenizer.pad_token_id = model.config.eos_token_id MagicPromptGenerator.tokenizer = tokenizer MagicPromptGenerator.model = model - MagicPromptGenerator.generator = pipeline( + MagicPromptGenerator.generator = transformers.pipeline( task="text-generation", tokenizer=tokenizer, model=model, @@ -123,6 +127,7 @@ def __init__( :param blocklist_regex: A regex to use to filter out prompts that match it. :param batch_size: The batch size to use when generating prompts. """ + transformers = _import_transformers() self._device = device self.set_model(model_name) @@ -140,7 +145,7 @@ def __init__( self._blocklist_regex = None if seed is not None: - set_seed(int(seed)) + transformers.set_seed(int(seed)) self._batch_size = batch_size diff --git a/tests/generators/test_magicprompt.py b/tests/generators/test_magicprompt.py index d81ee90..beff031 100644 --- a/tests/generators/test_magicprompt.py +++ b/tests/generators/test_magicprompt.py @@ -6,17 +6,20 @@ import pytest -pytest.importorskip("dynamicprompts.generators.magicprompt") +@pytest.fixture(autouse=True) +def mock_import_transformers(monkeypatch): + from dynamicprompts.generators import magicprompt -@pytest.mark.slow -class TestMagicPrompt: - def test_default_generator(self): - from dynamicprompts.generators.dummygenerator import DummyGenerator - from dynamicprompts.generators.magicprompt import MagicPromptGenerator + monkeypatch.setattr(magicprompt, "_import_transformers", MagicMock()) + + +def test_default_generator(): + from dynamicprompts.generators.dummygenerator import DummyGenerator + from dynamicprompts.generators.magicprompt import MagicPromptGenerator - generator = MagicPromptGenerator() - assert isinstance(generator._prompt_generator, DummyGenerator) + generator = MagicPromptGenerator() + assert isinstance(generator._prompt_generator, DummyGenerator) @pytest.mark.parametrize( @@ -121,7 +124,6 @@ def _generator( assert not any(artist in magic_prompt for artist in boring_artists) -@pytest.mark.slow def test_generate_passes_kwargs(): from dynamicprompts.generators.magicprompt import MagicPromptGenerator From 93980f816248c74bd4118d3881bc9e28126bbf16 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 13 Nov 2023 12:15:06 +0200 Subject: [PATCH 3/4] CI: don't bother with transformers --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68ad75f..54e9248 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: cache: "pip" cache-dependency-path: | pyproject.toml - - run: python -m pip install mypy -e .[dev,attentiongrabber,magicprompt,feelinglucky] + - run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky] - run: mypy --install-types --non-interactive src test: runs-on: ${{ matrix.os }} @@ -45,7 +45,7 @@ jobs: cache-dependency-path: | pyproject.toml - name: Install dependencies - run: python -m pip install -e .[dev,attentiongrabber,magicprompt,feelinglucky] + run: python -m pip install -e .[dev,attentiongrabber,feelinglucky] - run: pytest --cov --cov-report=term-missing --cov-report=xml . env: PYPARSINGENABLEALLWARNINGS: 1 From c9df45e246ccec0bff52fa24ab73e0db8e0336fc Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 13 Nov 2023 12:27:47 +0200 Subject: [PATCH 4/4] Add `yaml` extra (now that it's not implicitly required via transformers) --- .github/workflows/ci.yml | 4 ++-- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 54e9248..1532dea 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: cache: "pip" cache-dependency-path: | pyproject.toml - - run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky] + - run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky,yaml] - run: mypy --install-types --non-interactive src test: runs-on: ${{ matrix.os }} @@ -45,7 +45,7 @@ jobs: cache-dependency-path: | pyproject.toml - name: Install dependencies - run: python -m pip install -e .[dev,attentiongrabber,feelinglucky] + run: python -m pip install -e .[dev,attentiongrabber,feelinglucky,yaml] - run: pytest --cov --cov-report=term-missing --cov-report=xml . env: PYPARSINGENABLEALLWARNINGS: 1 diff --git a/pyproject.toml b/pyproject.toml index 804b99c..ff19746 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ attentiongrabber = [] # empty list for backwards compatibility (no "no extra" warnings) magicprompt = ["transformers[torch]~=4.19"] feelinglucky = ["requests~=2.28"] +yaml = ["pyyaml~=6.0"] dev = [ "pytest-cov~=4.0", "pytest-lazy-fixture~=0.6",