diff --git a/src/metatensor/models/utils/data/readers/targets/ase.py b/src/metatensor/models/utils/data/readers/targets/ase.py index 437fb1217..4486cde37 100644 --- a/src/metatensor/models/utils/data/readers/targets/ase.py +++ b/src/metatensor/models/utils/data/readers/targets/ase.py @@ -145,7 +145,7 @@ def _read_virial_stress_ase( if is_virial: values *= -1 else: # is stress - values *= volumes + values *= volumes.reshape(-1, 1, 1) samples = Labels(["sample"], torch.tensor([[s] for s in range(n_structures)])) diff --git a/tests/utils/data/targets/test_targets_ase.py b/tests/utils/data/targets/test_targets_ase.py index 94c7927bd..552b7e2f2 100644 --- a/tests/utils/data/targets/test_targets_ase.py +++ b/tests/utils/data/targets/test_targets_ase.py @@ -1,5 +1,7 @@ """Here we check only the correct values. Correct shape and metadata will be checked within `test_readers.py`""" +from typing import List + import ase.io import pytest import torch @@ -27,17 +29,22 @@ def ase_system() -> ase.Atoms: return atoms +def ase_systems() -> List[ase.Atoms]: + return [ase_system(), ase_system()] + + def test_read_energy_ase(monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) filename = "structures.xyz" - structures = ase_system() + structures = ase_systems() ase.io.write(filename, structures) result = read_energy_ase(filename=filename, key="true_energy") - expected = torch.tensor([[structures.info["true_energy"]]]) + l_expected = [a.info["true_energy"] for a in structures] + expected = torch.tensor(l_expected).reshape(-1, 1) torch.testing.assert_close(result.values, expected) @@ -46,12 +53,13 @@ def test_read_forces_ase(monkeypatch, tmp_path): filename = "structures.xyz" - structures = ase_system() + structures = ase_systems() ase.io.write(filename, structures) result = read_forces_ase(filename=filename, key="forces") - expected = -torch.tensor(structures.get_array("forces")).reshape(-1, 3, 1) + l_expected = [atoms.get_array("forces") for atoms in structures] + expected = -torch.tensor(l_expected).reshape(-1, 3, 1) torch.testing.assert_close(result.values, expected) @@ -60,13 +68,15 @@ def test_read_stress_ase(monkeypatch, tmp_path): filename = "structures.xyz" - structures = ase_system() + structures = ase_systems() ase.io.write(filename, structures) result = read_stress_ase(filename=filename, key="stress-3x3") - expected = torch.tensor(structures.info["stress-3x3"]) - expected *= torch.tensor(structures.cell.volume) + l_expected = [atoms.info["stress-3x3"] for atoms in structures] + l_cell = [atoms.cell.volume for atoms in structures] + expected = torch.tensor(l_expected) + expected *= torch.tensor(l_cell).reshape(-1, 1, 1) expected = expected.reshape(-1, 3, 3, 1) torch.testing.assert_close(result.values, expected) @@ -93,12 +103,13 @@ def test_read_virial_ase(monkeypatch, tmp_path): filename = "structures.xyz" - structures = ase_system() + structures = ase_systems() ase.io.write(filename, structures) result = read_virial_ase(filename=filename, key="stress-3x3") - expected = -torch.tensor(structures.info["stress-3x3"]) + l_expected = [atoms.info["stress-3x3"] for atoms in structures] + expected = -torch.tensor(l_expected) expected = expected.reshape(-1, 3, 3, 1) torch.testing.assert_close(result.values, expected)