Skip to content

Commit

Permalink
Add frontend for sGNN (#125)
Browse files Browse the repository at this point in the history
* Add sGNN generator
fixed a few problems in ADMPPmeGenerator

* remove debugging codes
  • Loading branch information
KuangYu authored Oct 22, 2023
1 parent 66d2eb9 commit 800480a
Show file tree
Hide file tree
Showing 38 changed files with 737 additions and 208 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ jobs:
pytest -vs tests/test_common/test_*
pytest -vs tests/test_admp/test_*
pytest -vs tests/test_utils.py
pytest -vs tests/test_mbar/test_*
pytest -vs tests/test_mbar/test_*
pytest -vs tests/test_sgnn/test_*
2 changes: 1 addition & 1 deletion dmff/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def matchTemplate(graph, template):
if graph.number_of_nodes() != template.number_of_nodes():
print("Node with different number of nodes.")
# print("Node with different number of nodes.")
return False, {}, {}

def match_func(n1, n2):
Expand Down
3 changes: 2 additions & 1 deletion dmff/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .classical import *
from .admp import *
from .admp import *
from .ml import *
34 changes: 24 additions & 10 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,15 +822,28 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
kzs.append(kz)
# record multipoles
c0.append(float(attribs["c0"]))
dX.append(float(attribs["dX"]))
dY.append(float(attribs["dY"]))
dZ.append(float(attribs["dZ"]))
qXX.append(float(attribs["qXX"]))
qYY.append(float(attribs["qYY"]))
qZZ.append(float(attribs["qZZ"]))
qXY.append(float(attribs["qXY"]))
qXZ.append(float(attribs["qXZ"]))
qYZ.append(float(attribs["qYZ"]))
if self.lmax >= 1:
dX.append(float(attribs["dX"]))
dY.append(float(attribs["dY"]))
dZ.append(float(attribs["dZ"]))
else:
dX.append(0.0)
dY.append(0.0)
dZ.append(0.0)
if self.lmax >= 2:
qXX.append(float(attribs["qXX"]))
qYY.append(float(attribs["qYY"]))
qZZ.append(float(attribs["qZZ"]))
qXY.append(float(attribs["qXY"]))
qXZ.append(float(attribs["qXZ"]))
qYZ.append(float(attribs["qYZ"]))
else:
qXX.append(0.0)
qYY.append(0.0)
qZZ.append(0.0)
qXY.append(0.0)
qXZ.append(0.0)
qYZ.append(0.0)
mask = 1.0
if "mask" in attribs and attribs["mask"].upper() == "TRUE":
mask = 0.0
Expand Down Expand Up @@ -1146,6 +1159,7 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof
pme_force = ADMPPmeForce(box, axis_types, axis_indices, rc,
self.ethresh, self.lmax, self.lpol, lpme,
self.step_pol)
self.pme_force = pme_force

def potential_fn(positions, box, pairs, params):
positions = positions * 10
Expand Down Expand Up @@ -1181,4 +1195,4 @@ def getMetaData(self):
return self._meta


_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator
_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator
74 changes: 74 additions & 0 deletions dmff/generators/ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from ..api.topology import DMFFTopology
from ..api.paramset import ParamSet
from ..api.hamiltonian import _DMFFGenerators
from ..utils import DMFFException, isinstance_jnp
from ..utils import jit_condition
import numpy as np
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
import pickle

from ..sgnn.graph import MAX_VALENCE, TopGraph, from_pdb
from ..sgnn.gnn import MolGNNForce, prm_transform_f2i


class SGNNGenerator:
def __init__(self, ffinfo: dict, paramset: ParamSet):

self.name = "SGNNForce"
self.ffinfo = ffinfo
paramset.addField(self.name)
self.key_type = None

self.file = self.ffinfo["Forces"][self.name]["meta"]["file"]
self.nn = int(self.ffinfo["Forces"][self.name]["meta"]["nn"])
self.pdb = self.ffinfo["Forces"][self.name]["meta"]["pdb"]

# load ML potential parameters
with open(self.file, 'rb') as ifile:
params = pickle.load(ifile)

# convert to jnp array
for k in params:
params[k] = jnp.array(params[k])
# set mask to all true
paramset.addParameter(params[k], k, field=self.name, mask=jnp.ones(params[k].shape))

# mask = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape), params)
# paramset.addParameter(params, "params", field=self.name, mask=mask)


def getName(self) -> str:
return self.name

def overwrite(self, paramset):
# do not use xml to handle ML potentials
# for ML potentials, xml only documents param file path
# so for ML potentials, overwrite function overwrites the file directly
with open(self.file, 'wb') as ofile:
pickle.dump(paramset[self.name], ofile)
return

def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs):
self.G = from_pdb(self.pdb)
n_atoms = topdata.getNumAtoms()
self.model = MolGNNForce(self.G, nn=self.nn)
n_layers = self.model.n_layers
def potential_fn(positions, box, pairs, params):
# convert unit to angstrom
positions = positions * 10
box = box * 10
prms = prm_transform_f2i(params[self.name], n_layers)
return self.model.get_energy(positions, box, prms)

