Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API Endpoints [POC] #611

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 152 additions & 69 deletions sd_dynamic_prompts/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@

import dynamicprompts
import gradio as gr
import modules.scripts as scripts
import torch
from fastapi import FastAPI, Body
from dynamicprompts.generators.promptgenerator import GeneratorException
from dynamicprompts.parser.parse import ParserConfig
from dynamicprompts.wildcards import WildcardManager
from modules.processing import fix_seed
from modules.shared import opts

from sd_dynamic_prompts import __version__, callbacks
from sd_dynamic_prompts.element_ids import make_element_id
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
from sd_dynamic_prompts.helpers import (
generate_prompts,
generate_prompts as generate_prompts_helper,
get_seeds,
load_magicprompt_models,
repeat_iterable_to_length,
Expand All @@ -33,6 +30,11 @@
from sd_dynamic_prompts.pnginfo_saver import PngInfoSaver
from sd_dynamic_prompts.prompt_writer import PromptWriter

import modules.scripts as scripts
from modules.script_callbacks import on_app_started
from modules.processing import fix_seed, StableDiffusionProcessing
from modules.shared import opts

VERSION = __version__

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -372,24 +374,6 @@ def process(
logger.debug("Dynamic prompts disabled - exiting")
return p

ignore_whitespace = opts.dp_ignore_whitespace

self._pnginfo_saver.enabled = opts.dp_write_raw_template
self._prompt_writer.enabled = opts.dp_write_prompts_to_file
self._limit_jinja_prompts = opts.dp_limit_jinja_prompts
self._auto_purge_cache = opts.dp_auto_purge_cache
self._wildcard_manager.dedup_wildcards = not opts.dp_wildcard_manager_no_dedupe
self._wildcard_manager.sort_wildcards = not opts.dp_wildcard_manager_no_sort
self._wildcard_manager.shuffle_wildcards = opts.dp_wildcard_manager_shuffle

magicprompt_batch_size = opts.dp_magicprompt_batch_size

parser_config = ParserConfig(
variant_start=opts.dp_parser_variant_start,
variant_end=opts.dp_parser_variant_end,
wildcard_wrap=opts.dp_parser_wildcard_wrap,
)

fix_seed(p)

# Save original prompts before we touch `p.prompt`/`p.hr_prompt` etc.
Expand Down Expand Up @@ -417,16 +401,128 @@ def process(
original_seed = p.seed
num_images = p.n_iter * p.batch_size

combinatorial_batches = int(combinatorial_batches)
self._auto_purge_cache = opts.dp_auto_purge_cache
if self._auto_purge_cache:
self._wildcard_manager.clear_cache()

all_prompts, all_negative_prompts = self.generate_prompts(
p=p,
original_prompt=original_prompt,
original_negative_prompt=original_negative_prompt,
original_seed=original_seed,
num_images=num_images,
is_combinatorial=is_combinatorial,
combinatorial_batches=combinatorial_batches,
is_magic_prompt=is_magic_prompt,
is_feeling_lucky=is_feeling_lucky,
is_attention_grabber=is_attention_grabber,
min_attention=min_attention,
max_attention=max_attention,
magic_prompt_length=magic_prompt_length,
magic_temp_value=magic_temp_value,
use_fixed_seed=use_fixed_seed,
unlink_seed_from_prompt=unlink_seed_from_prompt,
disable_negative_prompt=disable_negative_prompt,
enable_jinja_templates=enable_jinja_templates,
max_generations=max_generations,
magic_model=magic_model,
magic_blocklist_regex=magic_blocklist_regex,
)

updated_count = len(all_prompts)
p.n_iter = math.ceil(updated_count / p.batch_size)

if num_images != updated_count:
p.all_seeds, p.all_subseeds = get_seeds(
p,
updated_count,
use_fixed_seed,
is_combinatorial,
combinatorial_batches,
)

if updated_count > 1:
logger.info(
f"Prompt matrix will create {updated_count} images in a total of {p.n_iter} batches.",
)

self._prompt_writer.set_data(
positive_template=original_prompt,
negative_template=original_negative_prompt,
positive_prompts=all_prompts,
negative_prompts=all_negative_prompts,
)

p.all_prompts = all_prompts
p.all_negative_prompts = all_negative_prompts
if no_image_generation:
logger.debug("No image generation requested - exiting")
# Need a minimum of batch size images to avoid errors
p.batch_size = 1
p.all_prompts = all_prompts[0:1]

p.prompt_for_display = original_prompt
p.prompt = original_prompt

if hr_fix_enabled:
p.all_hr_prompts = _get_hr_fix_prompts(
all_prompts,
original_hr_prompt,
original_prompt,
)
p.all_hr_negative_prompts = _get_hr_fix_prompts(
all_negative_prompts,
original_negative_hr_prompt,
original_negative_prompt,
)

def generate_prompts(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just generating prompts shouldn't require the full p object, I don't think..?

Further, maybe this should be made a free function, and it should probably be kwarg-only ((*, original_prompt...)) because otherwise it's super easy to pass arguments in the wrong order.

Copy link
Author

@ArrowM ArrowM Aug 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p is used for its additional seed settings. see around dynamic_prompting.py.py#558 and around dynamic_prompting.py.py#572.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but those probably ought to be moved on from this function?

Copy link
Author

@ArrowM ArrowM Aug 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not seeing an easy way to move them out, any suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, GeneratorBuilder().set_context(p) is a bit problematic, but that's only used for Jinja context – the p-to-context code in create_jinja_generator should be refactored out. (I think Jinja over the API would fail right now since p.sd_model will not get set.)

With that done, I the p.all_seeds, p.all_subseeds = get_seeds(...) manipulation could be done outside the prompt-generation function, couldn't it?

self,
*,
p,
original_prompt: str,
original_negative_prompt: str,
original_seed: int,
num_images: int = 1,
is_combinatorial: bool = False,
combinatorial_batches: int = 1,
is_magic_prompt: bool = False,
is_feeling_lucky: bool = False,
is_attention_grabber: bool = False,
min_attention: float = 1.1,
max_attention: float = 1.5,
magic_prompt_length: int = 100,
magic_temp_value: float = 0.7,
use_fixed_seed: bool = False,
unlink_seed_from_prompt: bool = False,
disable_negative_prompt: bool = False,
enable_jinja_templates: bool = False,
max_generations: int = 0,
magic_model: str | None = "Gustavosta/MagicPrompt-Stable-Diffusion",
magic_blocklist_regex: str | None = "",
):
self._limit_jinja_prompts = opts.dp_limit_jinja_prompts
self._pnginfo_saver.enabled = opts.dp_write_raw_template
self._prompt_writer.enabled = opts.dp_write_prompts_to_file
self._wildcard_manager.dedup_wildcards = not opts.dp_wildcard_manager_no_dedupe
self._wildcard_manager.sort_wildcards = not opts.dp_wildcard_manager_no_sort
self._wildcard_manager.shuffle_wildcards = opts.dp_wildcard_manager_shuffle

ignore_whitespace = opts.dp_ignore_whitespace
magicprompt_batch_size = opts.dp_magicprompt_batch_size
parser_config = ParserConfig(
variant_start=opts.dp_parser_variant_start,
variant_end=opts.dp_parser_variant_end,
wildcard_wrap=opts.dp_parser_wildcard_wrap,
)

if is_combinatorial:
if max_generations == 0:
num_images = None
else:
num_images = max_generations

combinatorial_batches = int(combinatorial_batches)
if self._auto_purge_cache:
self._wildcard_manager.clear_cache()

try:
logger.debug("Creating generator")

Expand Down Expand Up @@ -482,7 +578,7 @@ def process(
)
all_seeds = p.all_seeds

all_prompts, all_negative_prompts = generate_prompts(
all_prompts, all_negative_prompts = generate_prompts_helper(
prompt_generator=generator,
negative_prompt_generator=negative_generator,
prompt=original_prompt,
Expand All @@ -496,49 +592,36 @@ def process(
all_prompts = [str(e)]
all_negative_prompts = [str(e)]

updated_count = len(all_prompts)
p.n_iter = math.ceil(updated_count / p.batch_size)
return all_prompts, all_negative_prompts

if num_images != updated_count:
p.all_seeds, p.all_subseeds = get_seeds(
p,
updated_count,
use_fixed_seed,
is_combinatorial,
combinatorial_batches,
)

if updated_count > 1:
logger.info(
f"Prompt matrix will create {updated_count} images in a total of {p.n_iter} batches.",
)

self._prompt_writer.set_data(
positive_template=original_prompt,
negative_template=original_negative_prompt,
positive_prompts=all_prompts,
negative_prompts=all_negative_prompts,
def api(_: gr.Blocks, app: FastAPI):
@app.post("/dynamicprompts/evaluate")
async def evaluate(
prompt: str = Body("", title="Prompt"),
negative_prompt: str = Body("", title="Negative Prompt"),
is_combinatorial: bool = Body(False, title="Is combinatorial"),
combinatorial_batches: int = Body(1, title="Combinatorial batches"),
batch_size: int = Body(1, title="Batch size"),
max_generations: int = Body(0, title="Max generations"),
seed: int = Body(1, title="Seed"),
):
script = Script()

all_prompts, all_negative_prompts = script.generate_prompts(
p=StableDiffusionProcessing(),
original_prompt=prompt,
original_negative_prompt=negative_prompt,
original_seed=seed,
num_images=batch_size,
is_combinatorial=is_combinatorial,
combinatorial_batches=combinatorial_batches,
max_generations=max_generations,
)
return {
"all_prompts": all_prompts,
"all_negative_prompts": all_negative_prompts,
}

p.all_prompts = all_prompts
p.all_negative_prompts = all_negative_prompts
if no_image_generation:
logger.debug("No image generation requested - exiting")
# Need a minimum of batch size images to avoid errors
p.batch_size = 1
p.all_prompts = all_prompts[0:1]

p.prompt_for_display = original_prompt
p.prompt = original_prompt

if hr_fix_enabled:
p.all_hr_prompts = _get_hr_fix_prompts(
all_prompts,
original_hr_prompt,
original_prompt,
)
p.all_hr_negative_prompts = _get_hr_fix_prompts(
all_negative_prompts,
original_negative_hr_prompt,
original_negative_prompt,
)
on_app_started(api)