Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PnP denoiser models #105

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

PnP denoiser models #105

wants to merge 7 commits into from

Conversation

fsherry
Copy link
Member

@fsherry fsherry commented Jun 26, 2024

Attached are some LION implementations of deep denoisers that are commonly used for plug-and-play image restoration.

@AnderBiguri
Copy link
Member

Hey @fsherry Thanks a lot!

Can you also add scripts of how you train them and evaluate them? Even if they are messy, its good for me to have a reference

@fsherry
Copy link
Member Author

fsherry commented Jun 27, 2024

Ok, my training scripts and evaluation scripts are in there now too :)

assert (
isinstance(noise_level, float) and noise_level >= 0.0
), "`noise_level` must be a non-negative float"
x0 = torch.cat((x0, noise_level * torch.ones_like(x0)), dim=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: si in here you concatenate a number for the noise level? Is this like a "label" for the noise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I didn't use it in the benchmarking paper, but the paper that introduced DRUNet suggests this. Then you can train on multiple noise levels, while inputting both noisy images and the noise level, and you hope that it generalises to a spectrum of noise levels.

Comment on lines +62 to +63
x = x.detach()
x.requires_grad = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this? So you remove it from the computational graph, but then you start a new computaional graph by requiring grad?

Given that forward itself has no loops, how is this doing anything? Or is it to do with something about the way you trained?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives you a new tensor with the same underlying data but which doesn't interfere with previous things done on x. If I removed line 62, I would be messing with the x that the user has input (undesirable), whereas now after line 62 x refers to a new tensor so I can do with it what I want (except modify it). I could have just as well said y = x.detach() in line 62 and then used y for the rest of it. The screenshot shows what I mean:
Screenshot 2024-07-15 at 23 21 18

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I am still confused :D. Yes, I do understand what it does in principle, but when we make ML models, we don't need to detach the input of the model, right? the computational graph understands that its the input of the model and starts the graph there. given that obj_grad is just used in forward, why is this line needed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to do this because we do a backward pass inside obj_grad. Otherwise you mess up the input x

optimiser = torch.optim.Adam(model.parameters(), lr=params["lr"], betas=(0.9, 0.9))
random_crop = RandomCrop((256, 256))
random_erasing = RandomErasing()
experiment = ct_benchmarking.GroundTruthCT()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may have missed to add some files?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'm not sure why you think so? RandomCrop and RandomErasing are from kornia.augmentation and imported on line 16.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you introduced ct_benchmarking.GroundTruthCT, its not in main.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, it was in this pull request #93

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Man, I literally went to the file and contr+f it and could not find it. Not my brightest day.

Comment on lines +187 to +197
x = x.cuda()
x = (x - x_min) / (x_max - x_min)
patches = random_erasing(torch.cat([random_crop(x) for _ in range(5)], dim=0))
optimiser.zero_grad()
y = patches + params["noise_level"] * torch.randn_like(patches)
recon = model(y)
loss = torch.mean((recon - patches) ** 2)
loss.backward()
grad_norm = mean_grad_norm(model)
losses.append(loss.item())
optimiser.step()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you erase some patches, then step on the loss and then test on the patches+nose? Or random_erasing gives you the patches??

You did this because of memory, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random_crop selects random square patches from the image, while random_erasing erases random rectangles from the patches (intended to counteract overfitting). After applying those transforms, I add noise and pass the noisy cropped+erased patches through the denoiser to train. Training on patches was necessary at least for the gradient-step denoisers because of memory constraints.


def forward(self, x, noise_level: Optional[float] = None):
_, grad = self.obj_grad(x, noise_level)
return x - grad
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, this is the one that gets messed up, x-grad. Makes sense now, thanks :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants