Skip to content

Commit

Permalink
Included more united tests to improve code coverage (#253)
Browse files Browse the repository at this point in the history
* model version for Potential class is added

* model version for Potential class is modified

* Enable the smooth version of Spherical Bessel function in TensorNet

* max_n, max_l for SphericalBessel radial basis functions are included in TensorNet class

* adding united tests for improving the coverage score

* little clean up in _so3.py and so3.py

* remove unnecessary data storage in dgl graphs

* update pymatgen version to fix the bug

* refractor all include_states into include_state for consistency

* change include_states into include_state in test_graph_conv.py

* Ensure the state attr from molecule graph is consistent with matgl.float_th and including linear layer in TensorNet to match the original implementations

* Fix the jupyter-notebook for M3GNet training

* included more united tests to improve code coverage
  • Loading branch information
kenko911 authored May 6, 2024
1 parent 2e84398 commit 64f5f8c
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 7 deletions.
2 changes: 0 additions & 2 deletions src/matgl/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def collate_fn_pes(batch, include_stress: bool = True, include_line_graph: bool
if include_magmom:
return g, torch.squeeze(lat), l_g, state_attr, e, f, s, m
return g, torch.squeeze(lat), l_g, state_attr, e, f, s
if include_magmom:
return g, torch.squeeze(lat), state_attr, e, f, s, m
return g, torch.squeeze(lat), state_attr, e, f, s


Expand Down
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,39 @@ def Mo():
return Structure(Lattice.cubic(3.17), ["Mo", "Mo"], [[0.01, 0, 0], [0.5, 0.5, 0.5]])


@pytest.fixture(scope="session")
def MoS2():
return Structure(
[
[3.18430383, 0.0, 1.9498237464610788e-16],
[-1.5921519149999994, 2.757688010148085, 1.9498237464610788e-16],
[0.0, 0.0, 19.44629514],
],
[
"Mo",
"Mo",
"Mo",
"S",
"S",
"S",
"S",
"S",
"S",
],
[
[0.00000000e00, 0.00000000e00, 1.94419205e01],
[1.59215192e00, 9.19229337e-01, 6.47772374e00],
[4.44089210e-16, 1.83845867e00, 1.29598221e01],
[0.00000000e00, 0.00000000e00, 4.92566372e00],
[1.59215192e00, 9.19229337e-01, 1.14077621e01],
[4.44089210e-16, 1.83845867e00, 1.78898605e01],
[0.00000000e00, 0.00000000e00, 8.02996293e00],
[1.59215192e00, 9.19229337e-01, 1.45120613e01],
[4.44089210e-16, 1.83845867e00, 1.54786455e00],
],
)


@pytest.fixture(scope="session")
def graph_Mo(Mo):
return get_graph(Mo, 5.0)
Expand Down
10 changes: 5 additions & 5 deletions tests/ext/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ def test_get_graph_from_atoms_mol():
assert np.allclose(state, [0.0, 0.0])


def test_molecular_dynamics(MoS):
def test_molecular_dynamics(MoS2):
pot = load_model("pretrained_models/M3GNet-MP-2021.2.8-PES/")
for ensemble in ["nvt", "nvt_langevin", "nvt_andersen", "npt", "npt_berendsen", "npt_nose_hoover"]:
md = MolecularDynamics(MoS, potential=pot, ensemble=ensemble, taut=0.1, taup=0.1, compressibility_au=10)
md = MolecularDynamics(MoS2, potential=pot, ensemble=ensemble, taut=0.1, taup=0.1, compressibility_au=10)
md.run(10)
assert md.dyn is not None
md.set_atoms(MoS)
md = MolecularDynamics(MoS, potential=pot, ensemble=ensemble, taut=None, taup=None, compressibility_au=10)
md.set_atoms(MoS2)
md = MolecularDynamics(MoS2, potential=pot, ensemble=ensemble, taut=None, taup=None, compressibility_au=10)
md.run(10)
with pytest.raises(ValueError, match="Ensemble not supported"):
MolecularDynamics(MoS, potential=pot, ensemble="notanensemble")
MolecularDynamics(MoS2, potential=pot, ensemble="notanensemble")
139 changes: 139 additions & 0 deletions tests/graph/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,142 @@ def test_mgl_dataloader_with_magmom(self, LiFePO4, BaNiO3):
assert len(train_loader) == 8
assert len(val_loader) == 1
assert len(test_loader) == 1

def test_mgl_dataloader_without_collate_fn(self, LiFePO4, BaNiO3):
structures = [LiFePO4, BaNiO3] * 10
energies = np.zeros(20).tolist()
f1 = np.zeros((28, 3)).tolist()
f2 = np.zeros((10, 3)).tolist()
s = np.zeros((3, 3)).tolist()
m1 = np.zeros(28).tolist()
m2 = np.zeros(10).tolist()
np.zeros((3, 3)).tolist()
forces = [f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2]
stresses = [s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s]
magmoms = [m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2]

labels = {
"energies": energies,
"forces": forces,
"stresses": stresses,
"magmoms": magmoms,
}
element_types = get_element_list(structures)
cry_graph = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(
threebody_cutoff=3.0,
structures=structures,
converter=cry_graph,
labels=labels,
include_line_graph=True,
clear_processed=True,
save_cache=False,
)

train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
# This modification is required for M3GNet property dataset
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
batch_size=2,
num_workers=1,
)

assert len(train_loader) == 8
assert len(val_loader) == 1
assert len(test_loader) == 1

labels.pop("magmoms")
dataset = MGLDataset(
threebody_cutoff=3.0,
structures=structures,
converter=cry_graph,
labels=labels,
include_line_graph=True,
clear_processed=True,
save_cache=False,
)

train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
# This modification is required for M3GNet property dataset
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
batch_size=2,
num_workers=1,
)

