From 87526942a67fd71bb775bc479b0a7449df516dd8 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 12 Jul 2024 22:56:38 +0800 Subject: [PATCH 1/8] judge image size for using diff interpolation --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..74720fec6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) image_height, image_width = image.shape[0:2] From 2e67978ee243a20f169ce76d7644bb1f9dec9bad Mon Sep 17 00:00:00 2001 From: Millie Date: Thu, 18 Jul 2024 11:52:58 -0700 Subject: [PATCH 2/8] Generate sample images without having CUDA (such as on Macs) --- library/train_util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..9b0397d7d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5229,7 +5229,7 @@ def sample_images_common( clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) - if cuda_rng_state is not None: + if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) @@ -5263,11 +5263,13 @@ def sample_image_inference( if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) else: # True random sample image generation torch.seed() - torch.cuda.seed() + if torch.cuda.is_available(): + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -5302,8 +5304,9 @@ def sample_image_inference( controlnet_image=controlnet_image, ) - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] From 1f16b80e88b1c4f05d49b4fc328d3b9b105ebcbe Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:35:24 +0800 Subject: [PATCH 3/8] Revert "judge image size for using diff interpolation" This reverts commit 87526942a67fd71bb775bc479b0a7449df516dd8. --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 74720fec6..15c23f3cc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ image_height, image_width = image.shape[0:2] From 9ca7a5b6cc99e25820a1aa6d02a779004d73bca0 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:59:11 +0800 Subject: [PATCH 4/8] instead cv2 LANCZOS4 resize to pil resize --- finetune/tag_images_by_wd14_tagger.py | 8 +++++--- library/train_util.py | 11 ++++++----- library/utils.py | 14 +++++++++++++- tools/detect_face_rotate.py | 7 +++++-- tools/resize_images_to_resolution.py | 11 +++++++---- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd61..6f5bdd36b 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -42,8 +42,10 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + if size > IMAGE_SIZE: + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) + else: + image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) image = image.astype(np.float32) return image diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..160e3b44b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -71,7 +71,7 @@ import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -2028,9 +2028,7 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = cv2.resize( - cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 - ) + cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2362,7 +2360,10 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if image_width > resized_size[0] and image_height > resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + else: + image = pil_resize(image, resized_size) image_height, image_width = image.shape[0:2] diff --git a/library/utils.py b/library/utils.py index 3037c055d..a219f6cb7 100644 --- a/library/utils.py +++ b/library/utils.py @@ -7,7 +7,9 @@ from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - +import cv2 +from PIL import Image +import numpy as np def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -78,7 +80,17 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) +def pil_resize(image, size, interpolation=Image.LANCZOS): + + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # use Pillow resize + resized_pil = pil_image.resize(size, interpolation) + + # return cv2 image + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + return resized_cv2 # TODO make inf_utils.py diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index bbc643edc..d2a4d9cfb 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -172,7 +172,10 @@ def process(args): if scale != 1.0: w = int(w * scale + .5) h = int(h * scale + .5) - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + if scale < 1.0: + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) + else: + face_img = pil_resize(face_img, (w, h)) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index b8069fc1d..0f9e00b1e 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import math from PIL import Image import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Select interpolation method if interpolation == 'lanczos4': - cv2_interpolation = cv2.INTER_LANCZOS4 + pil_interpolation = Image.LANCZOS elif interpolation == 'cubic': - cv2_interpolation = cv2.INTER_CUBIC + pil_interpolation = Image.BICUBIC else: cv2_interpolation = cv2.INTER_AREA @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_width = int(img.shape[1] * math.sqrt(scale_factor)) # Resize image - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + if cv2_interpolation: + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) else: new_height, new_width = img.shape[0:2] From 2a3aefb4e44dce1f189677d0a996ba0244633956 Mon Sep 17 00:00:00 2001 From: Nando Metzger <42088121+nandometzger@users.noreply.github.com> Date: Fri, 30 Aug 2024 08:15:05 +0200 Subject: [PATCH 5/8] Update train_util.py, bug fix --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..0fec565db 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1489,7 +1489,7 @@ def read_caption(img_path, caption_extension, enable_wildcard): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): logger.warning(f"not directory: {subset.image_dir}") - return [], [] + return [], [], [] info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info From 3a6154b7b0dbcae82d24adacf5a76f75288b98f4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 31 Aug 2024 06:21:16 +0000 Subject: [PATCH 6/8] Bump opencv-python from 4.7.0.68 to 4.8.1.78 Bumps [opencv-python](https://github.com/opencv/opencv-python) from 4.7.0.68 to 4.8.1.78. - [Release notes](https://github.com/opencv/opencv-python/releases) - [Commits](https://github.com/opencv/opencv-python/commits) --- updated-dependencies: - dependency-name: opencv-python dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e99775b8a..977c5cd91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers==4.36.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 -opencv-python==4.7.0.68 +opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.43.0 From 1bcf8d600bfb9f4314a41a12a5e7b272a17ceaed Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 01:33:04 +0000 Subject: [PATCH 7/8] Bump crate-ci/typos from 1.19.0 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.19.0 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.19.0...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/typos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index e8b06483f..0149dcdd3 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.24.3 From 0005867ba509d2e1a5674b267e8286b561c0ed71 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Sep 2024 10:45:18 +0900 Subject: [PATCH 8/8] update README, format code --- README.md | 5 +++++ library/train_util.py | 4 ++-- library/utils.py | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 81a549378..16ab80e7a 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! + +- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! + - `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! + - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. diff --git a/library/train_util.py b/library/train_util.py index 102d39ed7..1441e74f6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2094,7 +2094,7 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) + cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2432,7 +2432,7 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index a219f6cb7..5b7e657b2 100644 --- a/library/utils.py +++ b/library/utils.py @@ -11,6 +11,7 @@ from PIL import Image import numpy as np + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -80,8 +81,8 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) -def pil_resize(image, size, interpolation=Image.LANCZOS): +def pil_resize(image, size, interpolation=Image.LANCZOS): pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # use Pillow resize @@ -92,6 +93,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 + # TODO make inf_utils.py