From 002d75179ae5a3b165a65c5cf49c00bf8f98e2df Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 29 Jul 2024 23:18:34 +0900 Subject: [PATCH] sample images for training --- library/sd3_train_utils.py | 348 ++++++++++++++++++++++++++++++++++++- sd3_train.py | 51 +++--- 2 files changed, 367 insertions(+), 32 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 8f99d9474..da0729506 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,14 +1,18 @@ import argparse -import glob import math import os -from typing import List, Optional, Tuple, Union +import toml +import json +import time +from typing import Dict, List, Optional, Tuple, Union import torch from safetensors.torch import save_file -from accelerate import Accelerator +from accelerate import Accelerator, PartialState +from tqdm import tqdm +from PIL import Image -from library import sd3_models, sd3_utils, train_util +from library import sd3_models, sd3_utils, strategy_base, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -276,10 +280,342 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin ) -def sample_images(*args, **kwargs): - return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +# temporary copied from sd3_minimal_inferece.py +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + latent = latent.to(dtype).to(device) + + # noise = get_noise(seed, latent).to(device) + if seed is not None: + generator = torch.manual_seed(seed) + noise = ( + torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") + .to(latent.dtype) + .to(device) + ) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_sigmas(model_sampling, steps).to(device) + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + return x + + +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + mmdit, + vae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + mmdit = accelerator.unwrap_model(mmdit) + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + org_vae_device = vae.device # will be on cpu + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + vae.to(org_vae_device) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + mmdit: sd3_models.MMDiT, + text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]], + vae: sd3_models.SDVAE, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + te_outputs = sample_prompts_te_outputs[prompt] + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) + te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + + lg_out, t5_out, pooled = te_outputs + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # encode negative prompts + if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs: + neg_te_outputs = sample_prompts_te_outputs[negative_prompt] + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) + neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + + lg_out, t5_out, pooled = neg_te_outputs + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # sample image + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + + # latent to image + with torch.no_grad(): + image = vae.decode(latents) + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + + image = Image.fromarray(decoded_np) + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + # region Diffusers diff --git a/sd3_train.py b/sd3_train.py index 617e30271..2f4ea8cb2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -299,6 +299,7 @@ def train(args): t5xxl.eval() # cache text encoder outputs + sample_prompts_te_outputs = None if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad here clip_l.to(accelerator.device) @@ -321,6 +322,22 @@ def train(args): with accelerator.autocast(): train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + prompts = sd3_train_utils.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_list = sd3_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list + ) + accelerator.wait_for_everyone() # load MMDIT @@ -635,10 +652,8 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) - # # For --sample_at_first - # sd3_train_utils.sample_images( - # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit - # ) + # For --sample_at_first + sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) # following function will be moved to sd3_train_utils @@ -831,17 +846,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): progress_bar.update(1) global_step += 1 - # sdxl_train_util.sample_images( - # accelerator, - # args, - # None, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [clip_l, clip_g], - # mmdit, - # ) + sd3_train_utils.sample_images( + accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -900,17 +907,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): vae, ) - # sdxl_train_util.sample_images( - # accelerator, - # args, - # epoch + 1, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [clip_l, clip_g], - # mmdit, - # ) + sd3_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) is_main_process = accelerator.is_main_process # if is_main_process: