Skip to content

Commit

Permalink
Use compare_state_dicts everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed May 23, 2024
1 parent fb6ac4b commit 1cf85bd
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 51 deletions.
26 changes: 3 additions & 23 deletions test/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from math import sqrt
from test.utils import cast_input
from test.utils import cast_input, compare_state_dicts
from typing import Iterable, List, Tuple, Union

import torch
Expand Down Expand Up @@ -657,25 +657,6 @@ def test_KFAC_inverse_damped_torch_matvec(
report_nonclose(inv_KFAC @ x.cpu().numpy(), inv_KFAC_x.cpu().numpy())


def compare_state_dicts(state_dict: dict, state_dict_new: dict):
"""Compare two state dicts recursively."""
assert len(state_dict) == len(state_dict_new)
for value, value_new in zip(state_dict.values(), state_dict_new.values()):
if isinstance(value, torch.Tensor):
assert torch.allclose(value, value_new)
elif isinstance(value, dict):
compare_state_dicts(value, value_new)
elif isinstance(value, tuple):
assert len(value) == len(value_new)
assert all(isinstance(v, type(v2)) for v, v2 in zip(value, value_new))
assert all(
torch.allclose(torch.as_tensor(v), torch.as_tensor(v2))
for v, v2 in zip(value, value_new)
)
else:
assert value == value_new


def test_KFAC_inverse_save_and_load_state_dict():
"""Test that KFACInverseLinearOperator can be saved and loaded from state dict."""
torch.manual_seed(0)
Expand Down Expand Up @@ -713,15 +694,14 @@ def test_KFAC_inverse_save_and_load_state_dict():
# create new inverse KFAC and load state dict
inv_kfac_new = KFACInverseLinearOperator(kfac)
inv_kfac_new.load_state_dict(torch.load("inv_kfac_state_dict.pt"))
# clean up
os.remove("inv_kfac_state_dict.pt")

# check that the two inverse KFACs are equal
compare_state_dicts(inv_kfac.state_dict(), inv_kfac_new.state_dict())
test_vec = torch.rand(inv_kfac.shape[1])
report_nonclose(inv_kfac @ test_vec, inv_kfac_new @ test_vec)

# clean up
os.remove("inv_kfac_state_dict.pt")


def test_KFAC_inverse_from_state_dict():
"""Test that KFACInverseLinearOperator can be created from state dict."""
Expand Down
32 changes: 5 additions & 27 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
WeightShareModel,
binary_classification_targets,
classification_targets,
compare_state_dicts,
ggn_block_diagonal,
regression_targets,
)
Expand Down Expand Up @@ -1288,26 +1289,14 @@ def test_save_and_load_state_dict():
check_deterministic=False, # turn off to avoid computing KFAC again
)
kfac_new.load_state_dict(load("kfac_state_dict.pt"))
# clean up
os.remove("kfac_state_dict.pt")

# check that the two KFACs are equal
assert len(kfac.state_dict()) == len(kfac_new.state_dict())
for value, value_new in zip(
kfac.state_dict().values(), kfac_new.state_dict().values()
):
if isinstance(value, Tensor):
assert allclose(value, value_new)
elif isinstance(value, dict):
for key, val in value.items():
assert allclose(val, value_new[key])
else:
assert value == value_new

compare_state_dicts(kfac.state_dict(), kfac_new.state_dict())
test_vec = rand(kfac.shape[1])
report_nonclose(kfac @ test_vec, kfac_new @ test_vec)

# clean up
os.remove("kfac_state_dict.pt")


def test_from_state_dict():
"""Test that KFACLinearOperator can be created from state dict."""
Expand All @@ -1334,17 +1323,6 @@ def test_from_state_dict():
kfac_new = KFACLinearOperator.from_state_dict(state_dict, model, params, [(X, y)])

# check that the two KFACs are equal
assert len(kfac.state_dict()) == len(kfac_new.state_dict())
for value, value_new in zip(
kfac.state_dict().values(), kfac_new.state_dict().values()
):
if isinstance(value, Tensor):
assert allclose(value, value_new)
elif isinstance(value, dict):
for key, val in value.items():
assert allclose(val, value_new[key])
else:
assert value == value_new

compare_state_dicts(kfac.state_dict(), kfac_new.state_dict())
test_vec = rand(kfac.shape[1])
report_nonclose(kfac @ test_vec, kfac_new @ test_vec)
39 changes: 38 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from numpy import eye, ndarray
from torch import Tensor, cat, cuda, device, dtype, from_numpy, rand, randint
from torch import (
Tensor,
allclose,
as_tensor,
cat,
cuda,
device,
dtype,
from_numpy,
rand,
randint,
)
from torch.nn import (
AdaptiveAvgPool2d,
BCEWithLogitsLoss,
Expand Down Expand Up @@ -367,3 +378,29 @@ def batch_size_fn(X: MutableMapping) -> int:
batch_size: The first dimension size of the tensor.
"""
return X["x"].shape[0]


def compare_state_dicts(state_dict: dict, state_dict_new: dict):
"""Compare two state dicts recursively.
Args:
state_dict (dict): The first state dict to compare.
state_dict_new (dict): The second state dict to compare.
Raises:
AssertionError: If the state dicts are not equal.
"""
assert len(state_dict) == len(state_dict_new)
for value, value_new in zip(state_dict.values(), state_dict_new.values()):
if isinstance(value, Tensor):
assert allclose(value, value_new)
elif isinstance(value, dict):
compare_state_dicts(value, value_new)
elif isinstance(value, tuple):
assert len(value) == len(value_new)
assert all(isinstance(v, type(v2)) for v, v2 in zip(value, value_new))
assert all(
allclose(as_tensor(v), as_tensor(v2)) for v, v2 in zip(value, value_new)
)
else:
assert value == value_new

0 comments on commit 1cf85bd

Please sign in to comment.