Skip to content

Commit

Permalink
Added pytype None checks to accumulators.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617156170
  • Loading branch information
KfacJaxDev authored and KfacJaxDev committed Aug 16, 2024
1 parent 9623ab4 commit 6c700d4
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions kfac_jax/_src/utils/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ class WeightedMovingAverage(Generic[TArrayTree], misc.State):

@property
def ndim(self) -> int:
assert self.raw_value is not None
return self.raw_value.ndim

@property
def shape(self) -> Shape:
assert self.raw_value is not None
return self.raw_value.shape

@property
def dtype(self) -> DType:
assert self.raw_value is not None
return self.raw_value.dtype

@property
Expand Down Expand Up @@ -148,12 +151,12 @@ def __init__(
self._multi_device = multi_device

@property
def accumulator(self) -> TArrayTree:
def accumulator(self) -> TArrayTree | None:
"""The current value of the underlying not-normalized accumulator."""
return self._accumulator

@property
def weight(self) -> Numeric:
def weight(self) -> Numeric | None:
"""The current normalization weight of the underlying accumulator."""
return self._weight

Expand All @@ -163,7 +166,7 @@ def multi_device(self) -> bool:
return self._multi_device

@property
def value(self) -> TArrayTree:
def value(self) -> TArrayTree | None:
"""The current normalized value of the accumulator."""

if types.tree_is_empty(self.accumulator):
Expand All @@ -179,7 +182,7 @@ def clear(self) -> None:
self._accumulator = None
self._weight = None

def value_and_clear(self) -> TArrayTree:
def value_and_clear(self) -> TArrayTree | None:
"""Retrieves the normalized value of the accumulator and clears it."""

value = self.value
Expand Down

0 comments on commit 6c700d4

Please sign in to comment.