Skip to content

Commit

Permalink
add test & evaluation benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
mv-lab committed Feb 23, 2024
1 parent d65ff5c commit d5417e4
Show file tree
Hide file tree
Showing 10 changed files with 868 additions and 32 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
new_samples/
*.ipynb
*.zip
*/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
66 changes: 46 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,17 @@ Image restoration is a fundamental problem that involves recovering a high-quali

### TODO / News 🔥

- [ ] Upload all test results for comparisons (ETA 1st Feb)
- [x] [Replicate Demo](https://replicate.com/mv-lab/instructir)
- [x] Upload models to HF 🤗 [(download the models here)](https://huggingface.co/marcosv/InstructIR) [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm.svg)](https://huggingface.co/marcosv/InstructIR)
- [x] 🤗 [Hugging Face Demo](https://huggingface.co/spaces/marcosv/InstructIR) try it now [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/marcosv/InstructIR)
- [ ] Upload Model weights and results for other InstructIR variants (3D, 5D).

- [x] [download all the test datasets](https://drive.google.com/file/d/11wGsKOMDVrBlsle4xtzORPLZAsGhel8c/view?usp=sharing) for all-in-one restoration.

- [x] check the instructions below to run `eval_instructir.py` and get all the metrics and results for all-in-one restoration.

- [x] You can download all the qualitative results here [instructir_results.zip](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip)

- [x] Upload models to HF 🤗 [(download the models here)](https://huggingface.co/marcosv/InstructIR)

- [x] 🤗 [Hugging Face Demo](https://huggingface.co/spaces/marcosv/InstructIR) try it now

- [x] [Google Colab Tutorial](https://colab.research.google.com/drive/1OrTvS-i6uLM2Y8kIkq8ZZRwEQxQFchfq?usp=sharing) (check [demo.ipynb](demo.ipynb))

Expand All @@ -51,33 +57,41 @@ Image restoration is a fundamental problem that involves recovering a high-quali

<a href="https://mv-lab.github.io/InstructIR/"><img src="images/instructir_teaser.png" alt="InstructIR" width=100%></a>

### Gradio Demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
We made a simple [Gradio demo](app.py) you can run (locally) on your machine [here](app.py). You need Python>=3.9 and [these requirements](requirements_gradio.txt) for it: `pip install -r requirements_gradio.txt`
## Results

Check `test.py` and `eval_instructir.py`. The following command provides all the metric for all the benchmarks using the pre-trained models in `models/`. The results from InstructIR are saved in the indicated folder `results/`

```
python app.py
python eval_instructir.py --model models/im_instructir-7d.pt --lm models/lm_instructir-7d.pt --device 0 --config configs/eval5d.yml --save results/
```

<br>
<a href="https://huggingface.co/spaces/marcosv/InstructIR">
<img src="images/gradio.png" alt="InstructIR Gradio">
</a>
An example of the output log is:

```
>>> Eval on CBSD68_15 noise 0
CBSD68_15_base 24.84328738380881
CBSD68_15_psnr 33.98722295200123 68
CBSD68_15_ssim 0.9315137801801457
## Results
....
```

You can **[download all the test datasets](https://drive.google.com/file/d/11wGsKOMDVrBlsle4xtzORPLZAsGhel8c/view?usp=sharing)**, and locate them in `test-data/`. Make sure the paths are updated in the config file `configs/eval5d.yml`.

You can download the paper results from here. We test InstructIR in the following benchmarks:
-------

You can **[download all the paper results](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip)** -check releases-. We test InstructIR in the following benchmarks:

| Dataset | Task | Test Results |
| :---------------- | :------ | ----: |
| BSD68 | Denoising | [Download]() |
| Urban100 | Denoising | [Download]() |
| Rain100 | Deraining | [Download]() |
| [GoPro](https://seungjunnah.github.io/Datasets/gopro) | Deblurring | [Download]() |
| [LOL](https://daooshee.github.io/BMVC2018website/) | Lol Image Enhancement | [Download]() |
| [MIT5K](https://data.csail.mit.edu/graphics/fivek/) | Image Enhancement | [Download]() |
| BSD68 | Denoising | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) |
| Urban100 | Denoising | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) |
| Rain100 | Deraining | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) |
| [GoPro](https://seungjunnah.github.io/Datasets/gopro) | Deblurring | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) |
| [LOL](https://daooshee.github.io/BMVC2018website/) | Lol Image Enhancement | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) |
| [MIT5K](https://data.csail.mit.edu/graphics/fivek/) | Image Enhancement | [Download](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) |

TODO: Add download links for all the benchmarks.
In releases or clicking the link above you can download [instructir_results.zip](https://github.com/mv-lab/InstructIR/releases/download/instructir-results/instructir_results.zip) which includes all the qualitative results for those datasets [1.9 Gbs].


<img src="static/tables/table1.png" width=100%>
Expand Down Expand Up @@ -149,6 +163,18 @@ The final result looks indeed stunning 🤗 You can do it yourself in the [demo

- ***Why aren't you using diffusion-based models?*** (1) We want to keep the solution simple and efficient. (2) Our priority is high-fidelity --as in many industry scenarios realted to computational photography--.

### Gradio Demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
We made a simple [Gradio demo](app.py) you can run (locally) on your machine [here](app.py). You need Python>=3.9 and [these requirements](requirements_gradio.txt) for it: `pip install -r requirements_gradio.txt`

```
python app.py
```

<br>
<a href="https://huggingface.co/spaces/marcosv/InstructIR">
<img src="images/gradio.png" alt="InstructIR Gradio">
</a>


### Acknowledgments
This work was partly supported by the The Humboldt Foundation (AvH). Marcos Conde is also supported by Sony Interactive Entertainment, FTG.
Expand Down
22 changes: 10 additions & 12 deletions configs/eval5d.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,18 @@ test:
batch_size: 1
num_workers: 3

dn_datapath: "data/denoising_testsets/"
dn_datasets: ["CBSD68", "urban100", "Kodak24", "McMaster"]
dn_datapath: "test-data/denoising_testsets/"
dn_datasets: ["CBSD68", "urban100", "Kodak24"]
dn_sigmas: [15, 25, 50]

rain_targets: ["data/Rain/rain_test/Rain100L/target/"]
rain_inputs: ["data/Rain/rain_test/Rain100L/input/"]
rain_targets: ["test-data/Rain100L/target/"]
rain_inputs: ["test-data/Rain100L/input/"]

haze_targets: "data/SOTS-OUT/GT/"
haze_inputs : "data/SOTS-OUT/IN/"
haze_targets: "test-data/SOTS/GT/"
haze_inputs : "test-data/SOTS/IN/"

lol_targets: "data/LOL/eval15/high/"
lol_inputs : "data/LOL/eval15/low/"
lol_targets: "test-data/LOL/high/"
lol_inputs : "test-data/LOL/low/"

gopro_targets: "data/gopro_test/GoPro/target/"
gopro_inputs: "data/gopro_test/GoPro/input/"


gopro_targets: "test-data/GoPro/target/"
gopro_inputs: "test-data/GoPro/input/"
211 changes: 211 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import json
import os
from glob import glob

from utils import load_img, modcrop


DEG_MAP = {
"noise": 0,
"blur" : 1,
"rain" : 2,
"haze" : 3,
"lol" : 4,
"sr" : 5,
"en" : 6,
}

DEG2TASK = {
"noise": "denoising",
"blur" : "deblurring",
"rain" : "deraining",
"haze" : "dehazing",
"lol" : "lol",
"sr" : "sr",
"en" : "enhancement"
}

def augment_prompt (prompt):
### special prompts
lol_prompts = ["fix the illumination", "increase the exposure of the photo", "the image is too dark to see anything, correct the photo", "poor illumination, improve the shot", "brighten dark regions", "make it HDR", "improve the light of the image", "Can you make the image brighter?"]
sr_prompts = ["I need to enhance the size and quality of this image.", "My photo is lacking size and clarity; can you improve it?", "I'd appreciate it if you could upscale this photo.", "My picture is too little, enlarge it.", "upsample this image", "increase the resolution of this photo", "increase the number of pixels", "upsample this photo", "Add details to this image", "improve the quality of this photo"]
en_prompts = ["make my image look like DSLR", "improve the colors of my image", "improve the contrast of this photo", "apply tonemapping", "enhance the colors of the image", "retouch the photo like a photograper"]

init = np.random.choice(["Remove the", "Reduce the", "Clean the", "Fix the", "Remove", "Improve the", "Correct the",])
end = np.random.choice(["please", "fast", "now", "in the photo", "in the picture", "in the image", ""])
newp = f"{init} {prompt} {end}"

if "lol" in prompt:
newp = np.random.choice(lol_prompts)
elif "sr" in prompt:
newp = np.random.choice(sr_prompts)
elif "en" in prompt:
newp = np.random.choice(en_prompts)

newp = newp.strip().replace(" ", " ").replace("\n", "")
return newp

def get_deg_name(path):
"""
Get the degradation name from the path
"""

if ("gopro" in path) or ("GoPro" in path) or ("blur" in path) or ("Blur" in path) or ("RealBlur" in path):
return "blur"
elif ("SOTS" in path) or ("haze" in path) or ("sots" in path) or ("RESIDE" in path):
return "haze"
elif ("LOL" in path):
return "lol"
elif ("fiveK" in path):
return "en"
elif ("super" in path) or ("classicalSR" in path):
return "sr"
elif ("Rain100" in path) or ("rain13k" in path) or ("Rain13k" in path):
return "rain"
else:
return "noise"

def crop_img(image, base=16):
"""
Mod crop the image to ensure the dimension is divisible by base. Also done by SwinIR, Restormer and others.
"""
h = image.shape[0]
w = image.shape[1]
crop_h = h % base
crop_w = w % base
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]


################# DATASETS


class RefDegImage(Dataset):
"""
Dataset for Image Restoration having low-quality image and the reference image.
Tasks: synthetic denoising, deblurring, super-res, etc.
"""

def __init__(self, hq_img_paths, lq_img_paths, augmentations=None, val=False, name="test", deg_name="noise", deg_class=0):

assert len(hq_img_paths) == len(lq_img_paths)

self.hq_paths = hq_img_paths
self.lq_paths = lq_img_paths
self.totensor = torchvision.transforms.ToTensor()
self.val = val
self.augs = augmentations
self.name = name
self.degradation = deg_name
self.deg_class = deg_class

if self.val:
self.augs = None # No augmentations during validation/test

def __len__(self):
return len(self.hq_paths)

def __getitem__(self, idx):
hq_path = self.hq_paths[idx]
lq_path = self.lq_paths[idx]

hq_image = load_img(hq_path)
lq_image = load_img(lq_path)

if self.val:
# if an image has an odd number dimension we trim for example from [321, 189] to [320, 188].
hq_image = crop_img(hq_image)
lq_image = crop_img(lq_image)

hq_image = self.totensor(hq_image.astype(np.float32))
lq_image = self.totensor(lq_image.astype(np.float32))

return hq_image, lq_image, hq_path



def create_testsets (testsets, debug=False):
"""
Given a list of testsets create pytorch datasets for each.
The method requires the paths to references and noisy images.
"""
assert len(testsets) > 0

if debug:
print (20*'****')
print ("Creating Testsets", len(testsets))

datasets = []
for testdt in testsets:

path_hq , path_lq = testdt[0], testdt[1]
if debug: print (path_hq , path_lq)

if ("denoising" in path_hq) or ("jpeg" in path_hq):
dataset_name = path_hq.split("/")[-1]
dataset_sigma = path_lq.split("/")[-1].split("_")[-1].split(".")[0]
dataset_name = dataset_name+ f"_{dataset_sigma}"
elif "Rain" in path_hq:
if "Rain100L" in path_hq:
dataset_name = "Rain100L"
else:
dataset_name = path_hq.split("/")[3]

elif ("gopro" in path_hq) or ("GoPro" in path_hq):
dataset_name = "GoPro"
elif "LOL" in path_hq:
dataset_name = "LOL"
elif "SOTS" in path_hq:
dataset_name = "SOTS"
elif "fiveK" in path_hq:
dataset_name = "MIT5K"
else:
assert False, f"{path_hq} - unknown dataset"

hq_img_paths = sorted(glob(os.path.join(path_hq, "*")))
lq_img_paths = sorted(glob(os.path.join(path_lq, "*")))

if "SOTS" in path_hq:
# Haze removal SOTS test dataset
dataset_name = "SOTS"
hq_img_paths = sorted(glob(os.path.join(path_hq, "*.jpg")))
assert len(hq_img_paths) == 500

lq_img_paths = [file.replace("GT", "IN") for file in hq_img_paths]

if "fiveK" in path_hq:
dataset_name = "MIT5K"
testf = "test-data/mit5k/test.txt"
f = open(testf, "r")
test_ids = f.readlines()
test_ids = [x.strip() for x in test_ids]
f.close()
hq_img_paths = [os.path.join(path_hq, f"{x}.jpg") for x in test_ids]
lq_img_paths = [x.replace("expertC", "input") for x in hq_img_paths]
assert len(hq_img_paths) == 498

if "gopro" in path_hq:
assert len(hq_img_paths) == 1111

if "LOL" in path_hq:
assert len(hq_img_paths) == 15

assert len(hq_img_paths) == len(lq_img_paths)

deg_name = get_deg_name(path_hq)
deg_class = DEG_MAP[deg_name]

valdts = RefDegImage(hq_img_paths = hq_img_paths,
lq_img_paths = lq_img_paths,
val = True, name= dataset_name, deg_name=deg_name, deg_class=deg_class)

datasets.append(valdts)

assert len(datasets) == len(testsets)
print (20*'****')

return datasets
Loading

0 comments on commit d5417e4

Please sign in to comment.