Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This CL adds TNT blocks to kfac_jax. TNT blocks are a type of curvature approximation that is generally imposes less structure knowledge of the computation, and mainly relies on the array shape of the parameters. It also adds a RepatedDenseKroneckerFactored which extends the usual DenseTwoKroneckerFacored to the case where a dense layer is applied in parallel over an axis of inputs (like a time axis in sequence models). PiperOrigin-RevId: 662842744
- Loading branch information