You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TLDR: skip activation function gradients and center the network
How do you make a wrapper to skip the gradients for activation functions? I'm a bit of a noob on this but I assume I need to pass the tangents along?
This works, but only if you don't use it inside flax.nn.module (otherwise you get assert self.master.trace_type is StagingJaxprTrace assertion error) -- and I'm not sure if it wipes earlier gradients
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
to implement hebbian descent as a tinkerer I want to skip gradients for the activation functions
refs on the idea:
https://arxiv.org/abs/1905.10585 Hebbian-Descent
https://arxiv.org/abs/1905.12937 A Hippocampus Model for Online One-Shot Storage of Pattern Sequences
TLDR: skip activation function gradients and center the network
How do you make a wrapper to skip the gradients for activation functions? I'm a bit of a noob on this but I assume I need to pass the tangents along?
This works, but only if you don't use it inside flax.nn.module (otherwise you get
assert self.master.trace_type is StagingJaxprTrace
assertion error) -- and I'm not sure if it wipes earlier gradientsSo I went to abstract it into a function transform
This gives an AssertionError about
--> 291 assert self.master.trace_type is StagingJaxprTrace
both inside and outside of flax.nn.module definitionsBeta Was this translation helpful? Give feedback.
All reactions