Skip to content

Commit

Permalink
Merge pull request #109 from akx/faster-ci
Browse files Browse the repository at this point in the history
Faster CI
  • Loading branch information
adieyal authored Nov 17, 2023
2 parents b66bea8 + c9df45e commit 1c3ae20
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 29 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
cache: "pip"
cache-dependency-path: |
pyproject.toml
- run: python -m pip install mypy -e .[dev,attentiongrabber,magicprompt,feelinglucky]
- run: python -m pip install mypy -e .[dev,attentiongrabber,feelinglucky,yaml]
- run: mypy --install-types --non-interactive src
test:
runs-on: ${{ matrix.os }}
Expand All @@ -45,7 +45,7 @@ jobs:
cache-dependency-path: |
pyproject.toml
- name: Install dependencies
run: python -m pip install -e .[dev,attentiongrabber,magicprompt,feelinglucky]
run: python -m pip install -e .[dev,attentiongrabber,feelinglucky,yaml]
- run: pytest --cov --cov-report=term-missing --cov-report=xml .
env:
PYPARSINGENABLEALLWARNINGS: 1
Expand All @@ -67,9 +67,6 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
pyproject.toml
- run: python -m pip install hatch
- run: hatch build -t wheel
- name: Publish package distributions to PyPI
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
attentiongrabber = [] # empty list for backwards compatibility (no "no extra" warnings)
magicprompt = ["transformers[torch]~=4.19"]
feelinglucky = ["requests~=2.28"]
yaml = ["pyyaml~=6.0"]
dev = [
"pytest-cov~=4.0",
"pytest-lazy-fixture~=0.6",
Expand Down Expand Up @@ -84,6 +85,10 @@ exclude = "tests"
module = "transformers"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "torch"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "spacy.*"
ignore_missing_imports = true
Expand Down
35 changes: 20 additions & 15 deletions src/dynamicprompts/generators/magicprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,13 @@

logger = logging.getLogger(__name__)

try:
if TYPE_CHECKING:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Pipeline,
pipeline,
set_seed,
)
except ImportError as ie:
raise ImportError(
"You need to install the transformers library to use the MagicPrompt generator. "
"You can do this by running `pip install -U dynamicprompts[magicprompt]`.",
) from ie

if TYPE_CHECKING:
import torch

DEFAULT_MODEL_NAME = "Gustavosta/MagicPrompt-Stable-Diffusion"
MAX_SEED = 2**32 - 1
Expand Down Expand Up @@ -71,6 +62,18 @@ def clean_up_magic_prompt(orig_prompt: str, prompt: str) -> str:
return prompt


def _import_transformers(): # pragma: no cover
try:
import transformers

return transformers
except ImportError as ie:
raise ImportError(
"You need to install the transformers library to use the MagicPrompt generator. "
"You can do this by running `pip install -U dynamicprompts[magicprompt]`.",
) from ie


class MagicPromptGenerator(PromptGenerator):
generator: Pipeline | None = None
tokenizer: AutoTokenizer | None = None
Expand All @@ -83,13 +86,14 @@ def _load_pipeline(self, model_name: str) -> Pipeline:
logger.warning("First load of MagicPrompt may take a while.")

if MagicPromptGenerator.generator is None:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
transformers = _import_transformers()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token_id = model.config.eos_token_id

MagicPromptGenerator.tokenizer = tokenizer
MagicPromptGenerator.model = model
MagicPromptGenerator.generator = pipeline(
MagicPromptGenerator.generator = transformers.pipeline(
task="text-generation",
tokenizer=tokenizer,
model=model,
Expand Down Expand Up @@ -123,6 +127,7 @@ def __init__(
:param blocklist_regex: A regex to use to filter out prompts that match it.
:param batch_size: The batch size to use when generating prompts.
"""
transformers = _import_transformers()
self._device = device
self.set_model(model_name)

Expand All @@ -140,7 +145,7 @@ def __init__(
self._blocklist_regex = None

if seed is not None:
set_seed(int(seed))
transformers.set_seed(int(seed))

self._batch_size = batch_size

Expand Down
20 changes: 11 additions & 9 deletions tests/generators/test_magicprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@

import pytest

pytest.importorskip("dynamicprompts.generators.magicprompt")

@pytest.fixture(autouse=True)
def mock_import_transformers(monkeypatch):
from dynamicprompts.generators import magicprompt

@pytest.mark.slow
class TestMagicPrompt:
def test_default_generator(self):
from dynamicprompts.generators.dummygenerator import DummyGenerator
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
monkeypatch.setattr(magicprompt, "_import_transformers", MagicMock())


def test_default_generator():
from dynamicprompts.generators.dummygenerator import DummyGenerator
from dynamicprompts.generators.magicprompt import MagicPromptGenerator

generator = MagicPromptGenerator()
assert isinstance(generator._prompt_generator, DummyGenerator)
generator = MagicPromptGenerator()
assert isinstance(generator._prompt_generator, DummyGenerator)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -121,7 +124,6 @@ def _generator(
assert not any(artist in magic_prompt for artist in boring_artists)


@pytest.mark.slow
def test_generate_passes_kwargs():
from dynamicprompts.generators.magicprompt import MagicPromptGenerator

Expand Down

0 comments on commit 1c3ae20

Please sign in to comment.