Skip to content

Commit

Permalink
Merge pull request pfnet-research#13 from mwata/refactor_230602
Browse files Browse the repository at this point in the history
Refactor 230602
  • Loading branch information
masakiwatanabe authored and GitHub Enterprise committed Jun 5, 2023
2 parents 64473bc + 662f3c7 commit 5870c9c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 26 deletions.
4 changes: 0 additions & 4 deletions torch_dftd/nn/dftd3_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ def __init__(
rcov = torch.tensor(d3_params["rcov"], dtype=dtype)
r2r4 = torch.tensor(d3_params["r2r4"], dtype=dtype)

# check c6ab parameter structure
assert torch.all((c6ab[:, :, :, :, 1] == c6ab[:, :, :, 0:1, 1]) | (c6ab[:, :, :, :, 0] < 0)), "c6ab(1) is not constant along row"
assert torch.all((c6ab[:, :, :, :, 2] == c6ab[:, :, 0:1, :, 2]) | (c6ab[:, :, :, :, 0] < 0)), "c6ab(2) is not constant along column"

# (95, 95, 5, 5, 3) c0, c1, c2 for coordination number dependent c6ab term.
self.register_buffer("c6ab", c6ab)
self.register_buffer("r0ab", r0ab) # atom pair distance (95, 95)
Expand Down
32 changes: 10 additions & 22 deletions torch_dftd_static/functions/dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
from torch import Tensor
from torch_dftd.functions.smoothing import poly_smoothing

# conversion factors used in grimme d3 code

from torch_dftd.functions.dftd3 import d3_k1, d3_k3

def edisp( # calculate edisp by all-pair computation
Z: Tensor,
pos: Tensor, # (n_atoms, 3)
shift_vecs: Tensor, # half of shift vectors (all shift vecs = shift_vecs + -shift_vecs + [(0,0,0)])
shift_vecs: Tensor, # half of shift vectors (eg. shift_vecs = [0, v, 2v] -> ghosts located at [-2v, -v, 0, v, 2v])
c6ab: Tensor,
r0ab: Tensor,
rcov: Tensor,
Expand All @@ -31,10 +29,7 @@ def edisp( # calculate edisp by all-pair computation
assert shift_mask is not None

n_atoms = len(Z)
#assert torch.all(shift_vecs[0] == 0.0)
#triu_mask = (torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :])[:, :, None] | ((torch.arange(len(shift_vecs)) > 0)[None, None, :])
triu_mask = (torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :])[:, :, None] | ((torch.any(shift_vecs != 0.0, axis=-1))[None, None, :])

triu_mask = triu_mask & atom_mask[:, None, None] & atom_mask[None, :, None]
triu_mask = triu_mask & shift_mask[None, None, :]

Expand All @@ -45,8 +40,8 @@ def edisp( # calculate edisp by all-pair computation

# calculate coordination numbers (n_atoms,)
rco = rcov[Z][:, None] + rcov[Z][None, :] # (n_atoms, n_atoms)
rr = rco[:, :, None] / r # (n_atoms, n_atoms, 1+n_shift)
damp = torch.sigmoid(k1 * (rr - 1.0)) # (n_atoms, n_atoms, 1+n_shift)
rr = rco[:, :, None] / r # (n_atoms, n_atoms, n_shift)
damp = torch.sigmoid(k1 * (rr - 1.0)) # (n_atoms, n_atoms, n_shift)
if cnthr is not None and cutoff_smoothing == "poly":
damp *= poly_smoothing(r, cnthr)
if cnthr is not None:
Expand All @@ -61,11 +56,12 @@ def edisp( # calculate edisp by all-pair computation
cn2 = c6ab[1, Z, 0, :, 2] # (n_atoms, 5)
k3_rnc_1 = torch.where(cn1 >= 0.0, k3 * (nc[:, None] - cn1) ** 2, torch.tensor(-1.0e20))
k3_rnc_2 = torch.where(cn2 >= 0.0, k3 * (nc[:, None] - cn2) ** 2, torch.tensor(-1.0e20))
r_ratio_1 = torch.softmax(k3_rnc_1, dim=-1).to(torch.float32)
r_ratio_2 = torch.softmax(k3_rnc_2, dim=-1).to(torch.float32)
r_ratio_1 = torch.softmax(k3_rnc_1, dim=-1)
r_ratio_2 = torch.softmax(k3_rnc_2, dim=-1)
c6 = (cn0 * r_ratio_1[:, None, :, None] * r_ratio_2[None, :, None, :]).sum(dim=(-1,-2))
c8 = 3 * c6 * r2r4[Z][:, None] * r2r4[Z][None, :]

c8c6_ratio = 3 * r2r4[Z][:, None] * r2r4[Z][None, :]
c8 = c6 * c8c6_ratio

# calculate energy
s6 = params["s6"]
s8 = params["s18"]
Expand All @@ -74,10 +70,9 @@ def edisp( # calculate edisp by all-pair computation
if damping in ["bj", "bjm"]:
a1 = params["rs6"]
a2 = params["rs18"]

# Becke-Johnson damping, zero-damping introduces spurious repulsion
# and is therefore not supported/implemented
tmp = a1 * torch.sqrt(c8 / c6) + a2
tmp = a1 * torch.sqrt(c8c6_ratio) + a2
tmp2 = tmp ** 2
tmp6 = tmp2 ** 3
tmp8 = tmp6 * tmp2
Expand All @@ -97,11 +92,4 @@ def edisp( # calculate edisp by all-pair computation

e68 = torch.where(triu_mask, e68, torch.tensor(0.0))

return torch.sum(e68.to(torch.float64).sum()) * 2.0

#e68_same_cell = e68[:, :, 0]
#e68_same_cell = torch.where(torch.arange(n_atoms)[:, None] < torch.arange(n_atoms)[None, :], e68_same_cell, torch.tensor(0.0))
#e68_diff_cell = e68[:, :, 1:]
#e_same_cell = torch.sum(e68_same_cell.to(torch.float64).sum()) * 2.0
#e_diff_cell = torch.sum(e68_diff_cell.to(torch.float64).sum()) * 2.0
#return e_same_cell + e_diff_cell
return e68.to(torch.float64).sum() * 2.0
20 changes: 20 additions & 0 deletions torch_dftd_static/nn/dftd3_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@
from torch_dftd.functions.dftd3 import d3_autoang, d3_autoev
from torch_dftd_static.functions.dftd3 import edisp as edisp_triu


"""
Check that c6ab array (shape = (95,95,5,5,3)) has the following structure,
as assumed in edisp function:
- c6ab[..., 1] is constant along axes 1 (Z of second atom) and 3 (cn of second atom),
except for Z=0 which does not represent valid atom.
Second condition in torch.all(...) below exists because there can be
rows/columns without valid values (because not every pair of atoms has full 5x5 table),
and such missing values are represented by -1.
- c6ab[..., 2] is constant along axes 0 (Z of first atom) and 2 (cn of second atom)
except for Z=0.
https://docs.google.com/presentation/d/15J3jDALiD_tDPT9DVi2GcIBTMUT8QlfQe9ET4iuCE0M/edit#slide=id.g226c3535966_0_34
"""
def _check_c6ab_structure(c6ab):
assert torch.all((c6ab[:, 1:, :, :, 1] == c6ab[:, 1:2, :, 0:1, 1]) | (c6ab[:, 1:, :, :, 0] < 0)), "c6ab[..., 1] is not constant along second atom"
assert torch.all((c6ab[1:, :, :, :, 2] == c6ab[1:2, :, 0:1, :, 2]) | (c6ab[1:, :, :, :, 0] < 0)), "c6ab[..., 2] is not constant along first atom"

class DFTD3ModuleStatic(torch.nn.Module):
"""DFTD3ModuleStatic
Args:
Expand Down Expand Up @@ -48,6 +66,8 @@ def __init__(
self.register_buffer("rcov", rcov) # atom covalent distance (95)
self.register_buffer("r2r4", r2r4) # (95,)

_check_c6ab_structure(c6ab)

if cnthr > cutoff:
print(
f"WARNING: cnthr {cnthr} is larger than cutoff {cutoff}. "
Expand Down

0 comments on commit 5870c9c

Please sign in to comment.