Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-node environment training and accelerator related codes + skip file check option #1246

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ venv
build
.vscode
wandb
venv*
28 changes: 28 additions & 0 deletions finetune/merge_jsons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# simple merge of json files

import argparse
import json
import glob
from tqdm import tqdm

def main(args):
json_files = glob.glob(args.jsons)
merged = {}
for json_file in json_files:
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
for key, values in tqdm(data.items()):
if key not in merged:
merged[key] = data[key]
elif "train_resolution" in data[key]:
merged[key].update(data[key])
with open(args.out_json, 'w', encoding='utf-8') as f:
json.dump(merged, f, separators=(',', ':'), ensure_ascii=False)

if __name__ == "__main__":
#python finetune/merge_jsons.py --jsons ${JSON_RESULT_PATH}_*.json --out_json ${JSON_RESULT_PATH}.json
parser = argparse.ArgumentParser()
parser.add_argument("--jsons", type=str, help="json files to merge / マージするjsonファイル")
parser.add_argument("--out_json", type=str, help="output json file / 出力jsonファイル")
args = parser.parse_args()
main(args)
31 changes: 31 additions & 0 deletions finetune/multi_gpu_prepare.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash
#SBATCH --job-name=cache_latents
#SBATCH --output=O-%x.%j
#SBATCH --error=E-%x.%j
#SBATCH --partition=slurm_rtx4090
#SBATCH --cpus-per-gpu=16
#SBATCH --nodes=1 # number of nodes
#SBATCH --gres=gpu:1 # number of GPUs per node
#SBATCH --time=72:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --qos=gpu_qos

#unset LD_LIBRARY_PATH # if you have problems with CUDA kernels, let python use installed venv/conda environment's kernel instead.
TRAIN_DATA_DIR=/dataset/path/to/data
JSON_RESULT_PATH=/json/to/save/name
#cd sd-scripts # please CD to your installed script directory

TOTAL_SPLIT=8 # numbers to split, this should match with CUDA_DEVICES_NUM
OFFSET=0 # this is for offsetting, ping me if you need multi-node caching
CUDA_DEVICES_NUM=8 # numbers of GPUs
PYTHON=venv/bin/python # set python
MODEL_PATH=model.safetensors
for i in $(seq 0 $((CUDA_DEVICES_NUM-1)))
do
INDEX=$((i+OFFSET))
CUDA_VISIBLE_DEVICES=$i PYTHON finetune/prepare_buckets_latents.py --out_json ${JSON_RESULT_PATH}_${INDEX}.json --split_dataset --n_split $TOTAL_SPLIT --current_index $INDEX --model_name_or_path $MODEL_PATH --max_resolution "1024,1024" --max_bucket_reso 4096 --full_path --recursive --train_data_dir $TRAIN_DATA_DIR &
done

wait

# merge jsons
PYTHON finetune/merge_jsons.py --jsons "${JSON_RESULT_PATH}_*.json" --out_json ${JSON_RESULT_PATH}.json
352 changes: 352 additions & 0 deletions finetune/prepare_buckets_latents_separate.py

Large diffs are not rendered by default.

57 changes: 28 additions & 29 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,35 @@

def load_target_model(args, accelerator, model_version: str, weight_dtype):
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = _load_target_model(
args.pretrained_model_name_or_path,
args.vae,
model_version,
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype,
)

# work on low-ram device
if args.lowram:
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
logger.info(f"loading model for process {accelerator.state.local_process_index} {accelerator.state.process_index} /{accelerator.state.num_processes}")
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = _load_target_model(
args.pretrained_model_name_or_path,
args.vae,
model_version,
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype,
)

clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# work on low-ram device
if args.lowram:
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

clean_memory_on_device(accelerator.device)
logger.info(f"model loaded for process {accelerator.state.local_process_index} {accelerator.state.process_index} /{accelerator.state.num_processes}")
accelerator.wait_for_everyone()
logger.info(f"model loaded for all processes {accelerator.state.local_process_index} {accelerator.state.process_index} /{accelerator.state.num_processes}")

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

Expand Down
135 changes: 89 additions & 46 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,19 @@
)

TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
SKIP_PATH_CHECK = False

def set_skip_path_check(skip: bool):
global SKIP_PATH_CHECK
SKIP_PATH_CHECK = skip

def os_path_exists(path):
"""
Check if the path exists. Skips if
"""
if SKIP_PATH_CHECK: # this is necessary for NFS systems
return True
return os.path.exists(path)

