From 2b74c19caeb81634fd843e5bb95ebb3263619aa7 Mon Sep 17 00:00:00 2001 From: Adi Eyal Date: Sat, 8 Jul 2023 22:27:30 +0200 Subject: [PATCH] Updated how prompt seeds are generated These are now returned from the get_seeds function which decides whether if should be the same as image seeds or generated separately. Fixes #535 --- sd_dynamic_prompts/dynamic_prompting.py | 15 ++-- sd_dynamic_prompts/helpers.py | 33 ++++++--- tests/prompts/test_frozenprompt_generator.py | 5 +- tests/prompts/test_helpers.py | 77 +++++++++++++++----- 4 files changed, 93 insertions(+), 37 deletions(-) diff --git a/sd_dynamic_prompts/dynamic_prompting.py b/sd_dynamic_prompts/dynamic_prompting.py index 69189835..0cc28776 100644 --- a/sd_dynamic_prompts/dynamic_prompting.py +++ b/sd_dynamic_prompts/dynamic_prompting.py @@ -464,16 +464,18 @@ def process( else: negative_generator = generator - all_seeds = None - if num_images and not unlink_seed_from_prompt: - p.all_seeds, p.all_subseeds = get_seeds( + prompt_seeds = p.all_seeds + if num_images: + image_seeds, image_subseeds, prompt_seeds = get_seeds( p, num_images, use_fixed_seed, is_combinatorial, combinatorial_batches, + unlink_seed_from_prompt, ) - all_seeds = p.all_seeds + p.all_seeds = image_seeds + p.all_subseeds = image_subseeds all_prompts, all_negative_prompts = generate_prompts( generator, @@ -481,7 +483,7 @@ def process( original_prompt, original_negative_prompt, num_images, - all_seeds, + prompt_seeds, ) except GeneratorException as e: @@ -493,12 +495,13 @@ def process( p.n_iter = math.ceil(updated_count / p.batch_size) if num_images != updated_count: - p.all_seeds, p.all_subseeds = get_seeds( + p.all_seeds, p.all_subseeds, _ = get_seeds( p, updated_count, use_fixed_seed, is_combinatorial, combinatorial_batches, + unlink_seed_from_prompt, ) if updated_count > 1: diff --git a/sd_dynamic_prompts/helpers.py b/sd_dynamic_prompts/helpers.py index 19dd4036..84307f3e 100644 --- a/sd_dynamic_prompts/helpers.py +++ b/sd_dynamic_prompts/helpers.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import random from pathlib import Path from dynamicprompts.generators.promptgenerator import PromptGenerator @@ -8,12 +9,21 @@ logger = logging.getLogger(__name__) +def get_fixed_seed(seed): + # Copied from auto1111 modules/processing.py + if seed is None or seed == "" or seed == -1: + return int(random.randrange(4294967294)) + + return seed + + def get_seeds( p, num_seeds, use_fixed_seed, is_combinatorial=False, combinatorial_batches=1, + unlink_seed_from_prompt=False, ): if p.subseed_strength != 0: seed = int(p.all_seeds[0]) @@ -24,22 +34,27 @@ def get_seeds( if use_fixed_seed: if is_combinatorial: - all_seeds = [] - all_subseeds = [subseed] * num_seeds + image_seeds = [] + image_subseeds = [subseed] * num_seeds for i in range(combinatorial_batches): - all_seeds.extend([seed + i] * (num_seeds // combinatorial_batches)) + image_seeds.extend([seed + i] * (num_seeds // combinatorial_batches)) else: - all_seeds = [seed] * num_seeds - all_subseeds = [subseed] * num_seeds + image_seeds = [seed] * num_seeds + image_subseeds = [subseed] * num_seeds else: if p.subseed_strength == 0: - all_seeds = [seed + i for i in range(num_seeds)] + image_seeds = [seed + i for i in range(num_seeds)] else: - all_seeds = [seed] * num_seeds + image_seeds = [seed] * num_seeds - all_subseeds = [subseed + i for i in range(num_seeds)] + image_subseeds = [subseed + i for i in range(num_seeds)] + + if unlink_seed_from_prompt: + prompt_seeds = [get_fixed_seed(None) for _ in range(num_seeds)] + else: + prompt_seeds = image_seeds - return all_seeds, all_subseeds + return image_seeds, image_subseeds, prompt_seeds def should_freeze_prompt(p): diff --git a/tests/prompts/test_frozenprompt_generator.py b/tests/prompts/test_frozenprompt_generator.py index c196906a..21fdf341 100644 --- a/tests/prompts/test_frozenprompt_generator.py +++ b/tests/prompts/test_frozenprompt_generator.py @@ -4,7 +4,9 @@ def test_repeats_correctly(): - generator = FrozenPromptGenerator(RandomPromptGenerator()) + generator = FrozenPromptGenerator( + RandomPromptGenerator(unlink_seed_from_prompt=True), + ) template = "{A|B|C|D|E|F|G|H|I|J|K}" prompts = generator.generate(template, 10) @@ -15,5 +17,4 @@ def test_repeats_correctly(): assert len(prompts2) == 10 assert len(set(prompts2)) == 1 - assert prompts[0] != prompts2[0] diff --git a/tests/prompts/test_helpers.py b/tests/prompts/test_helpers.py index 74bcbef1..d5038c0a 100644 --- a/tests/prompts/test_helpers.py +++ b/tests/prompts/test_helpers.py @@ -22,21 +22,29 @@ def processing(): def test_get_seeds_with_fixed_seed(processing): num_seeds = 10 - seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True) - assert seeds == [processing.seed] * num_seeds - assert subseeds == [processing.subseed] * num_seeds + image_seeds, image_subseeds, _ = get_seeds( + processing, + num_seeds, + use_fixed_seed=True, + ) + assert image_seeds == [processing.seed] * num_seeds + assert image_subseeds == [processing.subseed] * num_seeds processing.subseed_strength = 0.5 - seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True) - assert seeds == [processing.all_seeds[0]] * num_seeds - assert subseeds == [processing.all_subseeds[0]] * num_seeds + image_seeds, image_subseeds, _ = get_seeds( + processing, + num_seeds, + use_fixed_seed=True, + ) + assert image_seeds == [processing.all_seeds[0]] * num_seeds + assert image_subseeds == [processing.all_subseeds[0]] * num_seeds def test_get_seeds_with_fixed_seed_batched_combinatorial(processing): num_seeds = 10 combinatorial_batches = 3 - seeds, subseeds = get_seeds( + image_seeds, image_subseeds, _ = get_seeds( processing, num_seeds, use_fixed_seed=True, @@ -44,16 +52,16 @@ def test_get_seeds_with_fixed_seed_batched_combinatorial(processing): combinatorial_batches=combinatorial_batches, ) seed0 = processing.seed - assert seeds == ( + assert image_seeds == ( [seed0] * (num_seeds // 3) + [seed0 + 1] * (num_seeds // 3) + [seed0 + 2] * (num_seeds // 3) ) - assert subseeds == [processing.subseed] * num_seeds + assert image_subseeds == [processing.subseed] * num_seeds processing.subseed_strength = 0.5 - seeds, subseeds = get_seeds( + image_seeds, image_subseeds, _ = get_seeds( processing, num_seeds, use_fixed_seed=True, @@ -61,28 +69,57 @@ def test_get_seeds_with_fixed_seed_batched_combinatorial(processing): combinatorial_batches=combinatorial_batches, ) seed0 = processing.all_seeds[0] - assert seeds == ( + assert image_seeds == ( [seed0] * (num_seeds // 3) + [seed0 + 1] * (num_seeds // 3) + [seed0 + 2] * (num_seeds // 3) ) - assert subseeds == [processing.all_subseeds[0]] * num_seeds + assert image_subseeds == [processing.all_subseeds[0]] * num_seeds def test_get_seeds_with_random_seed(processing): num_seeds = 10 - seed, subseed = processing.seed, processing.subseed - seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False) - assert seeds == list(range(seed, seed + num_seeds)) - assert subseeds == list(range(subseed, subseed + num_seeds)) + image_seeds, image_subseeds = processing.seed, processing.subseed + seeds, subseeds, _ = get_seeds( + processing, + num_seeds=num_seeds, + use_fixed_seed=False, + ) + assert seeds == list(range(image_seeds, image_seeds + num_seeds)) + assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds)) processing.subseed_strength = 0.5 - seed, subseed = processing.all_seeds[0], processing.all_subseeds[0] - seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False) - assert seeds == [seed] * num_seeds - assert subseeds == list(range(subseed, subseed + num_seeds)) + image_seeds, image_subseeds = processing.all_seeds[0], processing.all_subseeds[0] + seeds, subseeds, _ = get_seeds( + processing, + num_seeds=num_seeds, + use_fixed_seed=False, + ) + assert seeds == [image_seeds] * num_seeds + assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds)) + + +@pytest.mark.parametrize("use_fixed_seed", [True, False]) +def test_get_with_unlinked_seed(processing, use_fixed_seed): + num_seeds = 10 + + image_seeds, _, prompt_seeds = get_seeds( + processing, + num_seeds, + use_fixed_seed=use_fixed_seed, + unlink_seed_from_prompt=False, + ) + assert image_seeds == prompt_seeds + + image_seeds, _, prompt_seeds = get_seeds( + processing, + num_seeds, + use_fixed_seed=use_fixed_seed, + unlink_seed_from_prompt=True, + ) + assert image_seeds != prompt_seeds def test_load_magicprompt_models():