diff --git a/examples/programmatic/llpr/llpr.py b/examples/programmatic/llpr/llpr.py index 8db135c8..10857aaa 100644 --- a/examples/programmatic/llpr/llpr.py +++ b/examples/programmatic/llpr/llpr.py @@ -48,7 +48,10 @@ # how to create a Dataset object from them. from metatrain.utils.data import Dataset, read_systems, read_targets # noqa: E402 -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists # noqa: E402 +from metatrain.utils.neighbor_lists import ( # noqa: E402 + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) qm9_systems = read_systems("qm9_reduced_100.xyz") @@ -67,7 +70,7 @@ } targets, _ = read_targets(target_config) -requested_neighbor_lists = model.requested_neighbor_lists() +requested_neighbor_lists = get_requested_neighbor_lists(model) qm9_systems = [ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 93adb162..59a4f9ef 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -23,7 +23,10 @@ from ..utils.evaluate_model import evaluate_model from ..utils.logging import MetricLogger from ..utils.metrics import RMSEAccumulator -from ..utils.neighbor_lists import get_system_with_neighbor_lists +from ..utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ..utils.omegaconf import expand_dataset_config from ..utils.per_atom import average_by_num_atoms from .formatter import CustomHelpFormatter @@ -172,7 +175,7 @@ def _eval_targets( # if already present (e.g. if this function is called after training) for sample in dataset: system = sample["system"] - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, get_requested_neighbor_lists(model)) # Infer the device and dtype from the model model_tensor = next(itertools.chain(model.parameters(), model.buffers())) diff --git a/src/metatrain/experimental/alchemical_model/tests/test_exported.py b/src/metatrain/experimental/alchemical_model/tests/test_exported.py index 3be00244..89198369 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_exported.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_exported.py @@ -4,7 +4,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -31,7 +34,8 @@ def test_to(device, dtype): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py index b3c42d81..7ee3331a 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_functionality.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_functionality.py @@ -3,7 +3,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -25,7 +28,8 @@ def test_prediction_subset_elements(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py index f6492584..9d9a84dd 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_invariance.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_invariance.py @@ -6,7 +6,10 @@ from metatrain.experimental.alchemical_model import AlchemicalModel from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import DATASET_PATH, MODEL_HYPERS @@ -24,13 +27,15 @@ def test_rotational_invariance(): system = ase.io.read(DATASET_PATH) original_system = copy.deepcopy(system) original_system = systems_to_torch(original_system) + requested_neighbor_lists = get_requested_neighbor_lists(model) original_system = get_system_with_neighbor_lists( - original_system, model.requested_neighbor_lists() + original_system, requested_neighbor_lists ) system.rotate(48, "y") system = systems_to_torch(system) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/alchemical_model/tests/test_regression.py b/src/metatrain/experimental/alchemical_model/tests/test_regression.py index 4dbc6ed0..648c91c1 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_regression.py @@ -14,7 +14,10 @@ read_targets, ) from metatrain.utils.data.dataset import TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS @@ -38,8 +41,9 @@ def test_regression_init(): # Predict on the first five systems systems = read_systems(DATASET_PATH)[:5] + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] @@ -101,8 +105,9 @@ def test_regression_train(): ) model = AlchemicalModel(MODEL_HYPERS, dataset_info) + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 20b1de5e..0dd13a2b 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -20,7 +20,10 @@ from ...utils.logging import MetricLogger from ...utils.loss import TensorMapDictLoss from ...utils.metrics import RMSEAccumulator -from ...utils.neighbor_lists import get_system_with_neighbor_lists +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ...utils.per_atom import average_by_num_atoms from . import AlchemicalModel from .utils.composition import calculate_composition_weights @@ -68,7 +71,7 @@ def train( # Calculating the neighbor lists for the training and validation datasets: logger.info("Calculating neighbor lists for the datasets") - requested_neighbor_lists = model.requested_neighbor_lists() + requested_neighbor_lists = get_requested_neighbor_lists(model) for dataset in train_datasets + val_datasets: for i in range(len(dataset)): system = dataset[i]["system"] diff --git a/src/metatrain/experimental/gap/trainer.py b/src/metatrain/experimental/gap/trainer.py index 489eb98b..5daef836 100644 --- a/src/metatrain/experimental/gap/trainer.py +++ b/src/metatrain/experimental/gap/trainer.py @@ -11,7 +11,10 @@ from ...utils.additive import remove_additive from ...utils.data import check_datasets -from ...utils.neighbor_lists import get_system_with_neighbor_lists +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import GAP from .model import torch_tensor_map_to_core @@ -72,9 +75,7 @@ def train( train_structures = [sample["system"] for sample in train_dataset] logger.info("Calculating neighbor lists for the datasets") - requested_neighbor_lists = ( - model._soap_torch_calculator.requested_neighbor_lists() - ) + requested_neighbor_lists = get_requested_neighbor_lists(model) for dataset in train_datasets + val_datasets: for i in range(len(dataset)): system = dataset[i]["system"] diff --git a/src/metatrain/experimental/pet/tests/test_exported.py b/src/metatrain/experimental/pet/tests/test_exported.py index a72eb88d..f67a15e4 100644 --- a/src/metatrain/experimental/pet/tests/test_exported.py +++ b/src/metatrain/experimental/pet/tests/test_exported.py @@ -13,7 +13,10 @@ from metatrain.utils.architectures import get_default_hypers from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict from metatrain.utils.export import export -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -59,7 +62,8 @@ def test_to(device): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/pet/tests/test_functionality.py b/src/metatrain/experimental/pet/tests/test_functionality.py index 74a47b07..ddf52760 100644 --- a/src/metatrain/experimental/pet/tests/test_functionality.py +++ b/src/metatrain/experimental/pet/tests/test_functionality.py @@ -20,7 +20,10 @@ from metatrain.utils.architectures import get_default_hypers from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict from metatrain.utils.jsonschema import validate -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) DEFAULT_HYPERS = get_default_hypers("experimental.pet") @@ -74,7 +77,8 @@ def test_prediction(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, @@ -123,7 +127,8 @@ def test_per_atom_predictions_functionality(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, @@ -173,7 +178,8 @@ def test_selected_atoms_functionality(): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(model) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) evaluation_options = ModelEvaluationOptions( length_unit=dataset_info.length_unit, diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py index cc41a360..63242161 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_exported.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_exported.py @@ -4,7 +4,10 @@ from metatrain.experimental.soap_bpnn import SoapBpnn from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS @@ -31,7 +34,8 @@ def test_to(device, dtype): positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), cell=torch.zeros(3, 3), ) - system = get_system_with_neighbor_lists(system, exported.requested_neighbor_lists()) + requested_neighbor_lists = get_requested_neighbor_lists(exported) + system = get_system_with_neighbor_lists(system, requested_neighbor_lists) system = system.to(device=device, dtype=dtype) evaluation_options = ModelEvaluationOptions( diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index f56780a3..b67d9dd8 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -18,7 +18,10 @@ from ...utils.logging import MetricLogger from ...utils.loss import TensorMapDictLoss from ...utils.metrics import RMSEAccumulator -from ...utils.neighbor_lists import get_system_with_neighbor_lists +from ...utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from ...utils.per_atom import average_by_num_atoms from .model import SoapBpnn @@ -90,7 +93,7 @@ def train( # needs to happen before the additive models are trained, as they # might need them): logger.info("Calculating neighbor lists for the datasets") - requested_neighbor_lists = model.soap_calculator.requested_neighbor_lists() + requested_neighbor_lists = get_requested_neighbor_lists(model) for dataset in train_datasets + val_datasets: for i in range(len(dataset)): system = dataset[i]["system"] diff --git a/src/metatrain/utils/neighbor_lists.py b/src/metatrain/utils/neighbor_lists.py index b76d836f..e9d9304f 100644 --- a/src/metatrain/utils/neighbor_lists.py +++ b/src/metatrain/utils/neighbor_lists.py @@ -15,6 +15,69 @@ from .data.system_to_ase import system_to_ase +def get_requested_neighbor_lists( + module: torch.nn.Module, + length_unit: str = "", +) -> List[NeighborListOptions]: + """Get the neighbor lists requested by a module and its children. + + :param module: The module for which to get the requested neighbor lists. + :param length_unit: The length units to be registered in the returned + `NeighborListOptions`. + + :return: A list of `NeighborListOptions` objects requested by the module. + """ + requested: List[NeighborListOptions] = [] + _get_requested_neighbor_lists_in_place( + module=module, + module_name="", + requested=requested, + length_unit=length_unit, + ) + return requested + + +def _get_requested_neighbor_lists_in_place( + module: torch.nn.Module, + module_name: str, + requested: List[NeighborListOptions], + length_unit: str, +): + # copied from + # metatensor/python/metatensor-torch/metatensor/torch/atomistic/model.py + # and just removed the length units + + if hasattr(module, "requested_neighbor_lists"): + for new_options in module.requested_neighbor_lists(): + new_options.add_requestor(module_name) + + already_requested = False + for existing in requested: + if existing == new_options: + already_requested = True + for requestor in new_options.requestors(): + existing.add_requestor(requestor) + + if not already_requested: + if new_options.length_unit not in ["", length_unit]: + raise ValueError( + f"NeighborsListOptions from {module_name} already have a " + f"length unit ('{new_options.length_unit}') which does not " + f"match the model length units ('{length_unit}')" + ) + + new_options.length_unit = length_unit + requested.append(new_options) + + for child_name, child in module.named_children(): + _get_requested_neighbor_lists_in_place( + module=child, + module_name=module_name + "." + child_name, + requested=requested, + length_unit=length_unit, + ) + + def get_system_with_neighbor_lists( system: System, neighbor_lists: List[NeighborListOptions] ) -> System: diff --git a/tests/utils/test_evaluate_model.py b/tests/utils/test_evaluate_model.py index 72826cd7..e2bd81ec 100644 --- a/tests/utils/test_evaluate_model.py +++ b/tests/utils/test_evaluate_model.py @@ -6,7 +6,10 @@ from metatrain.utils.data import DatasetInfo, TargetInfo, read_systems from metatrain.utils.evaluate_model import evaluate_model from metatrain.utils.export import export -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import MODEL_HYPERS, RESOURCES_PATH @@ -45,8 +48,9 @@ def test_evaluate_model(training, exported): ) model = export(model, capabilities) + requested_neighbor_lists = get_requested_neighbor_lists(model) systems = [ - get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in systems ] diff --git a/tests/utils/test_llpr.py b/tests/utils/test_llpr.py index f1887ef5..189e7c2a 100644 --- a/tests/utils/test_llpr.py +++ b/tests/utils/test_llpr.py @@ -9,7 +9,10 @@ from metatrain.utils.data import Dataset, collate_fn, read_systems, read_targets from metatrain.utils.llpr import LLPRUncertaintyModel -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists, +) from . import RESOURCES_PATH @@ -37,7 +40,7 @@ def test_llpr(tmpdir): }, } targets, _ = read_targets(target_config) - requested_neighbor_lists = model.requested_neighbor_lists() + requested_neighbor_lists = get_requested_neighbor_lists(model) qm9_systems = [ get_system_with_neighbor_lists(system, requested_neighbor_lists) for system in qm9_systems