Skip to content

Commit

Permalink
Add NL request for ZBL
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 1, 2024
1 parent 43dbc23 commit 18c87ad
Showing 1 changed file with 11 additions and 44 deletions.
55 changes: 11 additions & 44 deletions src/metatrain/utils/additive/zbl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from ase.data import covalent_radii
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import ModelOutput, System
from metatensor.torch.atomistic import ModelOutput, NeighborListOptions, System

from ..data import DatasetInfo

Expand Down Expand Up @@ -101,49 +101,8 @@ def forward(
# Assert only one neighbor list for all systems
neighbor_lists: List[TensorBlock] = []
for system in systems:
nl_options = system.known_neighbor_lists()
if len(nl_options) != 1:
raise ValueError("ZBL only supports one neighbor list per system.")
nl_option = nl_options[0]
if nl_option.cutoff < 2.0 * self.largest_covalent_radius:
raise ValueError(
"ZBL only supports neighbor lists with a cutoff of at least "
f"{2.0 * self.largest_covalent_radius} Å, since the largest "
f"covalent radius is {self.largest_covalent_radius} Å."
)
if nl_option.full_list:
nl = system.get_neighbor_list(nl_option)
else:
# convert to full NL
half_nl = system.get_neighbor_list(nl_option)
half_nl_samples = half_nl.samples.values
half_nl_values = half_nl.values
nl = TensorBlock(
samples=Labels(
names=half_nl.samples.names,
values=torch.concatenate(
[
half_nl_samples,
torch.concatenate(
[
half_nl_samples[:, 1].unsqueeze(-1),
half_nl_samples[:, 0].unsqueeze(-1),
-half_nl_samples[:, 2:5],
],
dim=1,
),
]
),
),
components=half_nl.components,
properties=half_nl.properties,
values=torch.concatenate(
[
half_nl_values,
-half_nl_values,
],
),
)
nl_options = self.requested_neighbor_lists()[0]
nl = system.get_neighbor_list(nl_options)
neighbor_lists.append(nl)

# Find the elements of all i and j atoms
Expand Down Expand Up @@ -269,6 +228,14 @@ def get_pairwise_zbl(self, zi, zj, rij):

return e

def requested_neighbor_lists(self) -> List[NeighborListOptions]:
return [
NeighborListOptions(
cutoff=2.0 * self.largest_covalent_radius,
full_list=True,
)
]


def _phi(r, c, da):
phi = torch.sum(c.unsqueeze(-1) * torch.exp(-r * da), dim=0)
Expand Down

0 comments on commit 18c87ad

Please sign in to comment.