Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a general torch CompositionModel #280

Merged
merged 27 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
95c9153
Add a general torch CompositionModel
PicoCentauri Jul 2, 2024
4791d2e
Finish composition model
frostedoyster Aug 16, 2024
e7ceecf
Merge branch 'main' into composition-model
frostedoyster Aug 16, 2024
7c9ac4a
Integrate with SOAP-BPNN
frostedoyster Aug 16, 2024
eb5515e
Also use the new CompositionModel in GAP
frostedoyster Aug 16, 2024
fc71849
Add test for `remove_composition`
frostedoyster Aug 16, 2024
9a7718b
Exclude `mtt::aux::` quantities from composition models
frostedoyster Aug 16, 2024
043effe
Remove composition from original SOAP-BPNN
frostedoyster Aug 18, 2024
bb756d7
Fix bug
frostedoyster Aug 18, 2024
f5b3527
Update metatensor
frostedoyster Aug 29, 2024
e297fad
`._module` -> `.module`
frostedoyster Aug 29, 2024
b6caa7b
Update dataset
frostedoyster Aug 29, 2024
32c24be
Fix alchemical model
frostedoyster Sep 1, 2024
0b35fed
Add tests for errors
frostedoyster Sep 1, 2024
661ebe5
Only warn if atomic types are present in the validation dataset but n…
frostedoyster Sep 2, 2024
a2445d9
Merge branch 'update-metatensor' into composition-model
frostedoyster Sep 2, 2024
b69b189
Fix test
frostedoyster Sep 2, 2024
021b7c2
Merge branch 'main' into composition-model
frostedoyster Sep 2, 2024
7a406d1
Debugg
frostedoyster Sep 3, 2024
ab41dc2
Merge branch 'main' into composition-model
Luthaf Sep 4, 2024
d29920c
Do not import metatensor operation on the top level
Luthaf Sep 4, 2024
b1fd48d
Selected atoms for composition model
frostedoyster Sep 5, 2024
92b4735
Merge branch 'composition-model' of https://github.com/lab-cosmo/meta…
frostedoyster Sep 5, 2024
fe8e536
Test selected atoms
frostedoyster Sep 5, 2024
269c93c
Merge branch 'main' into composition-model
frostedoyster Sep 5, 2024
112ed89
More testing
frostedoyster Sep 7, 2024
e285cbf
Merge branch 'composition-model' of https://github.com/lab-cosmo/meta…
frostedoyster Sep 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import tomli # Replace by tomllib from std library once docs are build with Python 3.11

import metatrain


# When importing metatensor-torch, this will change the definition of the classes
# to include the documentation
os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"
os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1"

import metatrain # noqa: E402


ROOT = os.path.abspath(os.path.join("..", ".."))

# We use a second (pseudo) sphinx project located in `docs/generate_examples` to run the
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from metatensor.learn.data import DataLoader

