Skip to content

Commit

Permalink
Add one more test
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 30, 2023
1 parent c6fb596 commit 22e98ca
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/metatensor_models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def forward(self, features: TensorMap) -> TensorMap:
properties=Labels.range("properties", output_values.shape[-1]),
)
)

return TensorMap(keys=features.keys, blocks=new_blocks)


Expand Down
25 changes: 25 additions & 0 deletions src/metatensor_models/soap_bpnn/tests/test_functionality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os

import ase
import rascaline.torch
import torch
import yaml

from metatensor_models.soap_bpnn import SoapBPNN


path = os.path.dirname(__file__)
hypers_path = os.path.join(path, "../default.yml")
dataset_path = os.path.join(path, "data/qm9_reduced_100.xyz")


def test_prediction_subset():
"""Tests that the model can predict on a subset
of the elements it was trained on."""

all_species = [1, 6, 7, 8, 9]
hypers = yaml.safe_load(open(hypers_path, "r"))
soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64)

structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
soap_bpnn([rascaline.torch.systems_to_torch(structure)])

0 comments on commit 22e98ca

Please sign in to comment.