Skip to content

Commit

Permalink
Merge branch 'main' into export
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 1, 2024
2 parents 27e3bc9 + fcc7c1f commit cf069fb
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 33 deletions.
14 changes: 10 additions & 4 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def train_model(options: DictConfig) -> None:
necessary options for dataset preparation, model hyperparameters, and training.
"""

# This gives some accuracy improvements. It is very likely that
# this is just due to the preliminary composition fit in the SOAP-BPNN.
# TODO: investigate
torch.set_default_dtype(torch.float64)

# TODO load seed from config
generator = torch.Generator()

Expand All @@ -109,8 +114,8 @@ def train_model(options: DictConfig) -> None:
if not isinstance(test_options, float):
test_options = expand_dataset_config(test_options)
test_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
filename=test_options["structures"]["read_from"],
fileformat=test_options["structures"]["file_format"],
)
test_targets = read_targets(test_options["targets"])
test_dataset = Dataset(test_structures, test_targets)
Expand All @@ -125,8 +130,8 @@ def train_model(options: DictConfig) -> None:
if not isinstance(validation_options, float):
validation_options = expand_dataset_config(validation_options)
validation_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
filename=validation_options["structures"]["read_from"],
fileformat=validation_options["structures"]["file_format"],
)
validation_targets = read_targets(validation_options["targets"])
validation_dataset = Dataset(validation_structures, validation_targets)
Expand Down Expand Up @@ -179,6 +184,7 @@ def train_model(options: DictConfig) -> None:
for dataset in [train_dataset]: # HACK: only a single train_dataset for now
all_species += get_all_species(dataset)
all_species = list(set(all_species))
all_species.sort()

outputs = {
key: ModelOutput(
Expand Down
79 changes: 72 additions & 7 deletions src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ def forward(self, features: TensorMap) -> TensorMap:
for i in range(features.keys.values.shape[0])
]

new_keys: List[int] = []
new_blocks: List[TensorBlock] = []
for species_str, network in self.layers.items():
species = int(species_str)
if species in present_blocks:
new_keys.append(species)
block = features.block({"species_center": species})
output_values = network(block.values)
new_blocks.append(
Expand All @@ -86,8 +88,60 @@ def forward(self, features: TensorMap) -> TensorMap:
),
)
)
new_keys_labels = Labels(
names=["species_center"],
values=torch.tensor(new_keys).reshape(-1, 1),
)

return TensorMap(keys=new_keys_labels, blocks=new_blocks)


class LayerNormMap(torch.nn.Module):
def __init__(self, all_species: List[int], n_layer: int) -> None:
super().__init__()

# Initialize a layernorm for each species
layernorm_per_species = []
for _ in all_species:
layernorm_per_species.append(torch.nn.LayerNorm((n_layer,)))

# Create a module dict to store the neural networks
self.layernorms = torch.nn.ModuleDict(
{
str(species): layer
for species, layer in zip(all_species, layernorm_per_species)
}
)

def forward(self, features: TensorMap) -> TensorMap:
# Create a list of the blocks that are present in the features:
present_blocks = [
int(features.keys.entry(i).values.item())
for i in range(features.keys.values.shape[0])
]

return TensorMap(keys=features.keys, blocks=new_blocks)
new_keys: List[int] = []
new_blocks: List[TensorBlock] = []
for species_str, layer in self.layernorms.items():
species = int(species_str)
if species in present_blocks:
new_keys.append(species)
block = features.block({"species_center": species})
output_values = layer(block.values)
new_blocks.append(
TensorBlock(
values=output_values,
samples=block.samples,
components=block.components,
properties=block.properties,
)
)
new_keys_labels = Labels(
names=["species_center"],
values=torch.tensor(new_keys).reshape(-1, 1),
)

return TensorMap(keys=new_keys_labels, blocks=new_blocks)


class LinearMap(torch.nn.Module):
Expand All @@ -114,10 +168,12 @@ def forward(self, features: TensorMap) -> TensorMap:
for i in range(features.keys.values.shape[0])
]

new_keys: List[int] = []
new_blocks: List[TensorBlock] = []
for species_str, layer in self.layers.items():
species = int(species_str)
if species in present_blocks:
new_keys.append(species)
block = features.block({"species_center": species})
output_values = layer(block.values)
new_blocks.append(
Expand All @@ -134,8 +190,12 @@ def forward(self, features: TensorMap) -> TensorMap:
),
)
)
new_keys_labels = Labels(
names=["species_center"],
values=torch.tensor(new_keys).reshape(-1, 1),
)

return TensorMap(keys=features.keys, blocks=new_blocks)
return TensorMap(keys=new_keys_labels, blocks=new_blocks)


class Model(torch.nn.Module):
Expand All @@ -150,8 +210,7 @@ def __init__(
if output.quantity != "energy":
raise ValueError(
"SOAP-BPNN only supports energy-like outputs, "
f"but a {next(iter(capabilities.outputs.values())).quantity} "
"was provided"
f"but a {output.quantity} was provided"
)
if output.per_atom:
raise ValueError(
Expand All @@ -178,13 +237,17 @@ def __init__(
}

self.soap_calculator = rascaline.torch.SoapPowerSpectrum(**hypers["soap"])
hypers_bpnn = hypers["bpnn"]
hypers_bpnn["input_size"] = (
soap_size = (
len(self.all_species) ** 2
* hypers["soap"]["max_radial"] ** 2
* (hypers["soap"]["max_angular"] + 1)
)

self.layernorm = LayerNormMap(self.all_species, soap_size)

hypers_bpnn = hypers["bpnn"]
hypers_bpnn["input_size"] = soap_size

self.bpnn = MLPMap(self.all_species, hypers_bpnn)
self.neighbor_species_1_labels = Labels(
names=["species_neighbor_1"],
Expand Down Expand Up @@ -233,6 +296,8 @@ def forward(
self.neighbor_species_2_labels.to(device)
)

soap_features = self.layernorm(soap_features)

hidden_features = self.bpnn(soap_features)

atomic_energies: Dict[str, TensorMap] = {}
Expand All @@ -251,7 +316,7 @@ def forward(
atomic_energy, ["center", "species_center"]
)
# Change the energy label from _ to (0, 1):
total_energies[output_name] = metatensor.torch.TensorMap(
total_energies[output_name] = TensorMap(
keys=Labels(
names=["lambda", "sigma"],
values=torch.tensor([[0, 1]]),
Expand Down
4 changes: 2 additions & 2 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_regression_init():
{"U0": soap_bpnn.capabilities.outputs["U0"]},
)
expected_output = torch.tensor(
[[-0.4615], [-0.4367], [-0.3004], [-0.2606], [-0.2380]],
[[-0.1746], [-0.2209], [-0.2426], [-0.2033], [-0.2973]],
dtype=torch.float64,
)

Expand Down Expand Up @@ -90,7 +90,7 @@ def test_regression_train():
output = soap_bpnn(structures[:5], {"U0": soap_bpnn.capabilities.outputs["U0"]})

expected_output = torch.tensor(
[[-40.1358], [-56.1721], [-76.1576], [-77.1174], [-93.1679]],
[[-40.5007], [-56.5529], [-76.4418], [-77.2819], [-93.3743]],
dtype=torch.float64,
)

Expand Down
13 changes: 9 additions & 4 deletions src/metatensor/models/utils/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import rascaline.torch
import torch
from metatensor.torch import TensorBlock, TensorMap
from metatensor.torch import Labels, TensorBlock, TensorMap


def calculate_composition_weights(
Expand Down Expand Up @@ -80,11 +80,11 @@ def apply_composition_contribution(
Atomic property with the composition contribution applied.
"""

