diff --git a/.gitignore b/.gitignore index 68bc17f..309450b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +new_samples/ +*.ipynb +*.zip +*/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index 470fb9e..661b00c 100644 --- a/README.md +++ b/README.md @@ -33,11 +33,17 @@ Image restoration is a fundamental problem that involves recovering a high-quali ### TODO / News 🔥 -- [ ] Upload all test results for comparisons (ETA 1st Feb) -- [x] [Replicate Demo](https://replicate.com/mv-lab/instructir) -- [x] Upload models to HF 🤗 [(download the models here)](https://huggingface.co/marcosv/InstructIR) [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm.svg)](https://huggingface.co/marcosv/InstructIR) -- [x] 🤗 [Hugging Face Demo](https://huggingface.co/spaces/marcosv/InstructIR) try it now [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/marcosv/InstructIR) +- [ ] Upload Model weights and results for other InstructIR variants (3D, 5D). +- [x] [download all the test datasets](https://drive.google.com/file/d/11wGsKOMDVrBlsle4xtzORPLZAsGhel8c/view?usp=sharing) for all-in-one restoration. + +- [x] check the instructions below to run `eval_instructir.py` and get all the metrics and results for all-in-one restoration. + +- [x] You can download all the qualitative results here [instructir_results.zip](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) + +- [x] Upload models to HF 🤗 [(download the models here)](https://huggingface.co/marcosv/InstructIR) + +- [x] 🤗 [Hugging Face Demo](https://huggingface.co/spaces/marcosv/InstructIR) try it now - [x] [Google Colab Tutorial](https://colab.research.google.com/drive/1OrTvS-i6uLM2Y8kIkq8ZZRwEQxQFchfq?usp=sharing) (check [demo.ipynb](demo.ipynb)) @@ -51,33 +57,41 @@ Image restoration is a fundamental problem that involves recovering a high-quali InstructIR -### Gradio Demo -We made a simple [Gradio demo](app.py) you can run (locally) on your machine [here](app.py). You need Python>=3.9 and [these requirements](requirements_gradio.txt) for it: `pip install -r requirements_gradio.txt` +## Results + +Check `test.py` and `eval_instructir.py`. The following command provides all the metric for all the benchmarks using the pre-trained models in `models/`. The results from InstructIR are saved in the indicated folder `results/` ``` -python app.py +python eval_instructir.py --model models/im_instructir-7d.pt --lm models/lm_instructir-7d.pt --device 0 --config configs/eval5d.yml --save results/ ``` -
- -InstructIR Gradio - +An example of the output log is: +``` +>>> Eval on CBSD68_15 noise 0 +CBSD68_15_base 24.84328738380881 +CBSD68_15_psnr 33.98722295200123 68 +CBSD68_15_ssim 0.9315137801801457 -## Results +.... +``` + +You can **[download all the test datasets](https://drive.google.com/file/d/11wGsKOMDVrBlsle4xtzORPLZAsGhel8c/view?usp=sharing)**, and locate them in `test-data/`. Make sure the paths are updated in the config file `configs/eval5d.yml`. -You can download the paper results from here. We test InstructIR in the following benchmarks: +------- + +You can **[download all the paper results](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip)** -check releases-. We test InstructIR in the following benchmarks: | Dataset | Task | Test Results | | :---------------- | :------ | ----: | -| BSD68 | Denoising | [Download]() | -| Urban100 | Denoising | [Download]() | -| Rain100 | Deraining | [Download]() | -| [GoPro](https://seungjunnah.github.io/Datasets/gopro) | Deblurring | [Download]() | -| [LOL](https://daooshee.github.io/BMVC2018website/) | Lol Image Enhancement | [Download]() | -| [MIT5K](https://data.csail.mit.edu/graphics/fivek/) | Image Enhancement | [Download]() | +| BSD68 | Denoising | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) | +| Urban100 | Denoising | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) | +| Rain100 | Deraining | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) | +| [GoPro](https://seungjunnah.github.io/Datasets/gopro) | Deblurring | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) | +| [LOL](https://daooshee.github.io/BMVC2018website/) | Lol Image Enhancement | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) | +| [MIT5K](https://data.csail.mit.edu/graphics/fivek/) | Image Enhancement | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) | -TODO: Add download links for all the benchmarks. +In releases or clicking the link above you can download [instructir_results.zip](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) which includes all the qualitative results for those datasets [1.9 Gbs]. @@ -149,6 +163,18 @@ The final result looks indeed stunning 🤗 You can do it yourself in the [demo - ***Why aren't you using diffusion-based models?*** (1) We want to keep the solution simple and efficient. (2) Our priority is high-fidelity --as in many industry scenarios realted to computational photography--. +### Gradio Demo +We made a simple [Gradio demo](app.py) you can run (locally) on your machine [here](app.py). You need Python>=3.9 and [these requirements](requirements_gradio.txt) for it: `pip install -r requirements_gradio.txt` + +``` +python app.py +``` + +
+ +InstructIR Gradio + + ### Acknowledgments This work was partly supported by the The Humboldt Foundation (AvH). Marcos Conde is also supported by Sony Interactive Entertainment, FTG. diff --git a/configs/eval5d.yml b/configs/eval5d.yml index ec8651e..6c07e99 100644 --- a/configs/eval5d.yml +++ b/configs/eval5d.yml @@ -21,20 +21,18 @@ test: batch_size: 1 num_workers: 3 - dn_datapath: "data/denoising_testsets/" - dn_datasets: ["CBSD68", "urban100", "Kodak24", "McMaster"] + dn_datapath: "test-data/denoising_testsets/" + dn_datasets: ["CBSD68", "urban100", "Kodak24"] dn_sigmas: [15, 25, 50] - rain_targets: ["data/Rain/rain_test/Rain100L/target/"] - rain_inputs: ["data/Rain/rain_test/Rain100L/input/"] + rain_targets: ["test-data/Rain100L/target/"] + rain_inputs: ["test-data/Rain100L/input/"] - haze_targets: "data/SOTS-OUT/GT/" - haze_inputs : "data/SOTS-OUT/IN/" + haze_targets: "test-data/SOTS/GT/" + haze_inputs : "test-data/SOTS/IN/" - lol_targets: "data/LOL/eval15/high/" - lol_inputs : "data/LOL/eval15/low/" + lol_targets: "test-data/LOL/high/" + lol_inputs : "test-data/LOL/low/" - gopro_targets: "data/gopro_test/GoPro/target/" - gopro_inputs: "data/gopro_test/GoPro/input/" - - \ No newline at end of file + gopro_targets: "test-data/GoPro/target/" + gopro_inputs: "test-data/GoPro/input/" \ No newline at end of file diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..0521f14 --- /dev/null +++ b/datasets.py @@ -0,0 +1,211 @@ +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +import torchvision +import torchvision.transforms.functional as TF +import numpy as np +import json +import os +from glob import glob + +from utils import load_img, modcrop + + +DEG_MAP = { + "noise": 0, + "blur" : 1, + "rain" : 2, + "haze" : 3, + "lol" : 4, + "sr" : 5, + "en" : 6, +} + +DEG2TASK = { + "noise": "denoising", + "blur" : "deblurring", + "rain" : "deraining", + "haze" : "dehazing", + "lol" : "lol", + "sr" : "sr", + "en" : "enhancement" +} + +def augment_prompt (prompt): + ### special prompts + lol_prompts = ["fix the illumination", "increase the exposure of the photo", "the image is too dark to see anything, correct the photo", "poor illumination, improve the shot", "brighten dark regions", "make it HDR", "improve the light of the image", "Can you make the image brighter?"] + sr_prompts = ["I need to enhance the size and quality of this image.", "My photo is lacking size and clarity; can you improve it?", "I'd appreciate it if you could upscale this photo.", "My picture is too little, enlarge it.", "upsample this image", "increase the resolution of this photo", "increase the number of pixels", "upsample this photo", "Add details to this image", "improve the quality of this photo"] + en_prompts = ["make my image look like DSLR", "improve the colors of my image", "improve the contrast of this photo", "apply tonemapping", "enhance the colors of the image", "retouch the photo like a photograper"] + + init = np.random.choice(["Remove the", "Reduce the", "Clean the", "Fix the", "Remove", "Improve the", "Correct the",]) + end = np.random.choice(["please", "fast", "now", "in the photo", "in the picture", "in the image", ""]) + newp = f"{init} {prompt} {end}" + + if "lol" in prompt: + newp = np.random.choice(lol_prompts) + elif "sr" in prompt: + newp = np.random.choice(sr_prompts) + elif "en" in prompt: + newp = np.random.choice(en_prompts) + + newp = newp.strip().replace(" ", " ").replace("\n", "") + return newp + +def get_deg_name(path): + """ + Get the degradation name from the path + """ + + if ("gopro" in path) or ("GoPro" in path) or ("blur" in path) or ("Blur" in path) or ("RealBlur" in path): + return "blur" + elif ("SOTS" in path) or ("haze" in path) or ("sots" in path) or ("RESIDE" in path): + return "haze" + elif ("LOL" in path): + return "lol" + elif ("fiveK" in path): + return "en" + elif ("super" in path) or ("classicalSR" in path): + return "sr" + elif ("Rain100" in path) or ("rain13k" in path) or ("Rain13k" in path): + return "rain" + else: + return "noise" + +def crop_img(image, base=16): + """ + Mod crop the image to ensure the dimension is divisible by base. Also done by SwinIR, Restormer and others. + """ + h = image.shape[0] + w = image.shape[1] + crop_h = h % base + crop_w = w % base + return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] + + +################# DATASETS + + +class RefDegImage(Dataset): + """ + Dataset for Image Restoration having low-quality image and the reference image. + Tasks: synthetic denoising, deblurring, super-res, etc. + """ + + def __init__(self, hq_img_paths, lq_img_paths, augmentations=None, val=False, name="test", deg_name="noise", deg_class=0): + + assert len(hq_img_paths) == len(lq_img_paths) + + self.hq_paths = hq_img_paths + self.lq_paths = lq_img_paths + self.totensor = torchvision.transforms.ToTensor() + self.val = val + self.augs = augmentations + self.name = name + self.degradation = deg_name + self.deg_class = deg_class + + if self.val: + self.augs = None # No augmentations during validation/test + + def __len__(self): + return len(self.hq_paths) + + def __getitem__(self, idx): + hq_path = self.hq_paths[idx] + lq_path = self.lq_paths[idx] + + hq_image = load_img(hq_path) + lq_image = load_img(lq_path) + + if self.val: + # if an image has an odd number dimension we trim for example from [321, 189] to [320, 188]. + hq_image = crop_img(hq_image) + lq_image = crop_img(lq_image) + + hq_image = self.totensor(hq_image.astype(np.float32)) + lq_image = self.totensor(lq_image.astype(np.float32)) + + return hq_image, lq_image, hq_path + + + +def create_testsets (testsets, debug=False): + """ + Given a list of testsets create pytorch datasets for each. + The method requires the paths to references and noisy images. + """ + assert len(testsets) > 0 + + if debug: + print (20*'****') + print ("Creating Testsets", len(testsets)) + + datasets = [] + for testdt in testsets: + + path_hq , path_lq = testdt[0], testdt[1] + if debug: print (path_hq , path_lq) + + if ("denoising" in path_hq) or ("jpeg" in path_hq): + dataset_name = path_hq.split("/")[-1] + dataset_sigma = path_lq.split("/")[-1].split("_")[-1].split(".")[0] + dataset_name = dataset_name+ f"_{dataset_sigma}" + elif "Rain" in path_hq: + if "Rain100L" in path_hq: + dataset_name = "Rain100L" + else: + dataset_name = path_hq.split("/")[3] + + elif ("gopro" in path_hq) or ("GoPro" in path_hq): + dataset_name = "GoPro" + elif "LOL" in path_hq: + dataset_name = "LOL" + elif "SOTS" in path_hq: + dataset_name = "SOTS" + elif "fiveK" in path_hq: + dataset_name = "MIT5K" + else: + assert False, f"{path_hq} - unknown dataset" + + hq_img_paths = sorted(glob(os.path.join(path_hq, "*"))) + lq_img_paths = sorted(glob(os.path.join(path_lq, "*"))) + + if "SOTS" in path_hq: + # Haze removal SOTS test dataset + dataset_name = "SOTS" + hq_img_paths = sorted(glob(os.path.join(path_hq, "*.jpg"))) + assert len(hq_img_paths) == 500 + + lq_img_paths = [file.replace("GT", "IN") for file in hq_img_paths] + + if "fiveK" in path_hq: + dataset_name = "MIT5K" + testf = "test-data/mit5k/test.txt" + f = open(testf, "r") + test_ids = f.readlines() + test_ids = [x.strip() for x in test_ids] + f.close() + hq_img_paths = [os.path.join(path_hq, f"{x}.jpg") for x in test_ids] + lq_img_paths = [x.replace("expertC", "input") for x in hq_img_paths] + assert len(hq_img_paths) == 498 + + if "gopro" in path_hq: + assert len(hq_img_paths) == 1111 + + if "LOL" in path_hq: + assert len(hq_img_paths) == 15 + + assert len(hq_img_paths) == len(lq_img_paths) + + deg_name = get_deg_name(path_hq) + deg_class = DEG_MAP[deg_name] + + valdts = RefDegImage(hq_img_paths = hq_img_paths, + lq_img_paths = lq_img_paths, + val = True, name= dataset_name, deg_name=deg_name, deg_class=deg_class) + + datasets.append(valdts) + + assert len(datasets) == len(testsets) + print (20*'****') + + return datasets \ No newline at end of file diff --git a/eval_instructir.py b/eval_instructir.py new file mode 100644 index 0000000..8e65854 --- /dev/null +++ b/eval_instructir.py @@ -0,0 +1,204 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset +import torchvision +import torchvision.transforms.functional as TF + +import json +import os +from PIL import Image +import numpy as np +import matplotlib.pyplot as plt +import yaml +import random +import gc + +from utils import * +from models import instructir + +from text.models import LanguageModel, LMHead + +from test import test_model + + +def seed_everything(SEED=42): + random.seed(SEED) + np.random.seed(SEED) + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + torch.cuda.manual_seed_all(SEED) + torch.backends.cudnn.benchmark = True + + +if __name__=="__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='configs/eval5d.yml', help='Path to config file') + parser.add_argument('--model', type=str, default="models/im_instructir-7d.pt", help='Path to the image model weights') + parser.add_argument('--lm', type=str, default="models/lm_instructir-7d.pt", help='Path to the language model weights') + parser.add_argument('--promptify', type=str, default="simple_augment") + parser.add_argument('--device', type=int, default=0, help="GPU device") + parser.add_argument('--debug', action='store_true', help="Debug mode") + parser.add_argument('--save', type=str, default='results/', help="Path to save the resultant images") + args = parser.parse_args() + + SEED=42 + seed_everything(SEED=SEED) + torch.backends.cudnn.deterministic = True + + GPU = args.device + DEBUG = args.debug + MODEL_NAME = args.model + CONFIG = args.config + LM_MODEL = args.lm + SAVE_PATH = args.save + + print ('CUDA GPU available: ', torch.cuda.is_available()) + + torch.cuda.set_device(f'cuda:{GPU}') + device = torch.device(f'cuda:{GPU}' if torch.cuda.is_available() else "cpu") + print ('CUDA visible devices: ' + str(torch.cuda.device_count())) + print ('CUDA current device: ', torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device())) + + # parse config file + with open(os.path.join(CONFIG), "r") as f: + config = yaml.safe_load(f) + + cfg = dict2namespace(config) + + + print (20*"****") + print ("EVALUATION") + print (MODEL_NAME, LM_MODEL, device, DEBUG, CONFIG, args.promptify) + print (20*"****") + + ################### TESTING DATASET + + TESTSETS = [] + dn_testsets = [] + rain_testsets = [] + + # Denoising + try: + for testset in cfg.test.dn_datasets: + for sigma in cfg.test.dn_sigmas: + noisy_testpath = os.path.join(cfg.test.dn_datapath, testset+ f"_{sigma}") + clean_testpath = os.path.join(cfg.test.dn_datapath, testset) + #print (clean_testpath, noisy_testpath) + dn_testsets.append([clean_testpath, noisy_testpath]) + except: + dn_testsets = [] + + # RAIN + try: + for noisy_testpath, clean_testpath in zip(cfg.test.rain_inputs, cfg.test.rain_targets): + rain_testsets.append([clean_testpath, noisy_testpath]) + except: + rain_testsets = [] + + # HAZE + try: + haze_testsets = [[cfg.test.haze_targets, cfg.test.haze_inputs]] + except: + haze_testsets = [] + + # BLUR + try: + blur_testsets = [[cfg.test.gopro_targets, cfg.test.gopro_inputs]] + except: + blur_testsets = [] + + # LOL + try: + lol_testsets = [[cfg.test.lol_targets, cfg.test.lol_inputs]] + except: + lol_testsets = [] + + # MIT5K + try: + mit_testsets = [[cfg.test.mit_targets, cfg.test.mit_inputs]] + except: + mit_testsets = [] + + TESTSETS += dn_testsets + TESTSETS += rain_testsets + TESTSETS += haze_testsets + TESTSETS += blur_testsets + TESTSETS += lol_testsets + TESTSETS += mit_testsets + + # print ("Tests:", TESTSETS) + print ("TOTAL TESTSET:", len(TESTSETS)) + print (20 * "----") + + + ################### RESTORATION MODEL + + print ("Creating InstructIR") + model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks, + middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim) + + ################### LOAD IMAGE MODEL + + assert MODEL_NAME, "Model weights required for evaluation" + + print ("IMAGE MODEL CKPT:", MODEL_NAME) + model.load_state_dict(torch.load(MODEL_NAME), strict=True) + + model = model.to(device) + + nparams = count_params (model) + print ("Loaded weights!", nparams / 1e6) + + ################### LANGUAGE MODEL + + try: + PROMPT_DB = cfg.llm.text_db + except: + PROMPT_DB = None + + if cfg.model.use_text: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Initialize the LanguageModel class + LMODEL = cfg.llm.model + language_model = LanguageModel(model=LMODEL) + lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses) + lm_head = lm_head.to(device) + lm_nparams = count_params (lm_head) + + print ("LMHEAD MODEL CKPT:", LM_MODEL) + lm_head.load_state_dict(torch.load(LM_MODEL), strict=True) + print ("Loaded weights!") + + else: + LMODEL = None + language_model = None + lm_head = None + lm_nparams = 0 + + print (20 * "----") + + ################### TESTING !! + + from datasets import RefDegImage, augment_prompt, create_testsets + + if args.promptify == "simple_augment": + promptify = augment_prompt + elif args.promptify == "chatgpt": + prompts = json.load(open(cfg.llm.text_db)) + def promptify(deg): + + return np.random.choice(prompts[deg]) + else: + def promptify(deg): + return args.promptify + + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + test_datasets = create_testsets(TESTSETS, debug=True) + + test_model (model, language_model, lm_head, test_datasets, device, promptify, savepath=SAVE_PATH) + \ No newline at end of file diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..58847b4 --- /dev/null +++ b/metrics.py @@ -0,0 +1,278 @@ +import numpy as np +import math +import cv2 +import torch + + +def np_psnr(y_true, y_pred, maxval=1.): + mse = np.mean((y_true - y_pred) ** 2) + if(mse == 0): + return np.inf + + psnr = 20 * np.log10(maxval / np.sqrt(mse)) + return psnr + +def pt_psnr(y_true, y_pred, maxval=1.): + mse = torch.mean((y_true - y_pred) ** 2, dim=(1, 2, 3)) + psnr = 20 * torch.log10(maxval / torch.sqrt(mse)) + return psnr.unsqueeze(1) + + +############# SWINIR METRICS +# https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/utils/util_calculate_psnr_ssim.py#L243 + + +def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2) ** 2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def _ssim(img1, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img1.shape[2]): + ssims.append(_ssim(img1[..., i], img2[..., i])) + return np.array(ssims).mean() + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + if np.max(img) > 1.: + img = img.astype(np.float32) / 255. + + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + diff --git a/results/.gitkeep b/results/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test-data/.gitkeep b/test-data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test.py b/test.py new file mode 100644 index 0000000..d6fa984 --- /dev/null +++ b/test.py @@ -0,0 +1,100 @@ +import os +import gc +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +from metrics import pt_psnr, calculate_ssim, calculate_psnr +from pytorch_msssim import ssim +from utils import save_rgb + + +def test_model (model, language_model, lm_head, testsets, device, promptify, savepath="results/"): + + model.eval() + if language_model: + language_model.eval() + lm_head.eval() + + DEG_ACC = [] + derain_datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800'] + + with torch.no_grad(): + + for testset in testsets: + + if savepath: + dt_results_path = os.path.join(savepath, testset.name) + if not os.path.exists(dt_results_path): + os.mkdir(dt_results_path) + + print (">>> Eval on", testset.name, testset.degradation, testset.deg_class) + + testset_name = testset.name + test_dataloader = DataLoader(testset, batch_size=1, num_workers=4, drop_last=True, shuffle=False) + psnr_dataset = [] + ssim_dataset = [] + psnr_noisy = [] + use_y_channel= False + + if testset.name in derain_datasets: + use_y_channel = True + psnr_y_dataset = [] + ssim_y_dataset = [] + + for idx, batch in enumerate(test_dataloader): + + x = batch[0].to(device) # HQ image + y = batch[1].to(device) # LQ image + f = batch[2][0] # filename + t = [promptify(testset.degradation) for _ in range(x.shape[0])] + + if language_model: + if idx < 5: + # print the input prompt for debugging + print("\tInput prompt:", t) + + lm_embd = language_model(t) + lm_embd = lm_embd.to(device) + text_embd, deg_pred = lm_head (lm_embd) + + x_hat = model(y, text_embd) + + psnr_restore = torch.mean(pt_psnr(x, x_hat)) + psnr_dataset.append(psnr_restore.item()) + ssim_restore = ssim(x, x_hat, data_range=1., size_average=True) + ssim_dataset.append(ssim_restore.item()) + psnr_base = torch.mean(pt_psnr(x, y)) + psnr_noisy.append(psnr_base.item()) + + if use_y_channel: + _x_hat = np.clip(x_hat[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32) + _x = np.clip(x[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32) + _x_hat = (_x_hat*255).astype(np.uint8) + _x = (_x*255).astype(np.uint8) + + psnr_y = calculate_psnr(_x, _x_hat, crop_border=0, input_order='HWC', test_y_channel=True) + ssim_y = calculate_ssim(_x, _x_hat, crop_border=0, input_order='HWC', test_y_channel=True) + psnr_y_dataset.append(psnr_y) + ssim_y_dataset.append(ssim_y) + + ## SAVE RESULTS + if savepath: + restored_img = np.clip(x_hat[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32) + img_name = f.split("/")[-1] + save_rgb (restored_img, os.path.join(dt_results_path, img_name)) + + + print(f"{testset_name}_base", np.mean(psnr_noisy), "Total images:", len(psnr_dataset)) + print(f"{testset_name}_psnr", np.mean(psnr_dataset)) + print(f"{testset_name}_ssim", np.mean(ssim_dataset)) + if use_y_channel: + print(f"{testset_name}_psnr-Y", np.mean(psnr_y_dataset), len(psnr_y_dataset)) + print(f"{testset_name}_ssim-Y", np.mean(ssim_y_dataset)) + + print (); print (25 * "***") + + del test_dataloader,psnr_dataset, psnr_noisy; gc.collect() + + + # END OF FUNCTION \ No newline at end of file diff --git a/utils.py b/utils.py index b9dc643..e0794d5 100644 --- a/utils.py +++ b/utils.py @@ -48,6 +48,20 @@ def plot_all (images, figsize=(20,10), axis='off', names=None): axs[i].axis(axis) plt.show() +def modcrop(img_in, scale=2): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img def dict2namespace(config): namespace = argparse.Namespace()