-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] AffineAutoregressive transform leads to exploding gradients #85
Comments
Hi @francois-rozet thanks for raising this. I agree, we should have a softplus non-linearity. I usually use @stefanwebb what about letting the user choose which non-linearity must be used for the positive mapping of parameters? Something like layer = AffineLayer(positive_map='softplus') I think that for |
Side note: the method on this: def clamp_preserve_gradients(x: torch.Tensor, min: float, max: float) -> torch.Tensor:
"""
This helper function clamps gradients but still passes through the
gradient in clamped regions
"""
x.data.clamp_(min, max)
return x where all modifications are done in place. |
Summary: ### Motivation As pointed out in #85, it may be preferable to use `softplus` rather than `exp` to calculate the scale parameter of the affine map in `bij.ops.Affine`. ### Changes proposed Another PR #92 by vmoens implements `softplus`, `sigmoid`, and `exp` options for the scale parameter - I have factored that out and simplified some of the design in order to make #92 easier for review. `softplus` is now the default option for `Affine` Pull Request resolved: #109 Test Plan: `pytest tests/` Reviewed By: vmoens Differential Revision: D36169529 Pulled By: stefanwebb fbshipit-source-id: 625387e10399291a5a404c28f4ada743d0945649
Issue Description
In the
Affine
bijector, the scale parameter is obtained by clamping the parameters (network) output. According to some of my experiments this results in very unstable behavior and exploding gradients, especially in low entropy settings. I believe this is due to the non-continuities introduced in the gradients by the clamp operation.Instead of clamping, the
nflows
package appliessoftplus
to the network's output which also has the effect to bound (by below) the scale, while keeping smooth gradients. According to my experiments with Pyro,softplus
works better thanclamping
and, importantly, is not subject to exploding gradients. I would suggest replacing clamping bysoftplus
.Expected Behavior
Avoid exploding gradients. I have implemented the replacement of clamping by
softplus
for FlowTorch (https://github.com/francois-rozet/flowtorch/commit/9bf41e5b67a8993aa6173d6341f9d99ae5e7178b) but haven't had the time to test it properly.Additional Context
This issue is a replica of pyro-ppl/pyro#2998
Merry Christmas 🎄
The text was updated successfully, but these errors were encountered: