Skip to content

Commit

Permalink
Add TNT blocks to kfac_jax.
Browse files Browse the repository at this point in the history
This CL adds TNT blocks to kfac_jax. TNT blocks are a type of curvature approximation that is generally imposes less structure knowledge of the computation, and mainly relies on the array shape of the parameters.

It also adds a RepatedDenseKroneckerFactored which extends the usual DenseTwoKroneckerFacored to the case where a dense layer is applied in parallel over an axis of inputs (like a time axis in sequence models).

PiperOrigin-RevId: 665965233
  • Loading branch information
botev authored and KfacJaxDev committed Aug 21, 2024
1 parent 9618725 commit 60644ef
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 0 deletions.
8 changes: 8 additions & 0 deletions kfac_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,16 @@
KroneckerFactored = curvature_blocks.KroneckerFactored
NaiveDiagonal = curvature_blocks.NaiveDiagonal
NaiveFull = curvature_blocks.NaiveFull
NaiveTNT = curvature_blocks.NaiveTNT
DenseDiagonal = curvature_blocks.DenseDiagonal
DenseFull = curvature_blocks.DenseFull
DenseTwoKroneckerFactored = curvature_blocks.DenseTwoKroneckerFactored
RepeatedDenseKroneckerFactored = curvature_blocks.RepeatedDenseKroneckerFactored
DenseTNT = curvature_blocks.DenseTNT
Conv2DDiagonal = curvature_blocks.Conv2DDiagonal
Conv2DFull = curvature_blocks.Conv2DFull
Conv2DTwoKroneckerFactored = curvature_blocks.Conv2DTwoKroneckerFactored
Conv2DTNT = curvature_blocks.Conv2DTNT
ScaleAndShiftDiagonal = curvature_blocks.ScaleAndShiftDiagonal
ScaleAndShiftFull = curvature_blocks.ScaleAndShiftFull
set_max_parallel_elements = curvature_blocks.set_max_parallel_elements
Expand Down Expand Up @@ -165,12 +169,16 @@
"KroneckerFactored",
"NaiveDiagonal",
"NaiveFull",
"NaiveTNT",
"DenseDiagonal",
"DenseFull",
"DenseTwoKroneckerFactored",
"RepeatedDenseKroneckerFactored",
"DenseTNT",
"Conv2DDiagonal",
"Conv2DFull",
"Conv2DTwoKroneckerFactored",
"Conv2DTNT",
"ScaleAndShiftDiagonal",
"ScaleAndShiftFull",
"set_max_parallel_elements",
Expand Down
Loading

0 comments on commit 60644ef

Please sign in to comment.