Skip to content

Commit

Permalink
New function to find all requested neighbor lists in a model
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 1, 2024
1 parent 5b297c5 commit 54bcca5
Show file tree
Hide file tree
Showing 15 changed files with 149 additions and 34 deletions.
7 changes: 5 additions & 2 deletions examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
# how to create a Dataset object from them.

from metatrain.utils.data import Dataset, read_systems, read_targets # noqa: E402
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists # noqa: E402
from metatrain.utils.neighbor_lists import ( # noqa: E402
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)


qm9_systems = read_systems("qm9_reduced_100.xyz")
Expand All @@ -67,7 +70,7 @@
}
targets, _ = read_targets(target_config)

requested_neighbor_lists = model.requested_neighbor_lists()
requested_neighbor_lists = get_requested_neighbor_lists(model)
qm9_systems = [
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in qm9_systems
Expand Down
7 changes: 5 additions & 2 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from ..utils.evaluate_model import evaluate_model
from ..utils.logging import MetricLogger
from ..utils.metrics import RMSEAccumulator
from ..utils.neighbor_lists import get_system_with_neighbor_lists
from ..utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from ..utils.omegaconf import expand_dataset_config
from ..utils.per_atom import average_by_num_atoms
from .formatter import CustomHelpFormatter
Expand Down Expand Up @@ -172,7 +175,7 @@ def _eval_targets(
# if already present (e.g. if this function is called after training)
for sample in dataset:
system = sample["system"]
get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
get_system_with_neighbor_lists(system, get_requested_neighbor_lists(model))

# Infer the device and dtype from the model
model_tensor = next(itertools.chain(model.parameters(), model.buffers()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)

from . import MODEL_HYPERS

Expand All @@ -31,7 +34,8 @@ def test_to(device, dtype):
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
)
system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists())
requested_neighbor_lists = get_requested_neighbor_lists(exported)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)
system = system.to(device=device, dtype=dtype)

evaluation_options = ModelEvaluationOptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)

from . import MODEL_HYPERS

Expand All @@ -25,7 +28,8 @@ def test_prediction_subset_elements():
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
)
system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
requested_neighbor_lists = get_requested_neighbor_lists(model)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)

evaluation_options = ModelEvaluationOptions(
length_unit=dataset_info.length_unit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from metatrain.experimental.alchemical_model import AlchemicalModel
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)

from . import DATASET_PATH, MODEL_HYPERS

Expand All @@ -24,13 +27,15 @@ def test_rotational_invariance():
system = ase.io.read(DATASET_PATH)
original_system = copy.deepcopy(system)
original_system = systems_to_torch(original_system)
requested_neighbor_lists = get_requested_neighbor_lists(model)
original_system = get_system_with_neighbor_lists(
original_system, model.requested_neighbor_lists()
original_system, requested_neighbor_lists
)

system.rotate(48, "y")
system = systems_to_torch(system)
system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
requested_neighbor_lists = get_requested_neighbor_lists(model)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)

evaluation_options = ModelEvaluationOptions(
length_unit=dataset_info.length_unit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
read_targets,
)
from metatrain.utils.data.dataset import TargetInfoDict
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)

from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS

Expand All @@ -38,8 +41,9 @@ def test_regression_init():

# Predict on the first five systems
systems = read_systems(DATASET_PATH)[:5]
requested_neighbor_lists = get_requested_neighbor_lists(model)
systems = [
get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in systems
]

Expand Down Expand Up @@ -101,8 +105,9 @@ def test_regression_train():
)
model = AlchemicalModel(MODEL_HYPERS, dataset_info)

requested_neighbor_lists = get_requested_neighbor_lists(model)
systems = [
get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in systems
]

Expand Down
7 changes: 5 additions & 2 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from ...utils.logging import MetricLogger
from ...utils.loss import TensorMapDictLoss
from ...utils.metrics import RMSEAccumulator
from ...utils.neighbor_lists import get_system_with_neighbor_lists
from ...utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from ...utils.per_atom import average_by_num_atoms
from . import AlchemicalModel
from .utils.composition import calculate_composition_weights
Expand Down Expand Up @@ -68,7 +71,7 @@ def train(

# Calculating the neighbor lists for the training and validation datasets:
logger.info("Calculating neighbor lists for the datasets")
requested_neighbor_lists = model.requested_neighbor_lists()
requested_neighbor_lists = get_requested_neighbor_lists(model)
for dataset in train_datasets + val_datasets:
for i in range(len(dataset)):
system = dataset[i]["system"]
Expand Down
9 changes: 5 additions & 4 deletions src/metatrain/experimental/gap/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from ...utils.additive import remove_additive
from ...utils.data import check_datasets
from ...utils.neighbor_lists import get_system_with_neighbor_lists
from ...utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from . import GAP
from .model import torch_tensor_map_to_core

Expand Down Expand Up @@ -72,9 +75,7 @@ def train(
train_structures = [sample["system"] for sample in train_dataset]

logger.info("Calculating neighbor lists for the datasets")
requested_neighbor_lists = (
model._soap_torch_calculator.requested_neighbor_lists()
)
requested_neighbor_lists = get_requested_neighbor_lists(model)
for dataset in train_datasets + val_datasets:
for i in range(len(dataset)):
system = dataset[i]["system"]
Expand Down
8 changes: 6 additions & 2 deletions src/metatrain/experimental/pet/tests/test_exported.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from metatrain.utils.architectures import get_default_hypers
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.export import export
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)


DEFAULT_HYPERS = get_default_hypers("experimental.pet")
Expand Down Expand Up @@ -59,7 +62,8 @@ def test_to(device):
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
)
system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists())
requested_neighbor_lists = get_requested_neighbor_lists(exported)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)
system = system.to(device=device, dtype=dtype)