from ...utils.composition import calculate_composition_weights
from ...utils.data import (
CombinedDataLoader,
Dataset,
Expand All @@ -23,6 +22,7 @@
from ...utils.neighbor_lists import get_system_with_neighbor_lists
from ...utils.per_atom import average_by_num_atoms
from . import AlchemicalModel
from .utils.composition import calculate_composition_weights
from .utils.normalize import (
get_average_number_of_atoms,
get_average_number_of_neighbors,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List, Tuple, Union

import torch

from ....utils.data.dataset import Dataset, get_atomic_types


def calculate_composition_weights(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not needed anymore?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this needs to be kept for the alchemical model (notice that it changed directories), which works in a way that doesn't allow me to change things without changing the alchemical model code itself

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay!

datasets: Union[Dataset, List[Dataset]], property: str
) -> Tuple[torch.Tensor, List[int]]:
"""Calculate the composition weights for a dataset.

It assumes per-system properties.

:param dataset: Dataset to calculate the composition weights for.
:returns: Composition weights for the dataset, as well as the
list of species that the weights correspond to.
"""
if not isinstance(datasets, list):
datasets = [datasets]

# Note: `atomic_types` are sorted, and the composition weights are sorted as
# well, because the species are sorted in the composition features.
atomic_types = sorted(get_atomic_types(datasets))

targets = torch.stack(
[sample[property].block().values for dataset in datasets for sample in dataset]
)
targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions

total_num_structures = sum([len(dataset) for dataset in datasets])
dtype = datasets[0][0]["system"].positions.dtype
composition_features = torch.empty(
(total_num_structures, len(atomic_types)), dtype=dtype
)
structure_index = 0
for dataset in datasets:
for sample in dataset:
structure = sample["system"]
for j, s in enumerate(atomic_types):
composition_features[structure_index, j] = torch.sum(
structure.types == s
)
structure_index += 1

regularizer = 1e-20
while regularizer:
if regularizer > 1e5:
raise RuntimeError(
"Failed to solve the linear system to calculate the "
"composition weights. The dataset is probably too small "
"or ill-conditioned."
)
try:
solution = torch.linalg.solve(
composition_features.T @ composition_features
+ regularizer
* torch.eye(
composition_features.shape[1],
dtype=composition_features.dtype,
device=composition_features.device,
),
composition_features.T @ targets,
)
break
except torch._C._LinAlgError:
regularizer *= 10.0

return solution, atomic_types
22 changes: 20 additions & 2 deletions src/metatrain/experimental/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from metatrain.utils.data.dataset import DatasetInfo

from ...utils.composition import CompositionModel
from ...utils.export import export


Expand Down Expand Up @@ -127,6 +128,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
)
self._species_labels: TorchLabels = TorchLabels.empty("_")

self.composition_model = CompositionModel(
model_hypers={},
dataset_info=dataset_info,
)

def restart(self, dataset_info: DatasetInfo) -> "GAP":
raise ValueError("GAP does not allow restarting training")

Expand Down Expand Up @@ -201,8 +207,20 @@ def forward(
soap_features = TorchTensorMap(self._keys, soap_features.blocks())
output_key = list(outputs.keys())[0]
energies = self._subset_of_regressors_torch(soap_features)
out_tensor = self.apply_composition_weights(systems, energies)
return {output_key: out_tensor}
return_dict = {output_key: energies}

# apply composition model
composition_energies = self.composition_model(
systems, {output_key: ModelOutput("energy", per_atom=True)}, selected_atoms
)
composition_energies[output_key] = metatensor.torch.sum_over_samples(
composition_energies[output_key], "atom"
)
return_dict[output_key] = metatensor.torch.add(
return_dict[output_key], composition_energies[output_key]
)

return return_dict

def export(self) -> MetatensorAtomisticModel:
capabilities = ModelCapabilities(
Expand Down
31 changes: 6 additions & 25 deletions src/metatrain/experimental/gap/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from metatrain.utils.data import Dataset

from ...utils.composition import calculate_composition_weights
from ...utils.composition import remove_composition
from ...utils.data import check_datasets
from . import GAP
from .model import torch_tensor_map_to_core
Expand Down Expand Up @@ -52,10 +52,7 @@ def train(

# Calculate and set the composition weights:
logger.info("Calculating composition weights")
composition_weights, species = calculate_composition_weights(
train_datasets, target_name
)
model.set_composition_weights(target_name, composition_weights, species)
model.composition_model.train_model(train_datasets)

logger.info("Setting up data loaders")
if len(train_datasets[0][0][output_name].keys) > 1:
Expand All @@ -72,26 +69,10 @@ def train(
model._keys = train_y.keys
train_structures = [sample["system"] for sample in train_dataset]

logger.info("Fitting composition energies")
composition_energies = torch.zeros(len(train_y.block().values), dtype=dtype)
for i, structure in enumerate(train_structures):
for j, s in enumerate(species):
composition_energies[i] += (
torch.sum(structure.types == s) * composition_weights[j]
)
train_y_values = train_y.block().values
train_y_values = train_y_values - composition_energies.reshape(-1, 1)
train_block = metatensor.torch.TensorBlock(
values=train_y_values,
samples=train_y.block().samples,
components=train_y.block().components,
properties=train_y.block().properties,
)
if len(train_y[0].gradients_list()) > 0:
train_block.add_gradient("positions", train_y[0].gradient("positions"))
train_y = metatensor.torch.TensorMap(
train_y.keys,
[train_block],
logger.info("Subtracting composition energies")
# this acts in-place on train_y
remove_composition(
train_structures, {target_name: train_y}, model.composition_model
)

logger.info("Calculating SOAP features")
Expand Down
55 changes: 25 additions & 30 deletions src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from metatrain.utils.data.dataset import DatasetInfo

from ...utils.composition import apply_composition_contribution
from ...utils.composition import CompositionModel
from ...utils.dtype import dtype_to_str
from ...utils.export import export

Expand Down Expand Up @@ -123,14 +123,6 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
unit="unitless", per_atom=True
)

# creates a composition weight tensor that can be directly indexed by species,
# this can be left as a tensor of zero or set from the outside using
# set_composition_weights (recommended for better accuracy)
n_outputs = len(self.outputs)
self.register_buffer(
"composition_weights",
torch.zeros((n_outputs, max(self.atomic_types) + 1)),
)
# buffers cannot be indexed by strings (torchscript), so we create a single
# tensor for all output. Due to this, we need to slice the tensor when we use
# it and use the output name to select the correct slice via a dictionary
Expand Down Expand Up @@ -195,6 +187,11 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
}
)

self.composition_model = CompositionModel(
model_hypers={},
dataset_info=dataset_info,
)

def restart(self, dataset_info: DatasetInfo) -> "SoapBpnn":
# merge old and new dataset info
merged_info = self.dataset_info.union(dataset_info)
Expand Down Expand Up @@ -261,12 +258,7 @@ def forward(
atomic_energies: Dict[str, TensorMap] = {}
for output_name, output_layer in self.last_layers.items():
if output_name in outputs:
atomic_energies[output_name] = apply_composition_contribution(
output_layer(last_layer_features),
self.composition_weights[ # type: ignore
self.output_to_index[output_name]
],
)
atomic_energies[output_name] = output_layer(last_layer_features)

# Sum the atomic energies coming from the BPNN to get the total energy
for output_name, atomic_energy in atomic_energies.items():
Expand All @@ -281,6 +273,19 @@ def forward(
atomic_energy, ["atom", "center_type"]
)

if not self.training:
# at evaluation, we also add the composition contributions
Comment on lines +276 to +277
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NICE!

composition_contributions = self.composition_model(
systems, outputs, selected_atoms
)
for name in return_dict:
if name.startswith("mtt::aux::"):
continue # skip auxiliary outputs (not targets)
return_dict[name] = metatensor.torch.add(
return_dict[name],
composition_contributions[name],
)

return return_dict

@classmethod
Expand All @@ -303,6 +308,11 @@ def export(self) -> MetatensorAtomisticModel:
if dtype not in self.__supported_dtypes__:
raise ValueError(f"unsupported dtype {self.dtype} for SoapBpnn")

# Make sure the model is all in the same dtype
# For example, at this point, the composition model within the SOAP-BPNN is
# still float64
self.to(dtype)

capabilities = ModelCapabilities(
outputs=self.outputs,
atomic_types=self.atomic_types,
Expand All @@ -314,21 +324,6 @@ def export(self) -> MetatensorAtomisticModel:

return export(model=self, model_capabilities=capabilities)

def set_composition_weights(
self,
output_name: str,
input_composition_weights: torch.Tensor,
atomic_types: List[int],
) -> None:
"""Set the composition weights for a given output."""
# all species that are not present retain their weight of zero
self.composition_weights[self.output_to_index[output_name]][ # type: ignore
atomic_types
] = input_composition_weights.to(
dtype=self.composition_weights.dtype, # type: ignore
device=self.composition_weights.device, # type: ignore
)

def add_output(self, output_name: str) -> None:
"""Add a new output to the self."""
# add a new row to the composition weights tensor
Expand Down
6 changes: 6 additions & 0 deletions src/metatrain/experimental/soap_bpnn/tests/test_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def test_continue(monkeypatch, tmp_path):
}
}
targets, _ = read_targets(OmegaConf.create(conf))

# systems in float64 are required for training
systems = [system.to(torch.float64) for system in systems]
dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]})

hypers = DEFAULT_HYPERS.copy()
Expand All @@ -63,6 +66,9 @@ def test_continue(monkeypatch, tmp_path):
checkpoint_dir=".",
)

# evaluation
systems = [system.to(torch.float32) for system in systems]

# Predict on the first five systems
output_before = model_before(
systems[:5], {"mtt::U0": model_before.outputs["mtt::U0"]}
Expand Down
26 changes: 16 additions & 10 deletions src/metatrain/experimental/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ def test_regression_init():
)

expected_output = torch.tensor(
[[-0.03860], [0.11137], [0.09112], [-0.05634], [-0.02549]]
[
[-0.038599025458],
[0.111374437809],
[0.091115802526],
[-0.056339077652],
[-0.025491207838],
]
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5
Expand Down Expand Up @@ -100,17 +106,17 @@ def test_regression_train():

expected_output = torch.tensor(
[
[-40.592571258545],
[-56.522350311279],
[-76.571365356445],
[-77.384849548340],
[-93.445365905762],
[-0.106249026954],
[0.039981484413],
[-0.142682999372],
[-0.031701669097],
[-0.016210660338],
]
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5
Expand Down
Loading