# Get the composition for each structure in the dataset

new_keys: List[int] = []
new_blocks: List[TensorBlock] = []
for key, block in atomic_property.items():
atomic_species = int(key.values.item())
new_keys.append(atomic_species)
new_values = block.values + composition_weights[atomic_species]
new_blocks.append(
TensorBlock(
Expand All @@ -95,4 +95,9 @@ def apply_composition_contribution(
)
)

return TensorMap(keys=atomic_property.keys, blocks=new_blocks)
new_keys_labels = Labels(
names=["species_center"],
values=torch.tensor(new_keys).reshape(-1, 1),
)

return TensorMap(keys=new_keys_labels, blocks=new_blocks)
45 changes: 31 additions & 14 deletions src/metatensor/models/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import os
from typing import Dict, List
from typing import Dict, List, Tuple

import metatensor.torch
import numpy as np
import torch
from metatensor.torch import Labels, TensorMap
from metatensor.torch.atomistic import ModelCapabilities, System


if os.environ.get("METATENSOR_IMPORT_FOR_SPHINX", "0") == "1":
# This is necessary to make the Sphinx documentation build
compiled_slice = None
compiled_join = None
def compiled_slice(a, b):
pass

def compiled_join(a, axis, remove_tensor_name):
pass

else:
compiled_slice = torch.jit.script(metatensor.torch.slice)
compiled_join = torch.jit.script(metatensor.torch.join)
Expand Down Expand Up @@ -55,16 +60,14 @@ def __getitem__(self, index):
"""
structure = self.structures[index]

structure_index_samples = Labels(
sample_labels = Labels(
names=["structure"],
values=torch.tensor([[index]]), # must be a 2D-array
values=torch.tensor([index]).reshape(1, 1),
)

targets = {}
for name, tensor_map in self.targets.items():
targets[name] = compiled_slice(
tensor_map, "samples", structure_index_samples
)
targets[name] = compiled_slice(tensor_map, "samples", sample_labels)

return structure, targets

Expand Down Expand Up @@ -125,7 +128,7 @@ def get_all_targets(dataset: Dataset) -> List[str]:
return list(set(target_names))


def collate_fn(batch):
def collate_fn(batch: List[Tuple[System, Dict[str, TensorMap]]]):
"""
Creates a batch from a list of samples.
Expand All @@ -137,11 +140,25 @@ def collate_fn(batch):
A tuple containing the structures and targets for the batch.
"""

structures = [sample[0] for sample in batch]
targets = {}
for name in batch[0][1].keys():
targets[name] = compiled_join([sample[1][name] for sample in batch], "samples")

structures: List[System] = [sample[0] for sample in batch]
# `join` will reorder the samples based on their structure number.
# Let's reorder the list of structures in the same way:
structure_samples = [
list(sample[1].values())[0].block().samples.values.item() for sample in batch
]
sorting_order = np.argsort(structure_samples)
structures = [structures[index] for index in sorting_order]
# TODO: use metatensor.learn for datasets/dataloaders, making sure the same
# issues are handled correctly

targets: Dict[str, TensorMap] = {}
names = list(batch[0][1].keys())
for name in names:
targets[name] = compiled_join(
[sample[1][name] for sample in batch],
axis="samples",
remove_tensor_name=True,
)
return structures, targets


Expand Down
39 changes: 39 additions & 0 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import subprocess
from pathlib import Path

import ase.io
import pytest
from omegaconf import OmegaConf


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"
Expand All @@ -26,6 +28,43 @@ def test_train(monkeypatch, tmp_path, output):
assert Path(output).is_file()


@pytest.mark.parametrize("test_set_file", (True, False))
@pytest.mark.parametrize("validation_set_file", (True, False))
@pytest.mark.parametrize("output", [None, "mymodel.pt"])
def test_train_explicit_validation_test(
monkeypatch, tmp_path, test_set_file, validation_set_file, output
):
"""Test that training via the training cli runs without an error raise
also when the validation and test sets are provided explicitly."""
monkeypatch.chdir(tmp_path)

structures = ase.io.read(RESOURCES_PATH / "qm9_reduced_100.xyz", ":")
options = OmegaConf.load(RESOURCES_PATH / "options.yaml")

ase.io.write("qm9_reduced_100.xyz", structures[:50])

if test_set_file:
ase.io.write("test.xyz", structures[50:80])
options["validation_set"] = options["training_set"].copy()
options["validation_set"]["structures"]["read_from"] = "test.xyz"

if validation_set_file:
ase.io.write("validation.xyz", structures[80:])
options["test_set"] = options["training_set"].copy()
options["test_set"]["structures"]["read_from"] = "validation.xyz"

OmegaConf.save(config=options, f="options.yaml")
command = ["metatensor-models", "train", "options.yaml"]

if output is not None:
command += ["-o", output]
else:
output = "model.pt"

subprocess.check_call(command)
assert Path(output).is_file()


def test_yml_error():
"""Test error raise of the option file is not a .yaml file."""
try:
Expand Down
Binary file modified tests/resources/bpnn-model.pt
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/resources/options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ defaults:

architecture:
training:
batch_size: 2
num_epochs: 1
batch_size: 2
num_epochs: 1

# Section defining the parameters for structure and target data
training_set:
Expand Down

0 comments on commit cf069fb

Please sign in to comment.