Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 15, 2024
1 parent f873494 commit dfeb257
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 29 deletions.
7 changes: 5 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ def state_dict(self) -> Dict[str, Any]:
"loss_reduction": self._loss_func.reduction,
# Attributes
"progressbar": self._progressbar,
"shape": self._shape,
"shape": self.shape,
"seed": self._seed,
"fisher_type": self._fisher_type,
"mc_samples": self._mc_samples,
Expand Down Expand Up @@ -1137,7 +1137,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]):

# Set attributes
self._progressbar = state_dict["progressbar"]
self._shape = state_dict["shape"]
self.shape = state_dict["shape"]
self._seed = state_dict["seed"]
self._fisher_type = state_dict["fisher_type"]
self._mc_samples = state_dict["mc_samples"]
Expand Down Expand Up @@ -1182,6 +1182,9 @@ def from_state_dict(
Returns:
Linear operator of KFAC approximation.
Raises:
RuntimeError: If the check for deterministic behavior fails.
"""
loss_func = {
"MSELoss": MSELoss,
Expand Down
6 changes: 3 additions & 3 deletions test/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,10 +683,10 @@ def test_KFAC_inverse_save_and_load_state_dict():
# save state dict
state_dict = inv_kfac.state_dict()

# create new inverse KFAC with different linop input and try to load state dict
wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)])
inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac)
with raises(ValueError, match="mismatch"):
# create new inverse KFAC with different linop input and try to load state dict
wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)])
inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac)
inv_kfac_wrong.load_state_dict(state_dict)

# create new inverse KFAC and load state dict
Expand Down
48 changes: 24 additions & 24 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,37 +1243,37 @@ def test_save_and_load_state_dict():
# save state dict
state_dict = kfac.state_dict()

# create new KFAC with different loss function and try to load state dict
kfac_new = KFACLinearOperator(
model,
CrossEntropyLoss(),
params,
[(X, y)],
)
with raises(ValueError, match="loss"):
# create new KFAC with different loss function and try to load state dict
kfac_new = KFACLinearOperator(
model,
CrossEntropyLoss(),
params,
[(X, y)],
)
kfac_new.load_state_dict(state_dict)

# create new KFAC with different loss reduction and try to load state dict
kfac_new = KFACLinearOperator(
model,
MSELoss(),
params,
[(X, y)],
)
with raises(ValueError, match="reduction"):
# create new KFAC with different loss reduction and try to load state dict
kfac_new = KFACLinearOperator(
model,
MSELoss(),
params,
[(X, y)],
)
kfac_new.load_state_dict(state_dict)

# create new KFAC with different model and try to load state dict
wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out))
wrong_params = list(wrong_model.parameters())
kfac_new = KFACLinearOperator(
wrong_model,
MSELoss(reduction="sum"),
wrong_params,
[(X, y)],
loss_average=None,
)
with raises(RuntimeError, match="loading state_dict"):
# create new KFAC with different model and try to load state dict
wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out))
wrong_params = list(wrong_model.parameters())
kfac_new = KFACLinearOperator(
wrong_model,
MSELoss(reduction="sum"),
wrong_params,
[(X, y)],
loss_average=None,
)
kfac_new.load_state_dict(state_dict)

# create new KFAC and load state dict
Expand Down

0 comments on commit dfeb257

Please sign in to comment.