From 2a779dc86ffdae5a8b71cdf286c783e8ab0fba87 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Tue, 2 Jul 2024 12:16:01 +0200 Subject: [PATCH] Add a general torch CompositionModel --- src/metatrain/utils/composition.py | 196 ++++++++++++++++++++++++++++- 1 file changed, 194 insertions(+), 2 deletions(-) diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/composition.py index c3780c281..274f13f39 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/composition.py @@ -1,9 +1,201 @@ -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ModelOutput, System -from metatrain.utils.data import Dataset, get_atomic_types +from .data import Dataset, DatasetInfo, get_atomic_types +from .jsonschema import validate + + +class CompositionModel(torch.nn.Module): + """Calculate the energy based on the stoichiometry in a system. + + :param model_hypers: A dictionary of model hyperparameters. The paramater is ignored + and is only present to be consistent with the general model API. + :param dataset_info: An object containing information about the dataset, including + target quantities and atomic types. + + :raises ValueError: If any target quantity in the dataset info is not an energy-like + quantity. + """ + + def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo): + super().__init__() + + # `model_hypers` should be an empty dictionary + validate( + instance=model_hypers, + schema={"type": "object", "additionalProperties": False}, + ) + + # Check capabilities + for target in dataset_info.targets.values(): + if target.quantity != "energy": + raise ValueError( + "CompositionModel only supports energy-like outputs, but a " + f"{target.quantity} output was provided." + ) + + self.dataset_info = dataset_info + self.atomic_types = sorted(dataset_info.atomic_types) + self._weights: Dict[str, torch.Tensor[float]] = {} + + def train( + self, + datasets: List[Union[Dataset, torch.utils.data.Subset]], + ) -> None: + """Train/fit the composition weights for the datasets. + + :param datasets: Datasets to calculate the composition weights for. + :raises ValueError: If the provided datasets contain unknown atomic types. + :raises RuntimeError: If the linear system to calculate the composition weights + cannot be solved. + """ + if not isinstance(datasets, list): + datasets = [datasets] + + missing_types = sorted(get_atomic_types(datasets) - set(self.atomic_types)) + if missing_types: + raise ValueError( + f"Provided `datasets` contains unknown atomic types {missing_types}. " + f"Known types from initilaization are {self.atomic_types}." + ) + + # Compute weights for each target in the dataset info + for target_key in self.dataset_info.targets.keys(): + + # CAVE: What if sample does not contain `target_key` + targets = torch.stack( + [ + sample[target_key].block().values + for dataset in datasets + for sample in dataset + ] + ) + + # remove component and property dimensions + targets = targets.squeeze(dim=(1, 2)) + + structure_list = [ + sample["system"] for dataset in datasets for sample in dataset + ] + + dtype = structure_list[0].positions.dtype + composition_features = torch.zeros( + (len(structure_list), len(self.atomic_types)), dtype=dtype + ) + for i_structure, structure in enumerate(structure_list): + for i_types, atomic_type in enumerate(self.atomic_types): + composition_features[i_structure, i_types] = torch.sum( + structure.types == atomic_type + ) + + regularizer = 1e-20 + while regularizer: + if regularizer > 1e5: + raise RuntimeError( + "Failed to solve the linear system to calculate the " + "composition weights. The dataset is probably too small or " + "ill-conditioned." + ) + try: + self._weights[target_key] = torch.linalg.solve( + composition_features.T @ composition_features + + regularizer + * torch.eye( + composition_features.shape[1], + dtype=composition_features.dtype, + device=composition_features.device, + ), + composition_features.T @ targets, + ) + break + except torch._C._LinAlgError: + regularizer *= 10.0 + + def restart(self, dataset_info: DatasetInfo) -> "CompositionModel": + """Restart the model with a new dataset info. + + :param dataset_info: New dataset information to be used. + """ + return self({}, self.dataset_info.union(dataset_info)) + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + """Compute the targets for each system based on the composition weights. + + :param systems: List of systems to calculate the energy per atom. + :param outputs: Dictionary containing the model outputs. + :param selected_atoms: Optional selection of atoms for which to compute the + targets. + :returns: A dictionary with the computed targets for each system. + + :raises ValueError: If no weights have been computed or if `outputs` keys + contain unsupported keys. + :raises NotImplementedError: If `selected_atoms` is provided (not implemented). + """ + + if not self._weights: + raise ValueError("No weights. Call `compute_weights` method first.") + + if outputs.keys() != self._weights.keys(): + raise ValueError( + f"`outputs` keys ({', '.join(outputs.keys())}) contain unsupported " + f"keys. Supported keys are ({', '.join(self._weights.keys())})." + ) + + if selected_atoms is not None: + raise NotImplementedError("`selected_atoms` is not implemented.") + + # Compute the targets for each system by adding the composition weights times + # number of atoms per atomic type. + targets_out = {} + for target_key, target in self.dataset_info.targets.items(): + weights = self._weights[target_key] + targets: List[float] = [] + sample_values: List[List[int]] = [] + + for i_system, system in enumerate(systems): + target_per_atom = torch.zeros(len(system)) + + for i_type, atomic_type in enumerate(self.atomic_types): + target_per_atom[atomic_type == system.types] = weights[i_type] + + if target.per_atom: + targets += target_per_atom.tolist() + sample_values += [ + [i_system, i_atom] for i_atom in range(len(system)) + ] + + else: + targets += [torch.sum(target_per_atom).tolist()] + sample_values += [[i_system]] + + # Add metadata to the output + if target.per_atom: + sample_names = ["system", "atom"] + else: + sample_names = ["system"] + + block = TensorBlock( + values=torch.tensor(targets).reshape(-1, 1), + samples=Labels( + sample_names, + torch.tensor(sample_values), + ), + components=[], + properties=Labels(target_key, torch.tensor([[0]])), + ) + targets_out[target_key] = TensorMap( + keys=Labels("_", torch.tensor([[0]])), blocks=[block] + ) + + return targets_out def calculate_composition_weights(