Skip to content

Commit

Permalink
Test torch.save/load as well and fix order of equivalence checks
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 23, 2024
1 parent 04a3259 commit d5cecfc
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
15 changes: 10 additions & 5 deletions test/test_inverse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contains tests for ``curvlinops/inverse``."""

import os
from math import sqrt
from test.utils import cast_input
from typing import Iterable, List, Tuple, Union
Expand Down Expand Up @@ -666,7 +667,7 @@ def compare_state_dicts(state_dict: dict, state_dict_new: dict):
compare_state_dicts(value, value_new)
elif isinstance(value, tuple):
assert len(value) == len(value_new)
assert all([isinstance(v, type(v2)) for v, v2 in zip(value, value_new)])
assert all(isinstance(v, type(v2)) for v, v2 in zip(value, value_new))
assert all(
torch.allclose(torch.as_tensor(v), torch.as_tensor(v2))
for v, v2 in zip(value, value_new)
Expand Down Expand Up @@ -701,21 +702,25 @@ def test_KFAC_inverse_save_and_load_state_dict():

# save state dict
state_dict = inv_kfac.state_dict()
torch.save(state_dict, "inv_kfac_state_dict.pt")

# create new inverse KFAC with different linop input and try to load state dict
wrong_kfac = KFACLinearOperator(model, CrossEntropyLoss(), params, [(X, y)])
inv_kfac_wrong = KFACInverseLinearOperator(wrong_kfac)
with raises(ValueError, match="mismatch"):
inv_kfac_wrong.load_state_dict(state_dict)
inv_kfac_wrong.load_state_dict(torch.load(state_dict))

# create new inverse KFAC and load state dict
inv_kfac_new = KFACInverseLinearOperator(kfac)
inv_kfac_new.load_state_dict(state_dict)
inv_kfac_new.load_state_dict(torch.load(state_dict))

# check that the two inverse KFACs are equal
compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())
test_vec = torch.rand(inv_kfac.shape[1])
report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec)
compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())

# clean up
os.remove("inv_kfac_state_dict.pt")


def test_KFAC_inverse_from_state_dict():
Expand Down Expand Up @@ -746,6 +751,6 @@ def test_KFAC_inverse_from_state_dict():
inv_kfac_new = KFACInverseLinearOperator.from_state_dict(state_dict, kfac)

# check that the two inverse KFACs are equal
compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())
test_vec = torch.rand(kfac.shape[1])
report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec)
compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())
27 changes: 17 additions & 10 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contains tests for ``curvlinops.kfac``."""

import os
from test.cases import DEVICES, DEVICES_IDS
from test.utils import (
Conv2dModel,
Expand All @@ -20,7 +21,7 @@
from scipy.linalg import block_diag
from torch import Tensor, allclose, cat, cuda, device
from torch import eye as torch_eye
from torch import isinf, isnan, manual_seed, rand, rand_like, randperm
from torch import isinf, isnan, load, manual_seed, rand, rand_like, randperm, save
from torch.nn import (
BCEWithLogitsLoss,
CrossEntropyLoss,
Expand Down Expand Up @@ -1242,6 +1243,7 @@ def test_save_and_load_state_dict():

# save state dict
state_dict = kfac.state_dict()
save(state_dict, "kfac_state_dict.pt")

# create new KFAC with different loss function and try to load state dict
kfac_new = KFACLinearOperator(
Expand All @@ -1251,7 +1253,7 @@ def test_save_and_load_state_dict():
[(X, y)],
)
with raises(ValueError, match="loss"):
kfac_new.load_state_dict(state_dict)
kfac_new.load_state_dict(load(state_dict))

# create new KFAC with different loss reduction and try to load state dict
kfac_new = KFACLinearOperator(
Expand All @@ -1261,7 +1263,7 @@ def test_save_and_load_state_dict():
[(X, y)],
)
with raises(ValueError, match="reduction"):
kfac_new.load_state_dict(state_dict)
kfac_new.load_state_dict(load(state_dict))

# create new KFAC with different model and try to load state dict
wrong_model = Sequential(Linear(D_in, 10), ReLU(), Linear(10, D_out))
Expand All @@ -1274,7 +1276,7 @@ def test_save_and_load_state_dict():
loss_average=None,
)
with raises(RuntimeError, match="loading state_dict"):
kfac_new.load_state_dict(state_dict)
kfac_new.load_state_dict(load(state_dict))

# create new KFAC and load state dict
kfac_new = KFACLinearOperator(
Expand All @@ -1283,12 +1285,11 @@ def test_save_and_load_state_dict():
params,
[(X, y)],
loss_average=None,
check_deterministic=False, # turn off to avoid computing KFAC again
)
kfac_new.load_state_dict(load(state_dict))

# check that the two KFACs are equal
test_vec = rand(kfac.shape[1])
report_nonclose(kfac @ test_vec, kfac_new @ test_vec)

assert len(kfac.state_dict()) == len(kfac_new.state_dict())
for value, value_new in zip(
kfac.state_dict().values(), kfac_new.state_dict().values()
Expand All @@ -1301,6 +1302,12 @@ def test_save_and_load_state_dict():
else:
assert value == value_new

test_vec = rand(kfac.shape[1])
report_nonclose(kfac @ test_vec, kfac_new @ test_vec)

# clean up
os.remove("kfac_state_dict.pt")


def test_from_state_dict():
"""Test that KFACLinearOperator can be created from state dict."""
Expand All @@ -1327,9 +1334,6 @@ def test_from_state_dict():
kfac_new = KFACLinearOperator.from_state_dict(state_dict, model, params, [(X, y)])

# check that the two KFACs are equal
test_vec = rand(kfac.shape[1])
report_nonclose(kfac @ test_vec, kfac_new @ test_vec)

assert len(kfac.state_dict()) == len(kfac_new.state_dict())
for value, value_new in zip(
kfac.state_dict().values(), kfac_new.state_dict().values()
Expand All @@ -1341,3 +1345,6 @@ def test_from_state_dict():
assert allclose(val, value_new[key])
else:
assert value == value_new

test_vec = rand(kfac.shape[1])
report_nonclose(kfac @ test_vec, kfac_new @ test_vec)

0 comments on commit d5cecfc

Please sign in to comment.