diff --git a/scripts/benchmarking/eval_dncnn_pnp.py b/scripts/benchmarking/eval_dncnn_pnp.py new file mode 100644 index 0000000..df54261 --- /dev/null +++ b/scripts/benchmarking/eval_dncnn_pnp.py @@ -0,0 +1,141 @@ +# This file is part of LION library +# License : GPL-3 +# +# Author: Ferdia Sherry +# Modifications: - +# ============================================================================= + +from LION.CTtools.ct_utils import make_operator +from LION.experiments.ct_benchmarking_experiments import ( + FullDataCTRecon, + LimitedAngle120CTRecon, + LimitedAngle90CTRecon, + LimitedAngle60CTRecon, + SparseAngle360CTRecon, + SparseAngle120CTRecon, + SparseAngle60CTRecon, + LowDoseCTRecon, + BeamHardeningCTRecon, +) +from LION.models.LIONmodel import LIONParameter +from LION.models.PnP import DnCNN + +import argparse +import json +import os +from skimage.metrics import structural_similarity as ssim +import torch + + +def psnr(x, y): + return 10 * torch.log10((x**2).max() / ((x - y) ** 2).mean()) + + +def my_ssim(x: torch.tensor, y: torch.tensor): + x = x.cpu().numpy().squeeze() + y = y.cpu().numpy().squeeze() + return ssim(x, y, data_range=x.max() - x.min()) + + +with open("normalisation.json", "r") as in_file: + normalisation = json.load(in_file) + x_min, x_max = normalisation["x_min"], normalisation["x_max"] + + +def get_denoiser(model): + def denoiser(x): + x = (x - x_min) / (x_max - x_min) + out = model(x) + return x_min + (x_max - x_min) * out + + return denoiser + + +def operator_norm(operator, N_iter=500): + u = torch.randn(1, 1024, 1024).cuda() + for i in range(N_iter): + u /= u.norm() + u = operator.T(operator(u)) + return u.norm().sqrt().item() + + +parser = argparse.ArgumentParser("validate_dncnn") +parser.add_argument("--checkpoint", type=str) +parser.add_argument("--result_path", type=str, default=".") +parser.add_argument("--device", type=int, default=0) +parser.add_argument("--testing", action="store_true") +params = vars(parser.parse_args()) +print(params) + +torch.cuda.set_device(params["device"]) +chkpt = torch.load(params["checkpoint"], map_location="cpu") +config = chkpt["config"] +model = DnCNN( + LIONParameter( + in_channels=1, + int_channels=config["int_channels"], + kernel_size=(config["kernel_size"], config["kernel_size"]), + blocks=config["depth"], + residual=True, + bias_free=config["bias_free"], + act="leaky_relu", + enforce_positivity=config["enforce_positivity"], + batch_normalisation=True, + ) +).cuda() +model.load_state_dict(chkpt["state_dict"]) +model.eval() +denoiser = get_denoiser(model) + + +for experiment in [ + FullDataCTRecon(), + LimitedAngle120CTRecon(), + LimitedAngle90CTRecon(), + LimitedAngle60CTRecon(), + SparseAngle360CTRecon(), + SparseAngle120CTRecon(), + SparseAngle60CTRecon(), + LowDoseCTRecon(), + BeamHardeningCTRecon(), +]: + print(experiment) + operator = make_operator(experiment.geo) + op_norm = operator_norm(operator) + step_size = 1.0 / op_norm**2 + + if params["testing"]: + data = experiment.get_testing_dataset() + split = "test" + else: + data = experiment.get_validation_dataset() + split = "val" + dataloader = torch.utils.data.DataLoader(data, 1, shuffle=False) + + psnrs = [] + ssims = [] + for i, (y, x) in enumerate(dataloader): + y, x = y.cuda(), x.cuda() + recon = torch.zeros_like(x) + with torch.no_grad(): + for it in range(100): + recon = denoiser( + recon + - step_size * operator.T(operator(recon[0]) - y[0]).unsqueeze(0) + ) + psnrs.append(psnr(x, recon).item()) + ssims.append(my_ssim(x, recon).item()) + print( + f"It {i + 1} / {len(dataloader)}: PSNR = {psnrs[-1]:.1f} dB, SSIM = {ssims[-1]:.3}" + ) + psnrs, ssims = torch.tensor(psnrs), torch.tensor(ssims) + torch.save( + {"psnrs": psnrs, "ssims": ssims}, + os.path.join( + params["result_path"], + f"dncnn_{experiment.experiment_params.name.replace(' ', '_')}_{split}_noise_level={config['noise_level']}.pt", + ), + ) + print( + f"PSNR = {psnrs.mean():.1f} +- {psnrs.std():.1f} dB, SSIM= {ssims.mean():.3f} +- {ssims.std():.3f}" + ) diff --git a/scripts/benchmarking/eval_drunet_pnp.py b/scripts/benchmarking/eval_drunet_pnp.py new file mode 100644 index 0000000..55afbe2 --- /dev/null +++ b/scripts/benchmarking/eval_drunet_pnp.py @@ -0,0 +1,141 @@ +# This file is part of LION library +# License : GPL-3 +# +# Author: Ferdia Sherry +# Modifications: - +# ============================================================================= + +from LION.CTtools.ct_utils import make_operator +from LION.experiments.ct_benchmarking_experiments import ( + FullDataCTRecon, + LimitedAngle120CTRecon, + LimitedAngle90CTRecon, + LimitedAngle60CTRecon, + SparseAngle360CTRecon, + SparseAngle120CTRecon, + SparseAngle60CTRecon, + LowDoseCTRecon, + BeamHardeningCTRecon, +) +from LION.models.LIONmodel import LIONParameter +from LION.models.PnP import DRUNet + +import argparse +import json +import os +from skimage.metrics import structural_similarity as ssim +import torch + + +def psnr(x, y): + return 10 * torch.log10((x**2).max() / ((x - y) ** 2).mean()) + + +def my_ssim(x: torch.tensor, y: torch.tensor): + x = x.cpu().numpy().squeeze() + y = y.cpu().numpy().squeeze() + return ssim(x, y, data_range=x.max() - x.min()) + + +with open("normalisation.json", "r") as in_file: + normalisation = json.load(in_file) + x_min, x_max = normalisation["x_min"], normalisation["x_max"] + + +def get_denoiser(model): + def denoiser(x): + x = (x - x_min) / (x_max - x_min) + out = model(x) + return x_min + (x_max - x_min) * out + + return denoiser + + +def operator_norm(operator, N_iter=500): + u = torch.randn(1, 1024, 1024).cuda() + for i in range(N_iter): + u /= u.norm() + u = operator.T(operator(u)) + return u.norm().sqrt().item() + + +parser = argparse.ArgumentParser("validate_dncnn") +parser.add_argument("--checkpoint", type=str) +parser.add_argument("--result_path", type=str, default=".") +parser.add_argument("--device", type=int, default=0) +parser.add_argument("--testing", action="store_true") +params = vars(parser.parse_args()) +print(params) + +torch.cuda.set_device(params["device"]) +chkpt = torch.load(params["checkpoint"], map_location="cpu") +config = chkpt["config"] +model = DRUNet( + LIONParameter( + in_channels=1, + out_channels=1, + int_channels=config["int_channels"], + kernel_size=(config["kernel_size"], config["kernel_size"]), + n_blocks=config["n_blocks"], + use_noise_level=False, + bias_free=config["bias_free"], + act="leaky_relu", + enforce_positivity=config["enforce_positivity"], + ) +).cuda() +model.load_state_dict(chkpt["state_dict"]) +model.eval() +denoiser = get_denoiser(model) + + +for experiment in [ + FullDataCTRecon(), + LimitedAngle120CTRecon(), + LimitedAngle90CTRecon(), + LimitedAngle60CTRecon(), + SparseAngle360CTRecon(), + SparseAngle120CTRecon(), + SparseAngle60CTRecon(), + LowDoseCTRecon(), + BeamHardeningCTRecon(), +]: + print(experiment) + operator = make_operator(experiment.geo) + op_norm = operator_norm(operator) + step_size = 1.0 / op_norm**2 + + if params["testing"]: + data = experiment.get_testing_dataset() + split = "test" + else: + data = experiment.get_validation_dataset() + split = "val" + dataloader = torch.utils.data.DataLoader(data, 1, shuffle=False) + + psnrs = [] + ssims = [] + for i, (y, x) in enumerate(dataloader): + y, x = y.cuda(), x.cuda() + recon = torch.zeros_like(x) + with torch.no_grad(): + for it in range(100): + recon = denoiser( + recon + - step_size * operator.T(operator(recon[0]) - y[0]).unsqueeze(0) + ) + psnrs.append(psnr(x, recon).item()) + ssims.append(my_ssim(x, recon).item()) + print( + f"It {i + 1} / {len(dataloader)}: PSNR = {psnrs[-1]:.1f} dB, SSIM = {ssims[-1]:.3}" + ) + psnrs, ssims = torch.tensor(psnrs), torch.tensor(ssims) + torch.save( + {"psnrs": psnrs, "ssims": ssims}, + os.path.join( + params["result_path"], + f"drunet_{experiment.experiment_params.name.replace(' ', '_')}_{split}_noise_level={config['noise_level']}.pt", + ), + ) + print( + f"PSNR = {psnrs.mean():.1f} +- {psnrs.std():.1f} dB, SSIM= {ssims.mean():.3f} +- {ssims.std():.3f}" + ) diff --git a/scripts/benchmarking/eval_gs_drunet_pnp.py b/scripts/benchmarking/eval_gs_drunet_pnp.py new file mode 100644 index 0000000..c67a048 --- /dev/null +++ b/scripts/benchmarking/eval_gs_drunet_pnp.py @@ -0,0 +1,147 @@ +# This file is part of LION library +# License : GPL-3 +# +# Author: Ferdia Sherry +# Modifications: - +# ============================================================================= + +from LION.CTtools.ct_utils import make_operator +from LION.experiments.ct_benchmarking_experiments import ( + FullDataCTRecon, + LimitedAngle120CTRecon, + LimitedAngle90CTRecon, + LimitedAngle60CTRecon, + SparseAngle360CTRecon, + SparseAngle120CTRecon, + SparseAngle60CTRecon, + LowDoseCTRecon, + BeamHardeningCTRecon, +) +from LION.models.LIONmodel import LIONParameter +from LION.models.PnP import GSDRUNet + +import argparse +import json +import os +from skimage.metrics import structural_similarity as ssim +import torch + + +def psnr(x, y): + return 10 * torch.log10((x**2).max() / ((x - y) ** 2).mean()) + + +def my_ssim(x: torch.tensor, y: torch.tensor): + x = x.cpu().numpy().squeeze() + y = y.cpu().numpy().squeeze() + return ssim(x, y, data_range=x.max() - x.min()) + + +with open("normalisation.json", "r") as in_file: + normalisation = json.load(in_file) + x_min, x_max = normalisation["x_min"], normalisation["x_max"] + + +def get_denoiser(model): + def denoiser(x): + x = (x - x_min) / (x_max - x_min) + out = model(x) + return x_min + (x_max - x_min) * out + + return denoiser + + +def data_obj_grad(op, x, y): + res = op(x[0]) - y[0] + data_grad = op.T(res).unsqueeze(0) + return 0.5 * (res**2).sum(), data_grad + + +def operator_norm(operator, N_iter=500): + u = torch.randn(1, 1024, 1024).cuda() + for i in range(N_iter): + u /= u.norm() + u = operator.T(operator(u)) + return u.norm().sqrt().item() + + +parser = argparse.ArgumentParser("validate_dncnn") +parser.add_argument("--checkpoint", type=str) +parser.add_argument("--result_path", type=str, default=".") +parser.add_argument("--device", type=int, default=0) +parser.add_argument("--testing", action="store_true") +params = vars(parser.parse_args()) +print(params) + +torch.cuda.set_device(params["device"]) +chkpt = torch.load(params["checkpoint"], map_location="cpu") +config = chkpt["config"] +model = GSDRUNet( + LIONParameter( + in_channels=1, + out_channels=1, + int_channels=config["int_channels"], + kernel_size=(config["kernel_size"], config["kernel_size"]), + n_blocks=config["n_blocks"], + use_noise_level=False, + bias_free=config["bias_free"], + act="elu", + enforce_positivity=False, + ) +).cuda() +model.load_state_dict(chkpt["state_dict"]) +model.eval() +denoiser = get_denoiser(model) + + +for experiment in [ + FullDataCTRecon(), + LimitedAngle120CTRecon(), + LimitedAngle90CTRecon(), + LimitedAngle60CTRecon(), + SparseAngle360CTRecon(), + SparseAngle120CTRecon(), + SparseAngle60CTRecon(), + LowDoseCTRecon(), + BeamHardeningCTRecon(), +]: + print(experiment) + operator = make_operator(experiment.geo) + op_norm = operator_norm(operator) + step_size = 1.0 / op_norm**2 + + if params["testing"]: + data = experiment.get_testing_dataset() + split = "test" + else: + data = experiment.get_validation_dataset() + split = "val" + dataloader = torch.utils.data.DataLoader(data, 1, shuffle=False) + + psnrs = [] + ssims = [] + for i, (y, x) in enumerate(dataloader): + y, x = y.cuda(), x.cuda() + recon = torch.zeros_like(x) + for it in range(100): + data_obj, data_grad = data_obj_grad(operator, recon, y) + reg_obj, reg_grad = model.obj_grad(recon) + print((data_obj + reg_obj).item()) + recon = recon - step_size * (data_grad + reg_grad) + recon = denoiser(recon.detach()) + psnrs.append(psnr(x, recon).item()) + ssims.append(my_ssim(x, recon).item()) + print( + f"It {i + 1} / {len(dataloader)}: PSNR = {psnrs[-1]:.1f} dB, SSIM = {ssims[-1]:.3}" + ) + psnrs, ssims = torch.tensor(psnrs), torch.tensor(ssims) + torch.save( + {"psnrs": psnrs, "ssims": ssims}, + os.path.join( + params["result_path"], + f"gs_drunet_{experiment.experiment_params.name.replace(' ', '_')}_{split}_noise_level={config['noise_level']}.pt", + ), + ) + print( + f"PSNR = {psnrs.mean():.1f} +- {psnrs.std():.1f} dB, SSIM= {ssims.mean():.3f} +- {ssims.std():.3f}" + ) diff --git a/scripts/benchmarking/train_gaussian_denoiser.py b/scripts/benchmarking/train_gaussian_denoiser.py index cda02ce..ad381f7 100644 --- a/scripts/benchmarking/train_gaussian_denoiser.py +++ b/scripts/benchmarking/train_gaussian_denoiser.py @@ -1,3 +1,10 @@ +# This file is part of LION library +# License : GPL-3 +# +# Author: Ferdia Sherry +# Modifications: - +# ============================================================================= + import LION.experiments.ct_benchmarking_experiments as ct_benchmarking from LION.models.LIONmodel import LIONParameter from LION.models.PnP import DnCNN, DRUNet, GSDRUNet @@ -41,206 +48,200 @@ def mean_grad_norm(model: torch.nn.Module): parser.add_argument("--bias_free", action="store_true") parser.add_argument("--enforce_positivity", action="store_true") parser.add_argument("--debug", action="store_true") -parser.add_argument("--results_path", type=str, default="") - - -if __name__ == "__main__": - params = vars(parser.parse_args()) - if params["debug"]: - os.environ["WANDB_MODE"] = "offline" - - assert params["model"] in ["dncnn", "drunet", "gs_drunet"] - - print(params) - torch.cuda.set_device(params["device"]) - - commit_hash = git.Repo( - ".", search_parent_directories=True - ).head.reference.commit.hexsha - - if params["model"] == "dncnn": - model_params = LIONParameter( - in_channels=1, - int_channels=params["channels"], - kernel_size=(params["kernel_size"], params["kernel_size"]), - blocks=params["depth"], - residual=True, - bias_free=params["bias_free"], - act="leaky_relu", - enforce_positivity=params["enforce_positivity"], - batch_normalisation=True, - ) - model = DnCNN(model_params).cuda() - config = { - "depth": params["depth"], - "int_channels": params["channels"], - "kernel_size": params["kernel_size"], - "bias_free": params["bias_free"], - "enforce_positivity": params["enforce_positivity"], - "lr": params["lr"], - "epochs": params["epochs"], - "noise_level": params["noise_level"], - "commit_hash": commit_hash, - } - elif params["model"] == "drunet": - model_params = LIONParameter( - in_channels=1, - out_channels=1, - int_channels=params["channels"], - kernel_size=(params["kernel_size"], params["kernel_size"]), - n_blocks=params["n_blocks"], - use_noise_level=False, - bias_free=params["bias_free"], - act="leaky_relu", - enforce_positivity=params["enforce_positivity"], +parser.add_argument("--results_path", type=str, default=".") +params = vars(parser.parse_args()) + +if params["debug"]: + os.environ["WANDB_MODE"] = "offline" + +assert params["model"] in ["dncnn", "drunet", "gs_drunet"] + +print(params) +torch.cuda.set_device(params["device"]) + +commit_hash = git.Repo(".", search_parent_directories=True).head.reference.commit.hexsha + +if params["model"] == "dncnn": + model_params = LIONParameter( + in_channels=1, + int_channels=params["channels"], + kernel_size=(params["kernel_size"], params["kernel_size"]), + blocks=params["depth"], + residual=True, + bias_free=params["bias_free"], + act="leaky_relu", + enforce_positivity=params["enforce_positivity"], + batch_normalisation=True, + ) + model = DnCNN(model_params).cuda() + config = { + "depth": params["depth"], + "int_channels": params["channels"], + "kernel_size": params["kernel_size"], + "bias_free": params["bias_free"], + "enforce_positivity": params["enforce_positivity"], + "lr": params["lr"], + "epochs": params["epochs"], + "noise_level": params["noise_level"], + "commit_hash": commit_hash, + } +elif params["model"] == "drunet": + model_params = LIONParameter( + in_channels=1, + out_channels=1, + int_channels=params["channels"], + kernel_size=(params["kernel_size"], params["kernel_size"]), + n_blocks=params["n_blocks"], + use_noise_level=False, + bias_free=params["bias_free"], + act="leaky_relu", + enforce_positivity=params["enforce_positivity"], + ) + model = DRUNet(model_params).cuda() + config = { + "n_blocks": params["n_blocks"], + "int_channels": params["channels"], + "kernel_size": params["kernel_size"], + "bias_free": params["bias_free"], + "enforce_positivity": params["enforce_positivity"], + "lr": params["lr"], + "epochs": params["epochs"], + "noise_level": params["noise_level"], + "commit_hash": commit_hash, + } +elif params["model"] == "gs_drunet": + model_params = LIONParameter( + in_channels=1, + out_channels=1, + int_channels=params["channels"], + kernel_size=(params["kernel_size"], params["kernel_size"]), + n_blocks=params["n_blocks"], + use_noise_level=False, + bias_free=params["bias_free"], + act="elu", + enforce_positivity=params["enforce_positivity"], + ) + model = GSDRUNet(model_params).cuda() + config = { + "n_blocks": params["n_blocks"], + "int_channels": params["channels"], + "kernel_size": params["kernel_size"], + "bias_free": params["bias_free"], + "lr": params["lr"], + "epochs": params["epochs"], + "noise_level": params["noise_level"], + "commit_hash": commit_hash, + } +else: + raise NotImplementedError(f"Model {params['model']} has not been implemented!") + +experiment_id = uuid.uuid1() +experiment_name = f"{params['model']}" +if params["bias_free"]: + experiment_name += "_bias_free" +if "dncnn" in params["model"]: + experiment_name += f"_depth={params['depth']}" +else: + experiment_name += f"_n_blocks={params['n_blocks']}" +experiment_name += f"_noise_level={params['noise_level']}_{experiment_id}" +print(experiment_name) +print(config) +wandb.init(project="benchmarking_ct", config=config, name=experiment_name) + +optimiser = torch.optim.Adam(model.parameters(), lr=params["lr"], betas=(0.9, 0.9)) +random_crop = RandomCrop((256, 256)) +random_erasing = RandomErasing() +experiment = ct_benchmarking.GroundTruthCT() + +training_data = experiment.get_training_dataset() +validation_data = experiment.get_validation_dataset() +testing_data = experiment.get_testing_dataset() + +print( + f"N_train={len(training_data)}, N_val={len(validation_data)}, N_test={len(testing_data)}" +) + +batch_size = 1 + +training_dataloader = torch.utils.data.DataLoader( + training_data, batch_size, shuffle=True +) +validation_dataloader = torch.utils.data.DataLoader( + validation_data, batch_size, shuffle=False +) +testing_dataloader = torch.utils.data.DataLoader( + testing_data, batch_size, shuffle=False +) + +with open("normalisation.json", "r") as fp: + normalisation = json.load(fp) + x_min, x_max = normalisation["x_min"], normalisation["x_max"] + +best_val_psnr = -inf + +losses = [] +val_psnrs = [] +for epoch in range(params["epochs"]): + model.train() + for it, x in enumerate(training_dataloader): + x = x.cuda() + x = (x - x_min) / (x_max - x_min) + patches = random_erasing(torch.cat([random_crop(x) for _ in range(5)], dim=0)) + optimiser.zero_grad() + y = patches + params["noise_level"] * torch.randn_like(patches) + recon = model(y) + loss = torch.mean((recon - patches) ** 2) + loss.backward() + grad_norm = mean_grad_norm(model) + losses.append(loss.item()) + optimiser.step() + with torch.no_grad(): + y_psnr = psnr(patches, y) + recon_psnr = psnr(patches, recon) + print( + f"Epoch {epoch}, it {it}: PSNR(x, y) = {y_psnr.item():.1f} dB, PSNR(x, recon) = {recon_psnr:.1f} dB, loss = {loss.item():.2e}" ) - model = DRUNet(model_params).cuda() - config = { - "n_blocks": params["n_blocks"], - "int_channels": params["channels"], - "kernel_size": params["kernel_size"], - "bias_free": params["bias_free"], - "enforce_positivity": params["enforce_positivity"], - "lr": params["lr"], - "epochs": params["epochs"], - "noise_level": params["noise_level"], - "commit_hash": commit_hash, - } - elif params["model"] == "gs_drunet": - model_params = LIONParameter( - in_channels=1, - out_channels=1, - int_channels=params["channels"], - kernel_size=(params["kernel_size"], params["kernel_size"]), - n_blocks=params["n_blocks"], - use_noise_level=False, - bias_free=params["bias_free"], - act="elu", - enforce_positivity=params["enforce_positivity"], + wandb.log( + { + "train_loss": loss.item(), + "train_psnr": recon_psnr.item(), + "train_psnr_y": y_psnr.item(), + "train_psnr_offset": (recon_psnr - y_psnr).item(), + "grad_norm": grad_norm, + } ) - model = GSDRUNet(model_params).cuda() - config = { - "n_blocks": params["n_blocks"], - "int_channels": params["channels"], - "kernel_size": params["kernel_size"], - "bias_free": params["bias_free"], - "lr": params["lr"], - "epochs": params["epochs"], - "noise_level": params["noise_level"], - "commit_hash": commit_hash, - } - else: - raise NotImplementedError(f"Model {params['model']} has not been implemented!") - - experiment_id = uuid.uuid1() - experiment_name = f"{params['model']}" - if params["bias_free"]: - experiment_name += "_bias_free" - if "dncnn" in params["model"]: - experiment_name += f"_depth={params['depth']}" - else: - experiment_name += f"_n_blocks={params['n_blocks']}" - experiment_name += f"_noise_level={params['noise_level']}_{experiment_id}" - print(experiment_name) - print(config) - wandb.init(project="benchmarking_ct", config=config, name=experiment_name) - - optimiser = torch.optim.Adam(model.parameters(), lr=params["lr"], betas=(0.9, 0.9)) - random_crop = RandomCrop((256, 256)) - random_erasing = RandomErasing() - experiment = ct_benchmarking.GroundTruthCT() - - training_data = experiment.get_training_dataset() - validation_data = experiment.get_validation_dataset() - testing_data = experiment.get_testing_dataset() + psnrs = [] + y_psnrs = [] + model.eval() + for x in validation_dataloader: + x = x.cuda() + x = (x - x_min) / (x_max - x_min) + y = x + params["noise_level"] * torch.randn_like(x) + if params["model"] not in ["gs_drunet"]: + with torch.no_grad(): + recon = model(y) + else: + recon = model(y) + with torch.no_grad(): + psnrs.append(psnr(x, recon).item()) + y_psnrs.append(psnr(x, y).item()) + psnrs = torch.tensor(psnrs) + y_psnrs = torch.tensor(y_psnrs) print( - f"N_train={len(training_data)}, N_val={len(validation_data)}, N_test={len(testing_data)}" + f"Epoch {epoch}, val PSNR(x, y) = {y_psnrs.mean():.1f} +- {y_psnrs.std():.1f} dB, val PSNR(x, recon) = {psnrs.mean():.1f} +- {psnrs.std():.1f} dB" ) - - batch_size = 1 - - training_dataloader = torch.utils.data.DataLoader( - training_data, batch_size, shuffle=True - ) - validation_dataloader = torch.utils.data.DataLoader( - validation_data, batch_size, shuffle=False - ) - testing_dataloader = torch.utils.data.DataLoader( - testing_data, batch_size, shuffle=False + print( + f"Epoch {epoch}, val PSNRs: 5%-quantile {psnrs.quantile(0.05):.1f} dB, median {psnrs.quantile(0.5):.1f}, 95%-quantile {psnrs.quantile(0.95):.1f} dB" ) + wandb.log({"val_psnrs": psnrs}) - with open("normalisation.json", "r") as fp: - normalisation = json.load(fp) - x_min, x_max = normalisation["x_min"], normalisation["x_max"] - - best_val_psnr = -inf - - losses = [] - val_psnrs = [] - for epoch in range(params["epochs"]): - model.train() - for it, x in enumerate(training_dataloader): - x = x.cuda() - x = (x - x_min) / (x_max - x_min) - patches = random_erasing( - torch.cat([random_crop(x) for _ in range(5)], dim=0) - ) - optimiser.zero_grad() - y = patches + params["noise_level"] * torch.randn_like(patches) - recon = model(y) - loss = torch.mean((recon - patches) ** 2) - loss.backward() - grad_norm = mean_grad_norm(model) - losses.append(loss.item()) - optimiser.step() - with torch.no_grad(): - y_psnr = psnr(patches, y) - recon_psnr = psnr(patches, recon) - print( - f"Epoch {epoch}, it {it}: PSNR(x, y) = {y_psnr.item():.1f} dB, PSNR(x, recon) = {recon_psnr:.1f} dB, loss = {loss.item():.2e}" - ) - wandb.log( - { - "train_loss": loss.item(), - "train_psnr": recon_psnr.item(), - "train_psnr_y": y_psnr.item(), - "train_psnr_offset": (recon_psnr - y_psnr).item(), - "grad_norm": grad_norm, - } - ) - - psnrs = [] - y_psnrs = [] - model.eval() - for x in validation_dataloader: - x = x.cuda() - x = (x - x_min) / (x_max - x_min) - y = x + params["noise_level"] * torch.randn_like(x) - if params["model"] not in ["gs_drunet"]: - with torch.no_grad(): - recon = model(y) - else: - recon = model(y) - with torch.no_grad(): - psnrs.append(psnr(x, recon).item()) - y_psnrs.append(psnr(x, y).item()) - psnrs = torch.tensor(psnrs) - y_psnrs = torch.tensor(y_psnrs) - print( - f"Epoch {epoch}, val PSNR(x, y) = {y_psnrs.mean():.1f} +- {y_psnrs.std():.1f} dB, val PSNR(x, recon) = {psnrs.mean():.1f} +- {psnrs.std():.1f} dB" - ) - print( - f"Epoch {epoch}, val PSNRs: 5%-quantile {psnrs.quantile(0.05):.1f} dB, median {psnrs.quantile(0.5):.1f}, 95%-quantile {psnrs.quantile(0.95):.1f} dB" + if psnrs.mean() > best_val_psnr: + best_val_psnr = psnrs.mean().item() + torch.save( + {"config": config, "state_dict": model.state_dict()}, + os.path.join(params["results_path"], f"{experiment_name}.pt"), ) - wandb.log({"val_psnrs": psnrs}) - - if psnrs.mean() > best_val_psnr: - best_val_psnr = psnrs.mean().item() - torch.save( - {"config": config, "state_dict": model.state_dict()}, - os.path.join(params["results_path"], f"{experiment_name}.pt"), - ) - wandb.log({"best_val_psnr": best_val_psnr}) - wandb.log({"val_psnr_y": y_psnrs.mean().item()}) - wandb.log({"val_psnr_offset": best_val_psnr - y_psnrs.mean().item()}) + wandb.log({"best_val_psnr": best_val_psnr}) + wandb.log({"val_psnr_y": y_psnrs.mean().item()}) + wandb.log({"val_psnr_offset": best_val_psnr - y_psnrs.mean().item()})