Skip to content

Commit

Permalink
Fixing a bug with self._value_func, which was returning the `func_s…
Browse files Browse the repository at this point in the history
…tate`, rather than just the value of the loss, when `value_func_has_aux=False` and `value_func_has_state=True`.

Closes #131

PiperOrigin-RevId: 576068393
  • Loading branch information
botev authored and KfacJaxDev committed Oct 24, 2023
1 parent b707ab1 commit c701634
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __init__(
self._value_func_has_rng = value_func_has_rng
self._value_func: ValueFunc = convert_value_and_grad_to_value_func(
value_and_grad_func,
has_aux=value_func_has_aux,
has_aux=value_func_has_aux or value_func_has_state,
)
self._l2_reg = jnp.asarray(l2_reg)
self._use_adaptive_learning_rate = use_adaptive_learning_rate
Expand Down

0 comments on commit c701634

Please sign in to comment.