Skip to content

Commit

Permalink
Updated how prompt seeds are generated
Browse files Browse the repository at this point in the history
These are now returned from the get_seeds function which decides whether if should be the same as image seeds or generated separately. Fixes #535
  • Loading branch information
adieyal committed Jul 8, 2023
1 parent 78d599c commit 2b74c19
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 37 deletions.
15 changes: 9 additions & 6 deletions sd_dynamic_prompts/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,24 +464,26 @@ def process(
else:
negative_generator = generator

all_seeds = None
if num_images and not unlink_seed_from_prompt:
p.all_seeds, p.all_subseeds = get_seeds(
prompt_seeds = p.all_seeds
if num_images:
image_seeds, image_subseeds, prompt_seeds = get_seeds(
p,
num_images,
use_fixed_seed,
is_combinatorial,
combinatorial_batches,
unlink_seed_from_prompt,
)
all_seeds = p.all_seeds
p.all_seeds = image_seeds
p.all_subseeds = image_subseeds

all_prompts, all_negative_prompts = generate_prompts(
generator,
negative_generator,
original_prompt,
original_negative_prompt,
num_images,
all_seeds,
prompt_seeds,
)

except GeneratorException as e:
Expand All @@ -493,12 +495,13 @@ def process(
p.n_iter = math.ceil(updated_count / p.batch_size)

if num_images != updated_count:
p.all_seeds, p.all_subseeds = get_seeds(
p.all_seeds, p.all_subseeds, _ = get_seeds(
p,
updated_count,
use_fixed_seed,
is_combinatorial,
combinatorial_batches,
unlink_seed_from_prompt,
)

if updated_count > 1:
Expand Down
33 changes: 24 additions & 9 deletions sd_dynamic_prompts/helpers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from __future__ import annotations

import logging
import random
from pathlib import Path

from dynamicprompts.generators.promptgenerator import PromptGenerator

logger = logging.getLogger(__name__)


def get_fixed_seed(seed):
# Copied from auto1111 modules/processing.py
if seed is None or seed == "" or seed == -1:
return int(random.randrange(4294967294))

return seed


def get_seeds(
p,
num_seeds,
use_fixed_seed,
is_combinatorial=False,
combinatorial_batches=1,
unlink_seed_from_prompt=False,
):
if p.subseed_strength != 0:
seed = int(p.all_seeds[0])
Expand All @@ -24,22 +34,27 @@ def get_seeds(

if use_fixed_seed:
if is_combinatorial:
all_seeds = []
all_subseeds = [subseed] * num_seeds
image_seeds = []
image_subseeds = [subseed] * num_seeds
for i in range(combinatorial_batches):
all_seeds.extend([seed + i] * (num_seeds // combinatorial_batches))
image_seeds.extend([seed + i] * (num_seeds // combinatorial_batches))
else:
all_seeds = [seed] * num_seeds
all_subseeds = [subseed] * num_seeds
image_seeds = [seed] * num_seeds
image_subseeds = [subseed] * num_seeds
else:
if p.subseed_strength == 0:
all_seeds = [seed + i for i in range(num_seeds)]
image_seeds = [seed + i for i in range(num_seeds)]
else:
all_seeds = [seed] * num_seeds
image_seeds = [seed] * num_seeds

all_subseeds = [subseed + i for i in range(num_seeds)]
image_subseeds = [subseed + i for i in range(num_seeds)]

if unlink_seed_from_prompt:
prompt_seeds = [get_fixed_seed(None) for _ in range(num_seeds)]
else:
prompt_seeds = image_seeds

return all_seeds, all_subseeds
return image_seeds, image_subseeds, prompt_seeds


def should_freeze_prompt(p):
Expand Down
5 changes: 3 additions & 2 deletions tests/prompts/test_frozenprompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@


def test_repeats_correctly():
generator = FrozenPromptGenerator(RandomPromptGenerator())
generator = FrozenPromptGenerator(
RandomPromptGenerator(unlink_seed_from_prompt=True),
)
template = "{A|B|C|D|E|F|G|H|I|J|K}"
prompts = generator.generate(template, 10)

Expand All @@ -15,5 +17,4 @@ def test_repeats_correctly():

assert len(prompts2) == 10
assert len(set(prompts2)) == 1

assert prompts[0] != prompts2[0]
77 changes: 57 additions & 20 deletions tests/prompts/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,67 +22,104 @@ def processing():
def test_get_seeds_with_fixed_seed(processing):
num_seeds = 10

seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True)
assert seeds == [processing.seed] * num_seeds
assert subseeds == [processing.subseed] * num_seeds
image_seeds, image_subseeds, _ = get_seeds(
processing,
num_seeds,
use_fixed_seed=True,
)
assert image_seeds == [processing.seed] * num_seeds
assert image_subseeds == [processing.subseed] * num_seeds

processing.subseed_strength = 0.5

seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True)
assert seeds == [processing.all_seeds[0]] * num_seeds
assert subseeds == [processing.all_subseeds[0]] * num_seeds
image_seeds, image_subseeds, _ = get_seeds(
processing,
num_seeds,
use_fixed_seed=True,
)
assert image_seeds == [processing.all_seeds[0]] * num_seeds
assert image_subseeds == [processing.all_subseeds[0]] * num_seeds


def test_get_seeds_with_fixed_seed_batched_combinatorial(processing):
num_seeds = 10
combinatorial_batches = 3
seeds, subseeds = get_seeds(
image_seeds, image_subseeds, _ = get_seeds(
processing,
num_seeds,
use_fixed_seed=True,
is_combinatorial=True,
combinatorial_batches=combinatorial_batches,
)
seed0 = processing.seed
assert seeds == (
assert image_seeds == (
[seed0] * (num_seeds // 3)
+ [seed0 + 1] * (num_seeds // 3)
+ [seed0 + 2] * (num_seeds // 3)
)
assert subseeds == [processing.subseed] * num_seeds
assert image_subseeds == [processing.subseed] * num_seeds

processing.subseed_strength = 0.5

seeds, subseeds = get_seeds(
image_seeds, image_subseeds, _ = get_seeds(
processing,
num_seeds,
use_fixed_seed=True,
is_combinatorial=True,
combinatorial_batches=combinatorial_batches,
)
seed0 = processing.all_seeds[0]
assert seeds == (
assert image_seeds == (
[seed0] * (num_seeds // 3)
+ [seed0 + 1] * (num_seeds // 3)
+ [seed0 + 2] * (num_seeds // 3)
)
assert subseeds == [processing.all_subseeds[0]] * num_seeds
assert image_subseeds == [processing.all_subseeds[0]] * num_seeds


def test_get_seeds_with_random_seed(processing):
num_seeds = 10

seed, subseed = processing.seed, processing.subseed
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
assert seeds == list(range(seed, seed + num_seeds))
assert subseeds == list(range(subseed, subseed + num_seeds))
image_seeds, image_subseeds = processing.seed, processing.subseed
seeds, subseeds, _ = get_seeds(
processing,
num_seeds=num_seeds,
use_fixed_seed=False,
)
assert seeds == list(range(image_seeds, image_seeds + num_seeds))
assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds))

processing.subseed_strength = 0.5

seed, subseed = processing.all_seeds[0], processing.all_subseeds[0]
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
assert seeds == [seed] * num_seeds
assert subseeds == list(range(subseed, subseed + num_seeds))
image_seeds, image_subseeds = processing.all_seeds[0], processing.all_subseeds[0]
seeds, subseeds, _ = get_seeds(
processing,
num_seeds=num_seeds,
use_fixed_seed=False,
)
assert seeds == [image_seeds] * num_seeds
assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds))


@pytest.mark.parametrize("use_fixed_seed", [True, False])
def test_get_with_unlinked_seed(processing, use_fixed_seed):
num_seeds = 10

image_seeds, _, prompt_seeds = get_seeds(
processing,
num_seeds,
use_fixed_seed=use_fixed_seed,
unlink_seed_from_prompt=False,
)
assert image_seeds == prompt_seeds

image_seeds, _, prompt_seeds = get_seeds(
processing,
num_seeds,
use_fixed_seed=use_fixed_seed,
unlink_seed_from_prompt=True,
)
assert image_seeds != prompt_seeds


def test_load_magicprompt_models():
Expand Down

0 comments on commit 2b74c19

Please sign in to comment.