Skip to content

Commit

Permalink
Update submodules as well since MPS changes were needed
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyOverby committed Sep 16, 2024
1 parent b755ebd commit 4e56a20
Show file tree
Hide file tree
Showing 26 changed files with 5,785 additions and 187 deletions.
170 changes: 159 additions & 11 deletions README.md

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
Expand Down Expand Up @@ -354,7 +358,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -368,7 +374,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand All @@ -380,7 +388,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down Expand Up @@ -471,7 +481,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.end_training()

if is_main_process and (args.save_state or args.save_state_on_train_end):
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す
Expand Down
55 changes: 40 additions & 15 deletions gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,7 @@ class BatchDataBase(NamedTuple):
clip_prompt: str
guide_image: Any
raw_prompt: str
file_name: Optional[str]


class BatchDataExt(NamedTuple):
Expand Down Expand Up @@ -2316,7 +2317,7 @@ def scale_and_round(x):
# このバッチの情報を取り出す
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
(step_first, _, _, _, init_image, mask_image, _, guide_image, _, _),
(
width,
height,
Expand All @@ -2339,6 +2340,7 @@ def scale_and_round(x):
prompts = []
negative_prompts = []
raw_prompts = []
filenames = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
Expand Down Expand Up @@ -2371,14 +2373,15 @@ def scale_and_round(x):
all_guide_images_are_same = True
for i, (
_,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename),
_,
) in enumerate(batch):
prompts.append(prompt)
negative_prompts.append(negative_prompt)
seeds.append(seed)
clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt)
filenames.append(filename)

if init_image is not None:
init_images.append(init_image)
Expand Down Expand Up @@ -2478,8 +2481,8 @@ def scale_and_round(x):
# save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames)
):
if highres_fix:
seed -= 1 # record original seed
Expand All @@ -2505,17 +2508,23 @@ def scale_and_round(x):
metadata.add_text("crop-top", str(crop_top))
metadata.add_text("crop-left", str(crop_left))

if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
if filename is not None:
fln = filename
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"

image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
if fln.endswith(".webp"):
image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy
else:
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)

if not args.no_preview and not highres_1st and args.interactive:
try:
Expand Down Expand Up @@ -2562,6 +2571,7 @@ def scale_and_round(x):
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
filename = None

if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
Expand Down Expand Up @@ -2783,6 +2793,12 @@ def scale_and_round(x):
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue

m = re.match(r"f (.+)", parg, re.IGNORECASE)
if m: # filename
filename = m.group(1)
logger.info(f"filename: {filename}")
continue

except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
Expand Down Expand Up @@ -2873,7 +2889,16 @@ def scale_and_round(x):
b1 = BatchData(
False,
BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
global_step,
prompt,
negative_prompt,
seed,
init_image,
mask_image,
clip_prompt,
guide_image,
raw_prompt,
filename,
),
BatchDataExt(
width,
Expand Down Expand Up @@ -2916,7 +2941,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

add_logging_arguments(parser)

parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
)
Expand Down
Binary file added library/.DS_Store
Binary file not shown.
106 changes: 106 additions & 0 deletions library/adafactor_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import math
import torch
from transformers import Adafactor

@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")

state = self.state[p]
grad_shape = grad.shape

factored, use_first_moment = Adafactor._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0

if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)

state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)

p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()

state["step"] += 1
state["RMS"] = Adafactor._rms(p_data_fp32)
lr = Adafactor._get_lr(group, state)

beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]

exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))

# Approximation of exponential moving average of square of gradient
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]

exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)

update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)

if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg

if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))

p_data_fp32.add_(-update)

if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)


@torch.no_grad()
def adafactor_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
adafactor_step_param(self, p, group)

return loss

def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)
7 changes: 7 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams):
class_tokens: Optional[str] = None
caption_extension: str = ".caption"
cache_info: bool = False
alpha_mask: bool = False


@dataclass
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
alpha_mask: bool = False


@dataclass
Expand Down Expand Up @@ -191,6 +193,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"caption_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
Expand All @@ -212,11 +215,13 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
"is_reg": bool,
"alpha_mask": bool,
}
# FT means FineTuning
FT_SUBSET_DISTINCT_SCHEMA = {
Required("metadata_file"): str,
"image_dir": str,
"alpha_mask": bool,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
Expand Down Expand Up @@ -523,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_separator: {subset.caption_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
Expand All @@ -536,6 +542,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask},
"""
),
" ",
Expand Down
14 changes: 11 additions & 3 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,20 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):


def apply_masked_loss(loss, batch):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss

# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss

Expand Down
2 changes: 1 addition & 1 deletion library/ipex/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pylint: disable=protected-access, missing-function-docstring, line-too-long

# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers

sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
Expand Down
Loading

0 comments on commit 4e56a20

Please sign in to comment.