From 18c87ada159aff1a3d3b602be8fac2f259cd8512 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 1 Oct 2024 14:15:43 +0200 Subject: [PATCH] Add NL request for ZBL --- src/metatrain/utils/additive/zbl.py | 55 ++++++----------------------- 1 file changed, 11 insertions(+), 44 deletions(-) diff --git a/src/metatrain/utils/additive/zbl.py b/src/metatrain/utils/additive/zbl.py index 3e5000d5..b23f2b97 100644 --- a/src/metatrain/utils/additive/zbl.py +++ b/src/metatrain/utils/additive/zbl.py @@ -4,7 +4,7 @@ import torch from ase.data import covalent_radii from metatensor.torch import Labels, TensorBlock, TensorMap -from metatensor.torch.atomistic import ModelOutput, System +from metatensor.torch.atomistic import ModelOutput, NeighborListOptions, System from ..data import DatasetInfo @@ -101,49 +101,8 @@ def forward( # Assert only one neighbor list for all systems neighbor_lists: List[TensorBlock] = [] for system in systems: - nl_options = system.known_neighbor_lists() - if len(nl_options) != 1: - raise ValueError("ZBL only supports one neighbor list per system.") - nl_option = nl_options[0] - if nl_option.cutoff < 2.0 * self.largest_covalent_radius: - raise ValueError( - "ZBL only supports neighbor lists with a cutoff of at least " - f"{2.0 * self.largest_covalent_radius} Å, since the largest " - f"covalent radius is {self.largest_covalent_radius} Å." - ) - if nl_option.full_list: - nl = system.get_neighbor_list(nl_option) - else: - # convert to full NL - half_nl = system.get_neighbor_list(nl_option) - half_nl_samples = half_nl.samples.values - half_nl_values = half_nl.values - nl = TensorBlock( - samples=Labels( - names=half_nl.samples.names, - values=torch.concatenate( - [ - half_nl_samples, - torch.concatenate( - [ - half_nl_samples[:, 1].unsqueeze(-1), - half_nl_samples[:, 0].unsqueeze(-1), - -half_nl_samples[:, 2:5], - ], - dim=1, - ), - ] - ), - ), - components=half_nl.components, - properties=half_nl.properties, - values=torch.concatenate( - [ - half_nl_values, - -half_nl_values, - ], - ), - ) + nl_options = self.requested_neighbor_lists()[0] + nl = system.get_neighbor_list(nl_options) neighbor_lists.append(nl) # Find the elements of all i and j atoms @@ -269,6 +228,14 @@ def get_pairwise_zbl(self, zi, zj, rij): return e + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [ + NeighborListOptions( + cutoff=2.0 * self.largest_covalent_radius, + full_list=True, + ) + ] + def _phi(r, c, da): phi = torch.sum(c.unsqueeze(-1) * torch.exp(-r * da), dim=0)