Skip to content

Commit

Permalink
Fix wrong shapes in ase stress parsers
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 25, 2024
1 parent 26a5297 commit cdce03a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/data/readers/targets/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _read_virial_stress_ase(
if is_virial:
values *= -1
else: # is stress
values *= volumes
values *= volumes.reshape(-1, 1, 1)

samples = Labels(["sample"], torch.tensor([[s] for s in range(n_structures)]))

Expand Down
29 changes: 20 additions & 9 deletions tests/utils/data/targets/test_targets_ase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Here we check only the correct values. Correct shape and metadata will be checked
within `test_readers.py`"""
from typing import List

import ase.io
import pytest
import torch
Expand Down Expand Up @@ -27,17 +29,22 @@ def ase_system() -> ase.Atoms:
return atoms


def ase_systems() -> List[ase.Atoms]:
return [ase_system(), ase_system()]


def test_read_energy_ase(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)

filename = "structures.xyz"

structures = ase_system()
structures = ase_systems()
ase.io.write(filename, structures)

result = read_energy_ase(filename=filename, key="true_energy")

expected = torch.tensor([[structures.info["true_energy"]]])
l_expected = [a.info["true_energy"] for a in structures]
expected = torch.tensor(l_expected).reshape(-1, 1)
torch.testing.assert_close(result.values, expected)


Expand All @@ -46,12 +53,13 @@ def test_read_forces_ase(monkeypatch, tmp_path):

filename = "structures.xyz"

structures = ase_system()
structures = ase_systems()
ase.io.write(filename, structures)

result = read_forces_ase(filename=filename, key="forces")

expected = -torch.tensor(structures.get_array("forces")).reshape(-1, 3, 1)
l_expected = [atoms.get_array("forces") for atoms in structures]
expected = -torch.tensor(l_expected).reshape(-1, 3, 1)
torch.testing.assert_close(result.values, expected)


Expand All @@ -60,13 +68,15 @@ def test_read_stress_ase(monkeypatch, tmp_path):

filename = "structures.xyz"

structures = ase_system()
structures = ase_systems()
ase.io.write(filename, structures)

result = read_stress_ase(filename=filename, key="stress-3x3")

expected = torch.tensor(structures.info["stress-3x3"])
expected *= torch.tensor(structures.cell.volume)
l_expected = [atoms.info["stress-3x3"] for atoms in structures]
l_cell = [atoms.cell.volume for atoms in structures]
expected = torch.tensor(l_expected)
expected *= torch.tensor(l_cell).reshape(-1, 1, 1)
expected = expected.reshape(-1, 3, 3, 1)
torch.testing.assert_close(result.values, expected)

Expand All @@ -93,12 +103,13 @@ def test_read_virial_ase(monkeypatch, tmp_path):

filename = "structures.xyz"

structures = ase_system()
structures = ase_systems()
ase.io.write(filename, structures)

result = read_virial_ase(filename=filename, key="stress-3x3")

expected = -torch.tensor(structures.info["stress-3x3"])
l_expected = [atoms.info["stress-3x3"] for atoms in structures]
expected = -torch.tensor(l_expected)
expected = expected.reshape(-1, 3, 3, 1)
torch.testing.assert_close(result.values, expected)

Expand Down

0 comments on commit cdce03a

Please sign in to comment.