Skip to content

Commit

Permalink
Add a general torch CompositionModel
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jul 2, 2024
1 parent 4e64c26 commit c9a93a4
Showing 1 changed file with 194 additions and 2 deletions.
196 changes: 194 additions & 2 deletions src/metatrain/utils/composition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,201 @@
from typing import List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import ModelOutput, System

from metatrain.utils.data import Dataset, get_atomic_types
from .data import Dataset, DatasetInfo, get_atomic_types
from .jsonschema import validate


class CompositionModel(torch.nn.Module):
"""Calculate the energy based on the stoichiometry in a system.
:param model_hypers: A dictionary of model hyperparameters. The paramater is ignored
and is only present to be consistent with the general model API.
:param dataset_info: An object containing information about the dataset, including
target quantities and atomic types.
:raises ValueError: If any target quantity in the dataset info is not an energy-like
quantity.
"""

def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo):
super().__init__()

Check warning on line 24 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L24

Added line #L24 was not covered by tests

# `model_hypers` should be an empty dictionary
validate(

Check warning on line 27 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L27

Added line #L27 was not covered by tests
instance=model_hypers,
schema={"type" : "object", "additionalProperties": False,},
)

# Check capabilities
for target in dataset_info.targets.values():
if target.quantity != "energy":
raise ValueError(

Check warning on line 35 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L33-L35

Added lines #L33 - L35 were not covered by tests
"CompositionModel only supports energy-like outputs, but a "
f"{target.quantity} output was provided."
)

self.dataset_info = dataset_info
self.atomic_types = sorted(dataset_info.atomic_types)
self._weights: Dict[str, torch.Tensor[float]] = {}

Check warning on line 42 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L40-L42

Added lines #L40 - L42 were not covered by tests

def train(
self,
datasets: List[Union[Dataset, torch.utils.data.Subset]],
) -> None:
"""Train/fit the composition weights for the datasets.
:param datasets: Datasets to calculate the composition weights for.
:raises ValueError: If the provided datasets contain unknown atomic types.
:raises RuntimeError: If the linear system to calculate the composition weights
cannot be solved.
"""
if not isinstance(datasets, list):
datasets = [datasets]

Check warning on line 56 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L55-L56

Added lines #L55 - L56 were not covered by tests

missing_types = sorted(get_atomic_types(datasets) - set(self.atomic_types))
if missing_types:
raise ValueError(

Check warning on line 60 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L58-L60

Added lines #L58 - L60 were not covered by tests
f"Provided `datasets` contains unknown atomic types {missing_types}. "
f"Known types from initilaization are {self.atomic_types}."
)

# Compute weights for each target in the dataset info
for target_key in self.dataset_info.targets.keys():

Check warning on line 66 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L66

Added line #L66 was not covered by tests

# CAVE: What if sample does not contain `target_key`
targets = torch.stack(

Check warning on line 69 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L69

Added line #L69 was not covered by tests
[
sample[target_key].block().values
for dataset in datasets
for sample in dataset
]
)

# remove component and property dimensions
targets = targets.squeeze(dim=(1, 2))

Check warning on line 78 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L78

Added line #L78 was not covered by tests

structure_list = [

Check warning on line 80 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L80

Added line #L80 was not covered by tests
sample["system"] for dataset in datasets for sample in dataset
]

dtype = structure_list[0].positions.dtype
composition_features = torch.zeros(

Check warning on line 85 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L84-L85

Added lines #L84 - L85 were not covered by tests
(len(structure_list), len(self.atomic_types)), dtype=dtype
)
for i_structure, structure in enumerate(structure_list):
for i_types, atomic_type in enumerate(self.atomic_types):
composition_features[i_structure, i_types] = torch.sum(

Check warning on line 90 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L88-L90

Added lines #L88 - L90 were not covered by tests
structure.types == atomic_type
)