class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
Expand Down Expand Up @@ -1046,7 +1058,7 @@ def cache_text_encoder_outputs(
if not is_main_process: # store to info only
continue

if os.path.exists(te_out_npz):
if os_path_exists(te_out_npz):
continue

image_infos_to_cache.append(info)
Expand Down Expand Up @@ -1689,12 +1701,12 @@ def __init__(
continue

tags_list = []
for image_key, img_md in metadata.items():
for image_key, img_md in tqdm(metadata.items(), desc=f"load metadata: {subset.metadata_file}"):
# path情報を作る
abs_path = None

# まず画像を優先して探す
if os.path.exists(image_key):
if os_path_exists(image_key):
abs_path = image_key
else:
# わりといい加減だがいい方法が思いつかん
Expand All @@ -1704,11 +1716,11 @@ def __init__(

# なければnpzを探す
if abs_path is None:
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
if os_path_exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
if os_path_exists(npz_path):
abs_path = npz_path

assert abs_path is not None, f"no image / 画像がありません: {image_key}"
Expand Down Expand Up @@ -1839,10 +1851,10 @@ def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + ".npz"

if os.path.exists(npz_file_norm):
if os_path_exists(npz_file_norm):
# image_key is full path
npz_file_flip = base_name + "_flip.npz"
if not os.path.exists(npz_file_flip):
if not os_path_exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip

Expand All @@ -1854,10 +1866,10 @@ def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")

if not os.path.exists(npz_file_norm):
if not os_path_exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
elif not os_path_exists(npz_file_flip):
npz_file_flip = None

return npz_file_norm, npz_file_flip
Expand Down Expand Up @@ -2120,7 +2132,7 @@ def disable_token_padding(self):
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意

if not os.path.exists(npz_path):
if not os_path_exists(npz_path):
return False

npz = np.load(npz_path)
Expand Down Expand Up @@ -3087,7 +3099,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
)
parser.add_argument(
"--gradient_accumulation_steps",
Expand Down Expand Up @@ -3696,6 +3708,12 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)",
)

def add_skip_check_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--skip_file_existence_check",
action="store_true",
help="skip check for images and latents existence, useful if your storage has low random access speed / 画像とlatentの存在チェックをスキップする。ストレージのランダムアクセス速度が遅い場合に有用",
)

def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
if not args.config_file:
Expand Down Expand Up @@ -4274,6 +4292,7 @@ def prepare_accelerator(args: argparse.Namespace):
"logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
)
if log_with in ["wandb", "all"]:
os.environ['WANDB__SERVICE_WAIT'] = '300'
try:
import wandb
except ImportError:
Expand Down Expand Up @@ -4380,24 +4399,22 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une


def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args,
weight_dtype,
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}, {accelerator.process_index}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args,
weight_dtype,
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

clean_memory_on_device(accelerator.device)
logger.info(f"Model loaded for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}, {accelerator.process_index}")
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format


Expand Down Expand Up @@ -5198,14 +5215,14 @@ def sample_images_common(
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass

image_paths = []
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
sample_image_inference(
image_paths = [sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
)]
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
Expand All @@ -5216,9 +5233,31 @@ def sample_images_common(
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
image_paths += [sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)]
accelerator.wait_for_everyone()
# if not main process, return
if accelerator.is_main_process:
try:
import wandb
logger.info(image_paths)
wandb_logger = accelerator.get_tracker("wandb")
# parse base filename without ext from first image path
for image_path_saved in get_all_paths_like_imagepaths_by_time(image_paths[0]):
# 0327_bs768_lion_highres_focus_fixxl4_000020_13_20240329061413_42
# get 13
file_basename = os.path.basename(image_path_saved).split(".")[0]
sample_idx = int(file_basename.split("_")[-3])
logger.info(f"sample_idx: {sample_idx} -> {image_path_saved}")
wandb_logger.log(
{f"sample_{sample_idx}" : wandb.Image(Image.open(image_path_saved))},
commit=False,
step=steps,
)
except Exception as e:
logger.warn(e)
pass

# clear pipeline and cache to reduce vram usage
del pipeline
Expand All @@ -5233,6 +5272,22 @@ def sample_images_common(
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)

def get_all_paths_like_imagepaths_by_time(image_path):
file_basename = os.path.basename(image_path).split(".")[0]
timestamp_str = file_basename.split("_")[-2]
original_timestamp = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M%S")

front_fixed_part = "_".join(file_basename.split("_")[:-3])

for root, dirs, files in os.walk(os.path.dirname(image_path)):
for file in files:
front_fixed_part = "_".join(file_basename.split("_")[:-3])
if front_fixed_part in file:
timestamp_str = file.split("_")[-2]
timestamp = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M%S")
# allow 60-second difference
if abs((timestamp - original_timestamp).total_seconds()) < 60:
yield os.path.join(root, file)

def sample_image_inference(
accelerator: Accelerator,
Expand Down Expand Up @@ -5316,19 +5371,7 @@ def sample_image_inference(
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass

return os.path.join(save_dir, img_filename)

# endregion

Expand Down
Loading
Loading