diff --git a/kfac_jax/_src/utils/accumulators.py b/kfac_jax/_src/utils/accumulators.py index 259c076..fbbc20a 100644 --- a/kfac_jax/_src/utils/accumulators.py +++ b/kfac_jax/_src/utils/accumulators.py @@ -136,12 +136,12 @@ def __init__( self._multi_device = multi_device @property - def accumulator(self) -> TArrayTree: + def accumulator(self) -> TArrayTree | None: """The current value of the underlying not-normalized accumulator.""" return self._accumulator @property - def weight(self) -> Numeric: + def weight(self) -> Numeric | None: """The current normalization weight of the underlying accumulator.""" return self._weight @@ -151,7 +151,7 @@ def multi_device(self) -> bool: return self._multi_device @property - def value(self) -> TArrayTree: + def value(self) -> TArrayTree | None: """The current normalized value of the accumulator.""" if types.tree_is_empty(self.accumulator): @@ -167,7 +167,7 @@ def clear(self) -> None: self._accumulator = None self._weight = None - def value_and_clear(self) -> TArrayTree: + def value_and_clear(self) -> TArrayTree | None: """Retrieves the normalized value of the accumulator and clears it.""" value = self.value