Skip to content

Commit

Permalink
feat: reduce memory usage and add memory efficient option for model s…
Browse files Browse the repository at this point in the history
…aving
  • Loading branch information
kohya-ss committed Aug 19, 2024
1 parent 6e72a79 commit 486fe8f
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 7 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`

Aug 19, 2024:
In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason.

An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code.

Aug 18, 2024:
Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.

Expand Down
6 changes: 6 additions & 0 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,12 @@ def setup_parser() -> argparse.ArgumentParser:
add_custom_train_arguments(parser) # TODO remove this from here
flux_train_utils.add_flux_train_arguments(parser)

parser.add_argument(
"--mem_eff_save",
action="store_true",
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
)

parser.add_argument(
"--fused_optimizer_groups",
type=int,
Expand Down
21 changes: 15 additions & 6 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

init_ipex()

from .utils import setup_logging
from .utils import setup_logging, mem_eff_save_file

setup_logging()
import logging
Expand Down Expand Up @@ -409,27 +409,36 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
return model_pred, weighting


def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None):
def save_models(
ckpt_path: str,
flux: flux_models.Flux,
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
state_dict = {}

def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None:
if save_dtype is not None and v.dtype != save_dtype:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v

update_sd("", flux.state_dict())

save_file(state_dict, ckpt_path, metadata=sai_metadata)
if not use_mem_eff_save:
save_file(state_dict, ckpt_path, metadata=sai_metadata)
else:
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)


def save_flux_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype)
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)

train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)

Expand All @@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise(
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype)
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)

train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
Expand Down
75 changes: 74 additions & 1 deletion library/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import sys
import threading
from typing import *
import json
import struct

import torch
from torchvision import transforms
from typing import *
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
Expand Down Expand Up @@ -79,6 +82,76 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)


def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""

_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256

def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated

print(f"Using memory efficient save file: {filename}")

header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size

hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)

with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)

for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)


# TODO make inf_utils.py

Expand Down

0 comments on commit 486fe8f

Please sign in to comment.