Skip to content

Commit

Permalink
Fix and update unittests.
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 15, 2023
1 parent b97c462 commit bd62fcc
Show file tree
Hide file tree
Showing 40 changed files with 7,840 additions and 582 deletions.
2 changes: 2 additions & 0 deletions dmff/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .hamiltonian import Hamiltonian
from .topology import DMFFTopology
61 changes: 61 additions & 0 deletions dmff/api/gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import List, Tuple
from pathlib import Path
import numpy as np
from .topology import DMFFTopology


EMBED_W1 = np.random.random((117, 117))

elem_to_index = {'EP': 0, 'H': 1, 'HE': 2, 'LI': 3, 'BE': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
'F': 9, 'NE': 10, 'NA': 11, 'MG': 12, 'AL': 13, 'SI': 14, 'P': 15, 'S': 16,
'CL': 17, 'AR': 18, 'K': 19, 'CA': 20, 'SC': 21, 'TI': 22, 'V': 23, 'CR': 24,
'MN': 25, 'FE': 26, 'CO': 27, 'NI': 28, 'CU': 29, 'ZN': 30, 'GA': 31, 'GE': 32,
'AS': 33, 'SE': 34, 'BR': 35, 'KR': 36, 'RB': 37, 'SR': 38, 'Y': 39, 'ZR': 40,
'NB': 41, 'MO': 42, 'TC': 43, 'RU': 44, 'RH': 45, 'PD': 46, 'AG': 47, 'CD': 48,
'IN': 49, 'SN': 50, 'SB': 51, 'TE': 52, 'I': 53, 'XE': 54, 'CS': 55, 'BA': 56,
'LA': 57, 'CE': 58, 'PR': 59, 'ND': 60, 'PM': 61, 'SM': 62, 'EU': 63, 'GD': 64,
'TB': 65, 'DY': 66, 'HO': 67, 'ER': 68, 'TM': 69, 'YB': 70, 'LU': 71, 'HF': 72,
'TA': 73, 'W': 74, 'RE': 75, 'OS': 76, 'IR': 77, 'PT': 78, 'AU': 79, 'HG': 80,
'TL': 81, 'PB': 82, 'BI': 83, 'PO': 84, 'AT': 85, 'RN': 86, 'FR': 87, 'RA': 88,
'AC': 89, 'TH': 90, 'PA': 91, 'U': 92, 'NP': 93, 'PU': 94, 'AM': 95, 'CM': 96,
'BK': 97, 'CF': 98, 'ES': 99, 'FM': 100, 'MD': 101, 'NO': 102, 'LR': 103,
'RF': 104, 'DB': 105, 'SG': 106, 'BH': 107, 'HS': 108, 'MT': 109, 'DS': 110,
'RG': 111, 'UUB': 112, 'UUT': 113, 'UUQ': 114, 'UUP': 115, 'UUH': 116}


def mol_to_graph_matrix(topdata: DMFFTopology) -> Tuple[np.ndarray, np.ndarray]:
num_atom = topdata.getNumAtoms()
adj = np.zeros((num_atom, num_atom))
node_features = np.zeros((num_atom, 117))
for i in range(num_atom):
adj[i, i] = 1
node_features[i, elem_to_index[topdata._atom[i].element.upper()]] = 1
for bond in topdata.bonds():
adj[bond.atom1.index, bond.atom2.index] = 1
adj[bond.atom2.index, bond.atom1.index] = 1
return adj, node_features


def get_embed(topdata: DMFFTopology):
adj, node = mol_to_graph_matrix(topdata)
natom = adj.shape[0]
support = np.dot(node[:, :117], EMBED_W1[:117, :117])
out = np.dot(adj, support)
out = np.concatenate((out, node[:, 117:]), axis=1)
return out


def get_eqv_atoms(topdata: DMFFTopology):
embed = get_embed(topdata)
natom, nfeat = embed.shape[0], embed.shape[1]
dist = np.power(
embed.reshape((natom, 1, nfeat)) - embed.reshape((1, natom, nfeat)), 2
).sum(axis=2)
eqv_list = []
for na in range(natom):
eqv_list.append([na])
for nb in range(natom):
if dist[na, nb] < 1e-2 and na != nb:
eqv_list[-1].append(nb)
return eqv_list

1 change: 1 addition & 0 deletions dmff/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

def matchTemplate(graph, template):
if graph.number_of_nodes() != template.number_of_nodes():
print("Node with different number of nodes.")
return False, {}, {}

def match_func(n1, n2):
Expand Down
18 changes: 12 additions & 6 deletions dmff/api/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .paramset import ParamSet
from .topology import DMFFTopology
from ..operators.templatetype import TemplateATypeOperator
from ..operators.templatevsite import TemplateVSiteOperator
from ..utils import DMFFException
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -48,7 +49,9 @@ def __init__(self, topology: DMFFTopology, update_func: Callable):
def add(self, name, func):
self.dmff_potentials[name] = func

def getPotentialFunc(self, names=[]):
def getPotentialFunc(self, names: list=[]):
if isinstance(names, str):
names = [names]
if len(names) == 0:
names = self.dmff_potentials.keys()
def efunc(positions, box, pairs, prms):
Expand Down Expand Up @@ -83,22 +86,25 @@ def __init__(self, *args):
else:
self.generators[key] = _DMFFGenerators[key](
ffinfo, self.paramset)

def getGenerators(self):
return [g for g in self.generators.values()]

def createJaxPotential(self, topdata: Union[DMFFTopology, app.Topology], nonbondedMethod=app.NoCutoff,
nonbondedCutoff=1.0 * unit.nanometer, args={}, forces=None):
def createPotential(self, topdata: Union[DMFFTopology, app.Topology], nonbondedMethod=app.NoCutoff,
nonbondedCutoff=1.0 * unit.nanometer, **kwargs):
if isinstance(topdata, app.Topology):
topdata = DMFFTopology(from_top=topdata)
# initialize template operator
vsite = TemplateVSiteOperator(self.ffinfo)
topdata = vsite(topdata)
template = TemplateATypeOperator(self.ffinfo)
topdata = template(topdata)

efuncs = {}
for key in self.generators:
gen = self.generators[key]
if forces is not None and gen.getName() not in forces:
continue
efuncs[gen.getName()] = gen.createPotential(topdata, nonbondedMethod,
nonbondedCutoff, args)
nonbondedCutoff, **kwargs)

update_func = topdata.buildVSiteUpdateFunction()
potential = Potential(topdata, update_func)
Expand Down
91 changes: 67 additions & 24 deletions dmff/api/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@
'ARG', 'ASP', 'GLN', 'GLY', 'ILE', 'LYS', 'PHE', 'SER', 'TRP', 'VAL',
'A', 'G', 'C', 'U', 'I', 'DA', 'DG', 'DC', 'DT', 'DI']

EMBED_W1 = np.random.random((117, 117))

elem_to_index = {'EP': 0, 'H': 1, 'HE': 2, 'LI': 3, 'BE': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
'F': 9, 'NE': 10, 'NA': 11, 'MG': 12, 'AL': 13, 'SI': 14, 'P': 15, 'S': 16,
'CL': 17, 'AR': 18, 'K': 19, 'CA': 20, 'SC': 21, 'TI': 22, 'V': 23, 'CR': 24,
'MN': 25, 'FE': 26, 'CO': 27, 'NI': 28, 'CU': 29, 'ZN': 30, 'GA': 31, 'GE': 32,
'AS': 33, 'SE': 34, 'BR': 35, 'KR': 36, 'RB': 37, 'SR': 38, 'Y': 39, 'ZR': 40,
'NB': 41, 'MO': 42, 'TC': 43, 'RU': 44, 'RH': 45, 'PD': 46, 'AG': 47, 'CD': 48,
'IN': 49, 'SN': 50, 'SB': 51, 'TE': 52, 'I': 53, 'XE': 54, 'CS': 55, 'BA': 56,
'LA': 57, 'CE': 58, 'PR': 59, 'ND': 60, 'PM': 61, 'SM': 62, 'EU': 63, 'GD': 64,
'TB': 65, 'DY': 66, 'HO': 67, 'ER': 68, 'TM': 69, 'YB': 70, 'LU': 71, 'HF': 72,
'TA': 73, 'W': 74, 'RE': 75, 'OS': 76, 'IR': 77, 'PT': 78, 'AU': 79, 'HG': 80,
'TL': 81, 'PB': 82, 'BI': 83, 'PO': 84, 'AT': 85, 'RN': 86, 'FR': 87, 'RA': 88,
'AC': 89, 'TH': 90, 'PA': 91, 'U': 92, 'NP': 93, 'PU': 94, 'AM': 95, 'CM': 96,
'BK': 97, 'CF': 98, 'ES': 99, 'FM': 100, 'MD': 101, 'NO': 102, 'LR': 103,
'RF': 104, 'DB': 105, 'SG': 106, 'BH': 107, 'HS': 108, 'MT': 109, 'DS': 110,
'RG': 111, 'UUB': 112, 'UUT': 113, 'UUQ': 114, 'UUP': 115, 'UUH': 116}

