Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for batched inference of images of same size #814

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 95 additions & 77 deletions realesrgan/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import cv2
import math
import numpy as np
import os
import queue
import threading
from typing import Union

import cv2
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
from torch.nn import functional as F
Expand Down Expand Up @@ -88,30 +90,35 @@ def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
img = [torch.from_numpy(np.transpose(i, (2, 0, 1))).float() for i in img]
self.img = [i.unsqueeze(0).to(self.device) for i in img]
if self.half:
self.img = self.img.half()
self.img = [i.half() for i in self.img]

# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
self.img = [F.pad(i, (0, self.pre_pad, 0, self.pre_pad), 'reflect') for i in self.img]
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
padded_imgs = []
for im in self.img:
_, _, h, w = im.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
im = F.pad(im, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
padded_imgs.append(im)
self.img = padded_imgs

def process(self):
# model inference
self.img = torch.cat(self.img, dim=0)
self.output = self.model(self.img)

def tile_process(self):
Expand Down Expand Up @@ -176,8 +183,8 @@ def tile_process(self):

# put tile into output image
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]

def post_process(self):
# remove extra pad
Expand All @@ -191,76 +198,87 @@ def post_process(self):
return self.output

@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print('\tInput is a 16-bit image')
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = 'L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == 'realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = 'RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def enhance(self, img: Union[np.ndarray, list[np.ndarray]], outscale=None, alpha_upsampler='realesrgan'):
if isinstance(img, np.ndarray): # bs=1
img = [img]
h_input = [i.shape[0] for i in img]
w_input = [i.shape[1] for i in img]
img = [i.astype(np.float32) for i in img]
max_range = [65535 if np.max(i) > 256 else 255 for i in img]
if any(i > 256 for i in max_range):
print('\tInput contains 16-bit images')

img = [i / m_range for i, m_range in zip(img, max_range)]

img_modes = []
for idx, im in enumerate(img):
if len(im.shape) == 2: # gray image
img_mode = 'L'
img[idx] = cv2.cvtColor(im, cv2.COLOR_GRAY2RGB)
elif im.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = im[:, :, 3]
img = im[:, :, 0:3]
img[idx] = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if alpha_upsampler == 'realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = 'RGB'
img[idx] = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

img_modes.append(img_mode)

# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)

# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == 'RGBA':
if alpha_upsampler == 'realesrgan':
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)

# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha

# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)

if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output, (
int(w_input * outscale),
int(h_input * outscale),
), interpolation=cv2.INTER_LANCZOS4)

return output, img_mode
output_imgs = self.post_process()
output_imgs = output_imgs.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_imgs = [o for o in output_imgs]

final_results = []
for output_img, img_mode, max_r, h_i, w_i in zip(output_imgs, img_modes, max_range, h_input, w_input):
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)

# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == 'RGBA':
if alpha_upsampler == 'realesrgan':
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)

# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha

# ------------------------------ return ------------------------------ #
if max_r == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)

if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output, (
int(w_i * outscale),
int(h_i * outscale),
), interpolation=cv2.INTER_LANCZOS4)

final_results.append((output, img_mode))
return zip(*final_results)


class PrefetchReader(threading.Thread):
Expand Down