self._jaxPotential = potential_fn
return potential_fn

def getJaxPotential(self):
return self._jaxPotential


_DMFFGenerators["SGNNForce"] = SGNNGenerator

76 changes: 40 additions & 36 deletions dmff/sgnn/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,39 @@
from jax import value_and_grad, vmap


def prm_transform_f2i(params, n_layers):
p = {}
for k in params:
p[k] = jnp.array(params[k])
for i_nn in [0, 1]:
nn_name = 'fc%d' % i_nn
p['%s.weight' % nn_name] = []
p['%s.bias' % nn_name] = []
for i_layer in range(n_layers[i_nn]):
k_w = '%s.%d.weight' % (nn_name, i_layer)
k_b = '%s.%d.bias' % (nn_name, i_layer)
p['%s.weight' % nn_name].append(p.pop(k_w, None))
p['%s.bias' % nn_name].append(p.pop(k_b, None))
return p


def prm_transform_i2f(params, n_layers):
# transform format
p = {}
p['w'] = params['w']
p['fc_final.weight'] = params['fc_final.weight']
p['fc_final.bias'] = params['fc_final.bias']
for i_nn in range(2):
nn_name = 'fc%d' % i_nn
for i_layer in range(n_layers[i_nn]):
p[nn_name + '.%d.weight' %
i_layer] = params[nn_name + '.weight'][i_layer]
p[nn_name +
'.%d.bias' % i_layer] = params[nn_name +
'.bias'][i_layer]
return p


class MolGNNForce:

def __init__(self,
Expand Down Expand Up @@ -146,6 +179,7 @@ def message_pass(f_in, nb_connect, w, nn):

return


def load_params(self, ifn):
""" Load the network parameters from saved file
Expand All @@ -160,32 +194,12 @@ def load_params(self, ifn):
for k in params.keys():
params[k] = jnp.array(params[k])
# transform format
keys = list(params.keys())
for i_nn in [0, 1]:
nn_name = 'fc%d' % i_nn
keys_weight = []
keys_bias = []
for k in keys:
if re.search(nn_name + '.[0-9]+.weight', k) is not None:
keys_weight.append(k)
elif re.search(nn_name + '.[0-9]+.bias', k) is not None:
keys_bias.append(k)
if len(keys_weight) != self.n_layers[i_nn] or len(
keys_bias) != self.n_layers[i_nn]:
sys.exit(
'Error while loading GNN params, inconsistent inputs with the GNN structure, check your input!'
)
params['%s.weight' % nn_name] = []
params['%s.bias' % nn_name] = []
for i_layer in range(self.n_layers[i_nn]):
k_w = '%s.%d.weight' % (nn_name, i_layer)
k_b = '%s.%d.bias' % (nn_name, i_layer)
params['%s.weight' % nn_name].append(params.pop(k_w, None))
params['%s.bias' % nn_name].append(params.pop(k_b, None))
# params[nn_name]
self.params = params
self.params = prm_transform_f2i(params, self.n_layers)
return




def save_params(self, ofn):
""" Save the network parameters to a pickle file
Expand All @@ -196,18 +210,8 @@ def save_params(self, ofn):
"""
# transform format
params = {}
params['w'] = self.params['w']
params['fc_final.weight'] = self.params['fc_final.weight']
params['fc_final.bias'] = self.params['fc_final.bias']
for i_nn in range(2):
nn_name = 'fc%d' % i_nn
for i_layer in range(self.n_layers[i_nn]):
params[nn_name + '.%d.weight' %
i_layer] = self.params[nn_name + '.weight'][i_layer]
params[nn_name +
'.%d.bias' % i_layer] = self.params[nn_name +
'.bias'][i_layer]
params = prm_transform_i2f(self.params, self.n_layers)
with open(ofn, 'wb') as ofile:
pickle.dump(params, ofile)
return

14 changes: 14 additions & 0 deletions dmff/sgnn/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,20 @@ def from_pdb(pdb):
return TopGraph(list_atom_elems, bonds, positions=positions, box=box)


# def from_dmff_top(topdata):
# '''
# Build the sGNN TopGraph object from a DMFFTopology object

# Parameters
# ----------
# topdata: DMFFTopology data
# '''
# list_atom_elems = np.array([a.element for a in topdata.atoms()])
# bonds = np.array([np.sort([b.atom1.index, b.atom2.index]) for b in topdata.bonds()])
# n_atoms = len(list_atom_elems)
# return TopGraph(list_atom_elems, bonds, positions=jnp.zeros((n_atoms, 3)), box=jnp.eye(3)*10)


def validation():
G = from_pdb('peg4.pdb')
nn = 1
Expand Down
3 changes: 2 additions & 1 deletion examples/classical/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ def getEnergyDecomposition(context, forcegroups):
print("Nonbonded:", nbE(positions, box, pairs, params))

etotal = pot.getPotentialFunc()
print("Total:", etotal(positions, box, pairs, params))
print("Total:", etotal(positions, box, pairs, params))

Binary file removed examples/sgnn/model1.pickle
Binary file not shown.
1 change: 1 addition & 0 deletions examples/sgnn/model1.pickle
48 changes: 48 additions & 0 deletions examples/sgnn/peg.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<ForceField>
<AtomTypes>
<Type class="CT" element="C" mass="12.0107" name="1" />
<Type class="HC" element="H" mass="1.00784" name="2" />
<Type class="OS" element="O" mass="15.999" name="3" />
<Type class="CT" element="C" mass="12.0107" name="4" />
<Type class="HC" element="H" mass="1.00784" name="5" />
</AtomTypes>
<Residues>
<Residue name="TER">
<Atom name="C00" type="1" />
<Atom name="H01" type="2" />
<Atom name="H02" type="2" />
<Atom name="O03" type="3" />
<Atom name="C04" type="4" />
<Atom name="H05" type="5" />
<Atom name="H06" type="5" />
<Atom name="H07" type="5" />
<Bond from="0" to="1" />
<Bond from="0" to="2" />
<Bond from="0" to="3" />
<Bond from="3" to="4" />
<Bond from="4" to="5" />
<Bond from="4" to="6" />
<Bond from="4" to="7" />
<ExternalBond atomName="C00" />
</Residue>
<Residue name="INT">
<Atom name="C00" type="1" />
<Atom name="H01" type="2" />
<Atom name="H02" type="2" />
<Atom name="O03" type="3" />
<Atom name="C04" type="1" />
<Atom name="H05" type="2" />
<Atom name="H06" type="2" />
<Bond from="0" to="1" />
<Bond from="0" to="2" />
<Bond from="0" to="3" />
<Bond from="3" to="4" />
<Bond from="4" to="5" />
<Bond from="4" to="6" />
<ExternalBond atomName="C00" />
<ExternalBond atomName="C04" />
</Residue>
</Residues>
<SGNNForce file="model1.pickle" pdb="peg4.pdb" nn="1"/>
</ForceField>

Loading

0 comments on commit 800480a

Please sign in to comment.