-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PnP denoiser models #105
base: main
Are you sure you want to change the base?
PnP denoiser models #105
Conversation
Hey @fsherry Thanks a lot! Can you also add scripts of how you train them and evaluate them? Even if they are messy, its good for me to have a reference |
Ok, my training scripts and evaluation scripts are in there now too :) |
assert ( | ||
isinstance(noise_level, float) and noise_level >= 0.0 | ||
), "`noise_level` must be a non-negative float" | ||
x0 = torch.cat((x0, noise_level * torch.ones_like(x0)), dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: si in here you concatenate a number for the noise level? Is this like a "label" for the noise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I didn't use it in the benchmarking paper, but the paper that introduced DRUNet suggests this. Then you can train on multiple noise levels, while inputting both noisy images and the noise level, and you hope that it generalises to a spectrum of noise levels.
x = x.detach() | ||
x.requires_grad = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this? So you remove it from the computational graph, but then you start a new computaional graph by requiring grad?
Given that forward
itself has no loops, how is this doing anything? Or is it to do with something about the way you trained?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives you a new tensor with the same underlying data but which doesn't interfere with previous things done on x
. If I removed line 62, I would be messing with the x
that the user has input (undesirable), whereas now after line 62 x
refers to a new tensor so I can do with it what I want (except modify it). I could have just as well said y = x.detach()
in line 62 and then used y
for the rest of it. The screenshot shows what I mean:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I am still confused :D. Yes, I do understand what it does in principle, but when we make ML models, we don't need to detach the input of the model, right? the computational graph understands that its the input of the model and starts the graph there. given that obj_grad
is just used in forward
, why is this line needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to do this because we do a backward pass inside obj_grad
. Otherwise you mess up the input x
optimiser = torch.optim.Adam(model.parameters(), lr=params["lr"], betas=(0.9, 0.9)) | ||
random_crop = RandomCrop((256, 256)) | ||
random_erasing = RandomErasing() | ||
experiment = ct_benchmarking.GroundTruthCT() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may have missed to add some files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I'm not sure why you think so? RandomCrop
and RandomErasing
are from kornia.augmentation
and imported on line 16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you introduced ct_benchmarking.GroundTruthCT
, its not in main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is, it was in this pull request #93
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Man, I literally went to the file and contr+f it and could not find it. Not my brightest day.
x = x.cuda() | ||
x = (x - x_min) / (x_max - x_min) | ||
patches = random_erasing(torch.cat([random_crop(x) for _ in range(5)], dim=0)) | ||
optimiser.zero_grad() | ||
y = patches + params["noise_level"] * torch.randn_like(patches) | ||
recon = model(y) | ||
loss = torch.mean((recon - patches) ** 2) | ||
loss.backward() | ||
grad_norm = mean_grad_norm(model) | ||
losses.append(loss.item()) | ||
optimiser.step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you erase some patches, then step on the loss and then test on the patches+nose? Or random_erasing
gives you the patches??
You did this because of memory, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random_crop
selects random square patches from the image, while random_erasing
erases random rectangles from the patches (intended to counteract overfitting). After applying those transforms, I add noise and pass the noisy cropped+erased patches through the denoiser to train. Training on patches was necessary at least for the gradient-step denoisers because of memory constraints.
|
||
def forward(self, x, noise_level: Optional[float] = None): | ||
_, grad = self.obj_grad(x, noise_level) | ||
return x - grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, this is the one that gets messed up, x-grad
. Makes sense now, thanks :)
Attached are some LION implementations of deep denoisers that are commonly used for plug-and-play image restoration.