assert len(train_loader) == 8
assert len(val_loader) == 1
assert len(test_loader) == 1

labels.pop("stresses")
dataset = MGLDataset(
threebody_cutoff=3.0,
structures=structures,
converter=cry_graph,
labels=labels,
include_line_graph=True,
clear_processed=True,
save_cache=False,
)

train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
# This modification is required for M3GNet property dataset
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
batch_size=2,
num_workers=1,
)

assert len(train_loader) == 8
assert len(val_loader) == 1
assert len(test_loader) == 1
labels.pop("forces")
dataset = MGLDataset(
threebody_cutoff=3.0,
structures=structures,
converter=cry_graph,
labels=labels,
include_line_graph=True,
clear_processed=True,
save_cache=False,
)

train_data, val_data, test_data = split_dataset(
dataset,
frac_list=[0.8, 0.1, 0.1],
shuffle=True,
random_state=42,
)
# This modification is required for M3GNet property dataset
train_loader, val_loader, test_loader = MGLDataLoader(
train_data=train_data,
val_data=val_data,
test_data=test_data,
batch_size=2,
num_workers=1,
)

assert len(train_loader) == 8
assert len(val_loader) == 1
assert len(test_loader) == 1
6 changes: 6 additions & 0 deletions tests/utils/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ def test_generate_clebsch_gordan_rsh():
expected_cg_rsh_0 = torch.eye(1).unsqueeze(0) # Expected cg_rsh result
assert torch.allclose(cg_rsh_0, expected_cg_rsh_0, atol=1e-4) # Use torch.allclose for numerical comparisons

cg_rsh_0_without_pi = generate_clebsch_gordan_rsh(lmax_0, False)
expected_cg_rsh_0_without_pi = torch.eye(1).unsqueeze(0) # Expected cg_rsh without parity invariance result
assert torch.allclose(
cg_rsh_0_without_pi, expected_cg_rsh_0_without_pi, atol=1e-4
) # Use torch.allclose for numerical comparisons

# Test case 2: lmax = 1
lmax_1 = 1
cg_rsh_1 = generate_clebsch_gordan_rsh(lmax_1)
Expand Down

0 comments on commit 64f5f8c

Please sign in to comment.