class DMFFTopology:
def __init__(self, from_top=None, from_sdf=None, from_rdmol=None, residue_name="MOL"):
Expand Down Expand Up @@ -192,7 +209,11 @@ def updateMolecules(self, sanitize=True):
resname = atoms[ind[0]].residue.name
self._molecules.append(top2rdmol(self, ind))
if sanitize and resname not in _standardResidues:
self.regularize_aromaticity(self._molecules[-1])
try:
self.regularize_aromaticity(self._molecules[-1])
except BaseException as e:
print(e)
print("Warning: aromaticity regularize failed for residue %s" % resname)

def parseSMARTS(self, parser, resname=[]):
atoms = [a for a in self.atoms()]
Expand Down Expand Up @@ -492,29 +513,8 @@ def addVSiteToPos(self, positions):
return update_func(new_pos)

def getEquivalentAtoms(self):
graph = nx.Graph()
for atom in self.atoms():
elem = atom.meta["element"]
if elem == "EP":
continue
graph.add_node(atom.index, elem=elem)
for bond in self.bonds():
a1, a2 = bond.atom1, bond.atom2
graph.add_edge(a1.index, a2.index, btype="bond")

def match_node(n1, n2):
return n1["elem"] == n2["elem"]

ismags = nx.isomorphism.ISMAGS(graph, graph, node_match=match_node)
isomorphisms = list(ismags.isomorphisms_iter(symmetry=False))
eq_atoms = {}
for atom in self.atoms():
elem = atom.meta["element"]
if elem == "EP":
eq_atoms[atom.index] = [atom.index]
eq_atoms[atom.index] = list(
set([i[atom.index] for i in isomorphisms]))
return eq_atoms
eqv_atoms = get_eqv_atoms(self)
return eqv_atoms

def getPeriodicBoxVectors(self, use_jax=True):
if use_jax:
Expand Down Expand Up @@ -709,3 +709,46 @@ def top2rdmol(top, indices) -> Chem.rdchem.Mol:
# rdmol.UpdatePropertyCache()
# AllChem.EmbedMolecule(rdmol, randomSeed=1)
return rdmol


def mol_to_graph_matrix(topdata: DMFFTopology) -> Tuple[np.ndarray, np.ndarray]:
num_atom = topdata.getNumAtoms()
adj = np.zeros((num_atom, num_atom))
node_features = np.zeros((num_atom, 117))
atoms = [a for a in topdata.atoms()]
for i in range(num_atom):
adj[i, i] = 1
node_features[i, elem_to_index[atoms[i].element.upper()]] = 1
for bond in topdata.bonds():
adj[bond.atom1.index, bond.atom2.index] = 1
adj[bond.atom2.index, bond.atom1.index] = 1
for vsite in topdata.vsites():
adj[vsite.vatom.index, vsite.atoms[0].index] = 1
adj[vsite.atoms[0].index, vsite.vatom.index] = 1
return adj, node_features


def get_embed(topdata: DMFFTopology):
adj, node = mol_to_graph_matrix(topdata)
natom = adj.shape[0]
support = np.dot(node, EMBED_W1)
out = np.dot(adj, support)
out = np.dot(out, EMBED_W1)
out = np.dot(adj, out)
return out


def get_eqv_atoms(topdata: DMFFTopology):
embed = get_embed(topdata)
natom, nfeat = embed.shape[0], embed.shape[1]
dist = np.power(
embed.reshape((natom, 1, nfeat)) - embed.reshape((1, natom, nfeat)), 2
).sum(axis=2)
eqv_list = []
for na in range(natom):
eqv_list.append([na])
for nb in range(natom):
if dist[na, nb] < 1e-2 and na != nb:
eqv_list[-1].append(nb)
return eqv_list

4 changes: 4 additions & 0 deletions dmff/api/vstools.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,9 @@ def insertVirtualSites(topdata, vsite_list):
# regularize_aromaticity(rdmol)
# newtop._molecules.append(rdmol)
newtop.updateMolecules(sanitize=True)

# copy box info to newtop
if topdata.cell is not None:
newtop.setPeriodicBoxVectors(topdata.getPeriodicBoxVectors(use_jax=False))
return newtop

2 changes: 2 additions & 0 deletions dmff/classical/intra.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(self, p1idx, p2idx, p3idx, p4idx, prmidx, order):
self.refresh_calculators()

def generate_get_energy(self):
if len(self.p1idx) == 0:
return lambda positions, box, pairs, k, psi: 0.0
def get_energy(positions, box, pairs, k, psi):
p1 = positions[self.p1idx,:]
p2 = positions[self.p2idx,:]
Expand Down
Loading

0 comments on commit bd62fcc

Please sign in to comment.