Skip to content

Commit

Permalink
Merge pull request #1426 from sdbds/resize
Browse files Browse the repository at this point in the history
Replacing CV2 resize to Pil resize
  • Loading branch information
kohya-ss committed Sep 7, 2024
2 parents 319e4d9 + 9ca7a5b commit 16bb569
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 15 deletions.
8 changes: 5 additions & 3 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

import library.train_util as train_util
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize

setup_logging()
import logging
Expand Down Expand Up @@ -42,8 +42,10 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))

image = image.astype(np.float32)
return image
Expand Down
11 changes: 6 additions & 5 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize

setup_logging()
import logging
Expand Down Expand Up @@ -2094,9 +2094,7 @@ def __getitem__(self, index):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = cv2.resize(
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
)
cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0])))

if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
Expand Down Expand Up @@ -2459,7 +2457,10 @@ def trim_and_resize_if_required(

if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
if image_width > resized_size[0] and image_height > resized_size[1]:
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)

image_height, image_width = image.shape[0:2]

Expand Down
14 changes: 13 additions & 1 deletion library/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput

import cv2
from PIL import Image
import numpy as np

def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
Expand Down Expand Up @@ -78,7 +80,17 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)

def pil_resize(image, size, interpolation=Image.LANCZOS):

pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

# use Pillow resize
resized_pil = pil_image.resize(size, interpolation)

# return cv2 image
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)

return resized_cv2

# TODO make inf_utils.py

Expand Down
7 changes: 5 additions & 2 deletions tools/detect_face_rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -172,7 +172,10 @@ def process(args):
if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)
Expand Down
11 changes: 7 additions & 4 deletions tools/resize_images_to_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
Expand All @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi

# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA

Expand Down Expand Up @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_width = int(img.shape[1] * math.sqrt(scale_factor))

# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
else:
new_height, new_width = img.shape[0:2]

Expand Down

0 comments on commit 16bb569

Please sign in to comment.