regularizer = 1e-20
while regularizer:
if regularizer > 1e5:
raise RuntimeError(

Check warning on line 97 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L94-L97

Added lines #L94 - L97 were not covered by tests
"Failed to solve the linear system to calculate the "
"composition weights. The dataset is probably too small or "
"ill-conditioned."
)
try:
self._weights[target_key] = torch.linalg.solve(

Check warning on line 103 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L102-L103

Added lines #L102 - L103 were not covered by tests
composition_features.T @ composition_features
+ regularizer
* torch.eye(
composition_features.shape[1],
dtype=composition_features.dtype,
device=composition_features.device,
),
composition_features.T @ targets,
)
break
except torch._C._LinAlgError:
regularizer *= 10.0

Check warning on line 115 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L113-L115

Added lines #L113 - L115 were not covered by tests

def restart(self, dataset_info: DatasetInfo) -> "CompositionModel":
"""Restart the model with a new dataset info.
:param dataset_info: New dataset information to be used.
"""
return self({}, self.dataset_info.union(dataset_info))

Check warning on line 122 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L122

Added line #L122 was not covered by tests

def forward(
self,
systems: List[System],
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[Labels] = None,
) -> Dict[str, TensorMap]:
"""Compute the targets for each system based on the composition weights.
:param systems: List of systems to calculate the energy per atom.
:param outputs: Dictionary containing the model outputs.
:param selected_atoms: Optional selection of atoms for which to compute the
targets.
:returns: A dictionary with the computed targets for each system.
:raises ValueError: If no weights have been computed or if `outputs` keys
contain unsupported keys.
:raises NotImplementedError: If `selected_atoms` is provided (not implemented).
"""

if not self._weights:
raise ValueError("No weights. Call `compute_weights` method first.")

Check warning on line 144 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L143-L144

Added lines #L143 - L144 were not covered by tests

if outputs.keys() != self._weights.keys():
raise ValueError(

Check warning on line 147 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L146-L147

Added lines #L146 - L147 were not covered by tests
f"`outputs` keys ({', '.join(outputs.keys())}) contain unsupported "
f"keys. Supported keys are ({', '.join(self._weights.keys())})."
)

if selected_atoms is not None:
raise NotImplementedError("`selected_atoms` is not implemented.")

Check warning on line 153 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L152-L153

Added lines #L152 - L153 were not covered by tests

# Compute the targets for each system by adding the composition weights times
# number of atoms per atomic type.
targets_out = {}
for target_key, target in self.dataset_info.targets.items():
weights = self._weights[target_key]
targets: List[float] = []
sample_values: List[List[int]] = []

Check warning on line 161 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L157-L161

Added lines #L157 - L161 were not covered by tests

for i_system, system in enumerate(systems):
target_per_atom = torch.zeros(len(system))

Check warning on line 164 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L163-L164

Added lines #L163 - L164 were not covered by tests

for i_type, atomic_type in enumerate(self.atomic_types):
target_per_atom[atomic_type == system.types] = weights[i_type]

Check warning on line 167 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L166-L167

Added lines #L166 - L167 were not covered by tests

if target.per_atom:
targets += target_per_atom.tolist()
sample_values += [

Check warning on line 171 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L169-L171

Added lines #L169 - L171 were not covered by tests
[i_system, i_atom] for i_atom in range(len(system))
]

else:
targets += [torch.sum(target_per_atom).tolist()]
sample_values += [[i_system]]

Check warning on line 177 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L176-L177

Added lines #L176 - L177 were not covered by tests

# Add metadata to the output
if target.per_atom:
sample_names = ["system", "atom"]

Check warning on line 181 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L180-L181

Added lines #L180 - L181 were not covered by tests
else:
sample_names = ["system"]

Check warning on line 183 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L183

Added line #L183 was not covered by tests

block = TensorBlock(

Check warning on line 185 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L185

Added line #L185 was not covered by tests
values=torch.tensor(targets).reshape(-1, 1),
samples=Labels(
sample_names,
torch.tensor(sample_values),
),
components=[],
properties=Labels(target_key, torch.tensor([[0]])),
)
targets_out[target_key] = TensorMap(

Check warning on line 194 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L194

Added line #L194 was not covered by tests
keys=Labels("_", torch.tensor([[0]])), blocks=[block]
)

return targets_out

Check warning on line 198 in src/metatrain/utils/composition.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/composition.py#L198

Added line #L198 was not covered by tests


def calculate_composition_weights(
Expand Down

0 comments on commit c9a93a4

Please sign in to comment.