diff --git a/README.md b/README.md index 9a603b281..51e4635bb 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 19, 2024: +In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. + +An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code. + Aug 18, 2024: Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. diff --git a/flux_train.py b/flux_train.py index b294ce42a..669963856 100644 --- a/flux_train.py +++ b/flux_train.py @@ -759,6 +759,12 @@ def setup_parser() -> argparse.ArgumentParser: add_custom_train_arguments(parser) # TODO remove this from here flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + parser.add_argument( "--fused_optimizer_groups", type=int, diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 167d61c7e..3f9e8660f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -20,7 +20,7 @@ init_ipex() -from .utils import setup_logging +from .utils import setup_logging, mem_eff_save_file setup_logging() import logging @@ -409,19 +409,28 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): return model_pred, weighting -def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None): +def save_models( + ckpt_path: str, + flux: flux_models.Flux, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): state_dict = {} def update_sd(prefix, sd): for k, v in sd.items(): key = prefix + k - if save_dtype is not None: + if save_dtype is not None and v.dtype != save_dtype: v = v.detach().clone().to("cpu").to(save_dtype) state_dict[key] = v update_sd("", flux.state_dict()) - save_file(state_dict, ckpt_path, metadata=sai_metadata) + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) def save_flux_model_on_train_end( @@ -429,7 +438,7 @@ def save_flux_model_on_train_end( ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") - save_models(ckpt_file, flux, sai_metadata, save_dtype) + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) @@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise( ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") - save_models(ckpt_file, flux, sai_metadata, save_dtype) + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_epoch_end_or_stepwise_common( args, diff --git a/library/utils.py b/library/utils.py index 3037c055d..7de22d5a9 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,9 +1,12 @@ import logging import sys import threading +from typing import * +import json +import struct + import torch from torchvision import transforms -from typing import * from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput @@ -79,6 +82,76 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack("