Skip to content

Commit

Permalink
Merge branch 'dev' into sd3
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Sep 7, 2024
2 parents 2889108 + 62ec3e6 commit ce14447
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/typos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4

- name: typos-action
uses: crate-ci/typos@v1.21.0
uses: crate-ci/typos@v1.24.3
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

### Working in progress

- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds!

- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v!

- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!

- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available.
Expand Down
8 changes: 5 additions & 3 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

import library.train_util as train_util
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize

setup_logging()
import logging
Expand Down Expand Up @@ -42,8 +42,10 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))

image = image.astype(np.float32)
return image
Expand Down
26 changes: 15 additions & 11 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize

setup_logging()
import logging
Expand Down Expand Up @@ -1708,7 +1708,7 @@ def read_caption(img_path, caption_extension, enable_wildcard):
def load_dreambooth_dir(subset: DreamBoothSubset):
if not os.path.isdir(subset.image_dir):
logger.warning(f"not directory: {subset.image_dir}")
return [], []
return [], [], []

info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
use_cached_info_for_subset = subset.cache_info
Expand Down Expand Up @@ -2263,9 +2263,7 @@ def __getitem__(self, index):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = cv2.resize(
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
)
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))

if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
Expand Down Expand Up @@ -2659,7 +2657,10 @@ def trim_and_resize_if_required(

if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
if image_width > resized_size[0] and image_height > resized_size[1]:
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)

image_height, image_width = image.shape[0:2]

Expand Down Expand Up @@ -5657,7 +5658,7 @@ def sample_images_common(
clean_memory_on_device(accelerator.device)

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)

Expand Down Expand Up @@ -5691,11 +5692,13 @@ def sample_image_inference(

if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if torch.cuda.is_available():
torch.cuda.seed()

scheduler = get_my_scheduler(
sample_sampler=sampler_name,
Expand Down Expand Up @@ -5730,8 +5733,9 @@ def sample_image_inference(
controlnet_image=controlnet_image,
)

with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
if torch.cuda.is_available():
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()

image = pipeline.latents_to_image(latents)[0]

Expand Down
14 changes: 14 additions & 0 deletions library/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
import cv2
from PIL import Image
import numpy as np


def fire_in_thread(f, *args, **kwargs):
Expand Down Expand Up @@ -301,6 +304,17 @@ def _convert_float8(byte_tensor, dtype_str, shape):
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")

def pil_resize(image, size, interpolation=Image.LANCZOS):
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

# use Pillow resize
resized_pil = pil_image.resize(size, interpolation)

# return cv2 image
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)

return resized_cv2


# TODO make inf_utils.py

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ transformers==4.44.0
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.43.3
Expand Down
7 changes: 5 additions & 2 deletions tools/detect_face_rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -172,7 +172,10 @@ def process(args):
if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)
Expand Down
11 changes: 7 additions & 4 deletions tools/resize_images_to_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
Expand All @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi

# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA

Expand Down Expand Up @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_width = int(img.shape[1] * math.sqrt(scale_factor))

# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
else:
new_height, new_width = img.shape[0:2]

Expand Down

0 comments on commit ce14447

Please sign in to comment.