Skip to content

Commit

Permalink
Added pytype None checks to accumulators.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617156170
  • Loading branch information
KfacJaxDev authored and KfacJaxDev committed Mar 26, 2024
1 parent fb6df67 commit 0ee8c4b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions kfac_jax/_src/utils/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def __init__(
self._multi_device = multi_device

@property
def accumulator(self) -> TArrayTree:
def accumulator(self) -> TArrayTree | None:
"""The current value of the underlying not-normalized accumulator."""
return self._accumulator

@property
def weight(self) -> Numeric:
def weight(self) -> Numeric | None:
"""The current normalization weight of the underlying accumulator."""
return self._weight

Expand All @@ -151,7 +151,7 @@ def multi_device(self) -> bool:
return self._multi_device

@property
def value(self) -> TArrayTree:
def value(self) -> TArrayTree | None:
"""The current normalized value of the accumulator."""

if types.tree_is_empty(self.accumulator):
Expand All @@ -167,7 +167,7 @@ def clear(self) -> None:
self._accumulator = None
self._weight = None

def value_and_clear(self) -> TArrayTree:
def value_and_clear(self) -> TArrayTree | None:
"""Retrieves the normalized value of the accumulator and clears it."""

value = self.value
Expand Down

0 comments on commit 0ee8c4b

Please sign in to comment.