Skip to content

Commit

Permalink
Fix composition weights schema (#314)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: frostedoyster <[email protected]>
Co-authored-by: Filippo Bigi <[email protected]>
  • Loading branch information
3 people authored Jul 26, 2024
1 parent 81e5836 commit 80c50d3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 6 deletions.
9 changes: 7 additions & 2 deletions docs/src/architectures/soap-bpnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,15 @@ The parameters for training are
:param learning_rate: learning rate
:param log_interval: number of epochs that elapse between reporting new training results
:param checkpoint_interval: Interval to save a checkpoint to disk.
:param per_atom_targets: Specifies whether the model should be trained on a per-atom
:param fixed_composition_weights: allows to set fixed isolated atom energies from
outside. These are per target name and per (integer) atom type. For example,
``fixed_composition_weights: {"energy": {1: -396.0, 6: -500.0}, "mtt::U0": {1: 0.0,
6: 0.0}}`` sets the isolated atom energies for H (``1``) and O (``6``) to different
values for the two distinct targets.
:param per_atom_targets: specifies whether the model should be trained on a per-atom
loss. In that case, the logger will also output per-atom metrics for that target. In
any case, the final summary will be per-structure.
:param loss_weights: Specifies the weights to be used in the loss for each target. The
:param loss_weights: specifies the weights to be used in the loss for each target. The
weights should be a dictionary of floats, one for each target. All missing targets
are assigned a weight of 1.0.

Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
* (self.hypers["soap"]["max_angular"] + 1)
)

hypers_bpnn = self.hypers["bpnn"]
hypers_bpnn = {**self.hypers["bpnn"]}
hypers_bpnn["input_size"] = soap_size

if hypers_bpnn["layernorm"]:
Expand Down
10 changes: 8 additions & 2 deletions src/metatrain/experimental/soap_bpnn/schema-hypers.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,14 @@
"fixed_composition_weights": {
"type": "object",
"patternProperties": {
"^[0-9]+$": {
"type": "number"
"^.*$": {
"type": "object",
"propertyNames": {
"pattern": "^[0-9]+$"
},
"additionalProperties": {
"type": "number"
}
}
},
"additionalProperties": false
Expand Down
36 changes: 35 additions & 1 deletion src/metatrain/experimental/soap_bpnn/tests/test_functionality.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import metatensor.torch
import pytest
import torch
from jsonschema.exceptions import ValidationError
from metatensor.torch.atomistic import ModelOutput, System
from omegaconf import OmegaConf

from metatrain.experimental.soap_bpnn import SoapBpnn
from metatrain.utils.architectures import check_architecture_options
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict

from . import MODEL_HYPERS
from . import DEFAULT_HYPERS, MODEL_HYPERS


def test_prediction_subset_elements():
Expand Down Expand Up @@ -195,3 +199,33 @@ def test_output_per_atom():

assert outputs["energy"].block().samples.names == ["system", "atom"]
assert outputs["energy"].block().values.shape == (4, 1)


def test_fixed_composition_weights():
"""Tests the correctness of the json schema for fixed_composition_weights"""

hypers = DEFAULT_HYPERS.copy()
hypers["training"]["fixed_composition_weights"] = {
"energy": {
1: 1.0,
6: 0.0,
7: 0.0,
8: 0.0,
9: 3000.0,
}
}
hypers = OmegaConf.create(hypers)
check_architecture_options(
name="experimental.soap_bpnn", options=OmegaConf.to_container(hypers)
)


def test_fixed_composition_weights_error():
"""Test that only inputd of type Dict[str, Dict[int, float]] are allowed."""
hypers = DEFAULT_HYPERS.copy()
hypers["training"]["fixed_composition_weights"] = {"energy": {"H": 300.0}}
hypers = OmegaConf.create(hypers)
with pytest.raises(ValidationError, match=r"'H' does not match '\^\[0-9\]\+\$'"):
check_architecture_options(
name="experimental.soap_bpnn", options=OmegaConf.to_container(hypers)
)

0 comments on commit 80c50d3

Please sign in to comment.