evaluation_options = ModelEvaluationOptions(
Expand Down
14 changes: 10 additions & 4 deletions src/metatrain/experimental/pet/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from metatrain.utils.architectures import get_default_hypers
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.jsonschema import validate
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)


DEFAULT_HYPERS = get_default_hypers("experimental.pet")
Expand Down Expand Up @@ -74,7 +77,8 @@ def test_prediction():
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
)
system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
requested_neighbor_lists = get_requested_neighbor_lists(model)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)

evaluation_options = ModelEvaluationOptions(
length_unit=dataset_info.length_unit,
Expand Down Expand Up @@ -123,7 +127,8 @@ def test_per_atom_predictions_functionality():
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
)
system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
requested_neighbor_lists = get_requested_neighbor_lists(model)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)

evaluation_options = ModelEvaluationOptions(
length_unit=dataset_info.length_unit,
Expand Down Expand Up @@ -173,7 +178,8 @@ def test_selected_atoms_functionality():
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
)
system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
requested_neighbor_lists = get_requested_neighbor_lists(model)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)

evaluation_options = ModelEvaluationOptions(
length_unit=dataset_info.length_unit,
Expand Down
8 changes: 6 additions & 2 deletions src/metatrain/experimental/soap_bpnn/tests/test_exported.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from metatrain.experimental.soap_bpnn import SoapBpnn
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)

from . import MODEL_HYPERS

Expand All @@ -31,7 +34,8 @@ def test_to(device, dtype):
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
)
system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists())
requested_neighbor_lists = get_requested_neighbor_lists(exported)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)
system = system.to(device=device, dtype=dtype)

evaluation_options = ModelEvaluationOptions(
Expand Down
7 changes: 5 additions & 2 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from ...utils.logging import MetricLogger
from ...utils.loss import TensorMapDictLoss
from ...utils.metrics import RMSEAccumulator
from ...utils.neighbor_lists import get_system_with_neighbor_lists
from ...utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from ...utils.per_atom import average_by_num_atoms
from .model import SoapBpnn

Expand Down Expand Up @@ -90,7 +93,7 @@ def train(
# needs to happen before the additive models are trained, as they
# might need them):
logger.info("Calculating neighbor lists for the datasets")
requested_neighbor_lists = model.soap_calculator.requested_neighbor_lists()
requested_neighbor_lists = get_requested_neighbor_lists(model)
for dataset in train_datasets + val_datasets:
for i in range(len(dataset)):
system = dataset[i]["system"]
Expand Down
63 changes: 63 additions & 0 deletions src/metatrain/utils/neighbor_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,69 @@
from .data.system_to_ase import system_to_ase


def get_requested_neighbor_lists(
module: torch.nn.Module,
length_unit: str = "",
) -> List[NeighborListOptions]:
"""Get the neighbor lists requested by a module and its children.
:param module: The module for which to get the requested neighbor lists.
:param length_unit: The length units to be registered in the returned
`NeighborListOptions`.
:return: A list of `NeighborListOptions` objects requested by the module.
"""
requested: List[NeighborListOptions] = []
_get_requested_neighbor_lists_in_place(
module=module,
module_name="",
requested=requested,
length_unit=length_unit,
)
return requested


def _get_requested_neighbor_lists_in_place(
module: torch.nn.Module,
module_name: str,
requested: List[NeighborListOptions],
length_unit: str,
):
# copied from
# metatensor/python/metatensor-torch/metatensor/torch/atomistic/model.py
# and just removed the length units

if hasattr(module, "requested_neighbor_lists"):
for new_options in module.requested_neighbor_lists():
new_options.add_requestor(module_name)

already_requested = False
for existing in requested:
if existing == new_options:
already_requested = True
for requestor in new_options.requestors():
existing.add_requestor(requestor)

if not already_requested:
if new_options.length_unit not in ["", length_unit]:
raise ValueError(
f"NeighborsListOptions from {module_name} already have a "
f"length unit ('{new_options.length_unit}') which does not "
f"match the model length units ('{length_unit}')"
)

new_options.length_unit = length_unit
requested.append(new_options)

for child_name, child in module.named_children():
_get_requested_neighbor_lists_in_place(
module=child,
module_name=module_name + "." + child_name,
requested=requested,
length_unit=length_unit,
)


def get_system_with_neighbor_lists(
system: System, neighbor_lists: List[NeighborListOptions]
) -> System:
Expand Down
8 changes: 6 additions & 2 deletions tests/utils/test_evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems
from metatrain.utils.evaluate_model import evaluate_model
from metatrain.utils.export import export
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)

from . import MODEL_HYPERS, RESOURCES_PATH

Expand Down Expand Up @@ -45,8 +48,9 @@ def test_evaluate_model(training, exported):
)

model = export(model, capabilities)
requested_neighbor_lists = get_requested_neighbor_lists(model)
systems = [
get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
get_system_with_neighbor_lists(system, requested_neighbor_lists)
for system in systems
]

Expand Down
Loading

0 comments on commit 54bcca5

Please sign in to comment.