diff --git a/.github/workflows/ut.yml b/.github/workflows/ut.yml
index 701acf9bb..5b8de3ba1 100644
--- a/.github/workflows/ut.yml
+++ b/.github/workflows/ut.yml
@@ -33,4 +33,5 @@ jobs:
pytest -vs tests/test_common/test_*
pytest -vs tests/test_admp/test_*
pytest -vs tests/test_utils.py
- pytest -vs tests/test_mbar/test_*
+ pytest -vs tests/test_mbar/test_*
+ pytest -vs tests/test_sgnn/test_*
diff --git a/dmff/api/graph.py b/dmff/api/graph.py
index 6e1ccd6c0..64e9af10a 100644
--- a/dmff/api/graph.py
+++ b/dmff/api/graph.py
@@ -13,7 +13,7 @@
def matchTemplate(graph, template):
if graph.number_of_nodes() != template.number_of_nodes():
- print("Node with different number of nodes.")
+ # print("Node with different number of nodes.")
return False, {}, {}
def match_func(n1, n2):
diff --git a/dmff/generators/__init__.py b/dmff/generators/__init__.py
index 7ddb93293..6f37cf7f0 100644
--- a/dmff/generators/__init__.py
+++ b/dmff/generators/__init__.py
@@ -1,2 +1,3 @@
from .classical import *
-from .admp import *
\ No newline at end of file
+from .admp import *
+from .ml import *
diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py
index bf1cce2e0..cada68fe1 100644
--- a/dmff/generators/admp.py
+++ b/dmff/generators/admp.py
@@ -822,15 +822,28 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
kzs.append(kz)
# record multipoles
c0.append(float(attribs["c0"]))
- dX.append(float(attribs["dX"]))
- dY.append(float(attribs["dY"]))
- dZ.append(float(attribs["dZ"]))
- qXX.append(float(attribs["qXX"]))
- qYY.append(float(attribs["qYY"]))
- qZZ.append(float(attribs["qZZ"]))
- qXY.append(float(attribs["qXY"]))
- qXZ.append(float(attribs["qXZ"]))
- qYZ.append(float(attribs["qYZ"]))
+ if self.lmax >= 1:
+ dX.append(float(attribs["dX"]))
+ dY.append(float(attribs["dY"]))
+ dZ.append(float(attribs["dZ"]))
+ else:
+ dX.append(0.0)
+ dY.append(0.0)
+ dZ.append(0.0)
+ if self.lmax >= 2:
+ qXX.append(float(attribs["qXX"]))
+ qYY.append(float(attribs["qYY"]))
+ qZZ.append(float(attribs["qZZ"]))
+ qXY.append(float(attribs["qXY"]))
+ qXZ.append(float(attribs["qXZ"]))
+ qYZ.append(float(attribs["qYZ"]))
+ else:
+ qXX.append(0.0)
+ qYY.append(0.0)
+ qZZ.append(0.0)
+ qXY.append(0.0)
+ qXZ.append(0.0)
+ qYZ.append(0.0)
mask = 1.0
if "mask" in attribs and attribs["mask"].upper() == "TRUE":
mask = 0.0
@@ -1146,6 +1159,7 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof
pme_force = ADMPPmeForce(box, axis_types, axis_indices, rc,
self.ethresh, self.lmax, self.lpol, lpme,
self.step_pol)
+ self.pme_force = pme_force
def potential_fn(positions, box, pairs, params):
positions = positions * 10
@@ -1181,4 +1195,4 @@ def getMetaData(self):
return self._meta
-_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator
\ No newline at end of file
+_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator
diff --git a/dmff/generators/ml.py b/dmff/generators/ml.py
new file mode 100644
index 000000000..afec8cee8
--- /dev/null
+++ b/dmff/generators/ml.py
@@ -0,0 +1,74 @@
+from ..api.topology import DMFFTopology
+from ..api.paramset import ParamSet
+from ..api.hamiltonian import _DMFFGenerators
+from ..utils import DMFFException, isinstance_jnp
+from ..utils import jit_condition
+import numpy as np
+import jax
+import jax.numpy as jnp
+import openmm.app as app
+import openmm.unit as unit
+import pickle
+
+from ..sgnn.graph import MAX_VALENCE, TopGraph, from_pdb
+from ..sgnn.gnn import MolGNNForce, prm_transform_f2i
+
+
+class SGNNGenerator:
+ def __init__(self, ffinfo: dict, paramset: ParamSet):
+
+ self.name = "SGNNForce"
+ self.ffinfo = ffinfo
+ paramset.addField(self.name)
+ self.key_type = None
+
+ self.file = self.ffinfo["Forces"][self.name]["meta"]["file"]
+ self.nn = int(self.ffinfo["Forces"][self.name]["meta"]["nn"])
+ self.pdb = self.ffinfo["Forces"][self.name]["meta"]["pdb"]
+
+ # load ML potential parameters
+ with open(self.file, 'rb') as ifile:
+ params = pickle.load(ifile)
+
+ # convert to jnp array
+ for k in params:
+ params[k] = jnp.array(params[k])
+ # set mask to all true
+ paramset.addParameter(params[k], k, field=self.name, mask=jnp.ones(params[k].shape))
+
+ # mask = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape), params)
+ # paramset.addParameter(params, "params", field=self.name, mask=mask)
+
+
+ def getName(self) -> str:
+ return self.name
+
+ def overwrite(self, paramset):
+ # do not use xml to handle ML potentials
+ # for ML potentials, xml only documents param file path
+ # so for ML potentials, overwrite function overwrites the file directly
+ with open(self.file, 'wb') as ofile:
+ pickle.dump(paramset[self.name], ofile)
+ return
+
+ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs):
+ self.G = from_pdb(self.pdb)
+ n_atoms = topdata.getNumAtoms()
+ self.model = MolGNNForce(self.G, nn=self.nn)
+ n_layers = self.model.n_layers
+ def potential_fn(positions, box, pairs, params):
+ # convert unit to angstrom
+ positions = positions * 10
+ box = box * 10
+ prms = prm_transform_f2i(params[self.name], n_layers)
+ return self.model.get_energy(positions, box, prms)
+
+ self._jaxPotential = potential_fn
+ return potential_fn
+
+ def getJaxPotential(self):
+ return self._jaxPotential
+
+
+_DMFFGenerators["SGNNForce"] = SGNNGenerator
+
diff --git a/dmff/sgnn/gnn.py b/dmff/sgnn/gnn.py
index 49a5ee985..403d36429 100755
--- a/dmff/sgnn/gnn.py
+++ b/dmff/sgnn/gnn.py
@@ -13,6 +13,39 @@
from jax import value_and_grad, vmap
+def prm_transform_f2i(params, n_layers):
+ p = {}
+ for k in params:
+ p[k] = jnp.array(params[k])
+ for i_nn in [0, 1]:
+ nn_name = 'fc%d' % i_nn
+ p['%s.weight' % nn_name] = []
+ p['%s.bias' % nn_name] = []
+ for i_layer in range(n_layers[i_nn]):
+ k_w = '%s.%d.weight' % (nn_name, i_layer)
+ k_b = '%s.%d.bias' % (nn_name, i_layer)
+ p['%s.weight' % nn_name].append(p.pop(k_w, None))
+ p['%s.bias' % nn_name].append(p.pop(k_b, None))
+ return p
+
+
+def prm_transform_i2f(params, n_layers):
+ # transform format
+ p = {}
+ p['w'] = params['w']
+ p['fc_final.weight'] = params['fc_final.weight']
+ p['fc_final.bias'] = params['fc_final.bias']
+ for i_nn in range(2):
+ nn_name = 'fc%d' % i_nn
+ for i_layer in range(n_layers[i_nn]):
+ p[nn_name + '.%d.weight' %
+ i_layer] = params[nn_name + '.weight'][i_layer]
+ p[nn_name +
+ '.%d.bias' % i_layer] = params[nn_name +
+ '.bias'][i_layer]
+ return p
+
+
class MolGNNForce:
def __init__(self,
@@ -146,6 +179,7 @@ def message_pass(f_in, nb_connect, w, nn):
return
+
def load_params(self, ifn):
""" Load the network parameters from saved file
@@ -160,32 +194,12 @@ def load_params(self, ifn):
for k in params.keys():
params[k] = jnp.array(params[k])
# transform format
- keys = list(params.keys())
- for i_nn in [0, 1]:
- nn_name = 'fc%d' % i_nn
- keys_weight = []
- keys_bias = []
- for k in keys:
- if re.search(nn_name + '.[0-9]+.weight', k) is not None:
- keys_weight.append(k)
- elif re.search(nn_name + '.[0-9]+.bias', k) is not None:
- keys_bias.append(k)
- if len(keys_weight) != self.n_layers[i_nn] or len(
- keys_bias) != self.n_layers[i_nn]:
- sys.exit(
- 'Error while loading GNN params, inconsistent inputs with the GNN structure, check your input!'
- )
- params['%s.weight' % nn_name] = []
- params['%s.bias' % nn_name] = []
- for i_layer in range(self.n_layers[i_nn]):
- k_w = '%s.%d.weight' % (nn_name, i_layer)
- k_b = '%s.%d.bias' % (nn_name, i_layer)
- params['%s.weight' % nn_name].append(params.pop(k_w, None))
- params['%s.bias' % nn_name].append(params.pop(k_b, None))
- # params[nn_name]
- self.params = params
+ self.params = prm_transform_f2i(params, self.n_layers)
return
+
+
+
def save_params(self, ofn):
""" Save the network parameters to a pickle file
@@ -196,18 +210,8 @@ def save_params(self, ofn):
"""
# transform format
- params = {}
- params['w'] = self.params['w']
- params['fc_final.weight'] = self.params['fc_final.weight']
- params['fc_final.bias'] = self.params['fc_final.bias']
- for i_nn in range(2):
- nn_name = 'fc%d' % i_nn
- for i_layer in range(self.n_layers[i_nn]):
- params[nn_name + '.%d.weight' %
- i_layer] = self.params[nn_name + '.weight'][i_layer]
- params[nn_name +
- '.%d.bias' % i_layer] = self.params[nn_name +
- '.bias'][i_layer]
+ params = prm_transform_i2f(self.params, self.n_layers)
with open(ofn, 'wb') as ofile:
pickle.dump(params, ofile)
return
+
diff --git a/dmff/sgnn/graph.py b/dmff/sgnn/graph.py
index 93f6a809a..164a41f3f 100755
--- a/dmff/sgnn/graph.py
+++ b/dmff/sgnn/graph.py
@@ -1219,6 +1219,20 @@ def from_pdb(pdb):
return TopGraph(list_atom_elems, bonds, positions=positions, box=box)
+# def from_dmff_top(topdata):
+# '''
+# Build the sGNN TopGraph object from a DMFFTopology object
+
+# Parameters
+# ----------
+# topdata: DMFFTopology data
+# '''
+# list_atom_elems = np.array([a.element for a in topdata.atoms()])
+# bonds = np.array([np.sort([b.atom1.index, b.atom2.index]) for b in topdata.bonds()])
+# n_atoms = len(list_atom_elems)
+# return TopGraph(list_atom_elems, bonds, positions=jnp.zeros((n_atoms, 3)), box=jnp.eye(3)*10)
+
+
def validation():
G = from_pdb('peg4.pdb')
nn = 1
diff --git a/examples/classical/test_xml.py b/examples/classical/test_xml.py
index e84c6d849..c5d594033 100755
--- a/examples/classical/test_xml.py
+++ b/examples/classical/test_xml.py
@@ -82,4 +82,5 @@ def getEnergyDecomposition(context, forcegroups):
print("Nonbonded:", nbE(positions, box, pairs, params))
etotal = pot.getPotentialFunc()
- print("Total:", etotal(positions, box, pairs, params))
\ No newline at end of file
+ print("Total:", etotal(positions, box, pairs, params))
+
diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle
deleted file mode 100644
index 0c3959cd9..000000000
Binary files a/examples/sgnn/model1.pickle and /dev/null differ
diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle
new file mode 120000
index 000000000..88c340bb9
--- /dev/null
+++ b/examples/sgnn/model1.pickle
@@ -0,0 +1 @@
+test_backend/model1.pickle
\ No newline at end of file
diff --git a/examples/sgnn/peg.xml b/examples/sgnn/peg.xml
new file mode 100644
index 000000000..d3f41baaf
--- /dev/null
+++ b/examples/sgnn/peg.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/sgnn/peg4.pdb b/examples/sgnn/peg4.pdb
deleted file mode 100644
index 2c11081d1..000000000
--- a/examples/sgnn/peg4.pdb
+++ /dev/null
@@ -1,63 +0,0 @@
-REMARK
-CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1
-ATOM 1 C00 TER 1 -2.962 3.637 -1.170
-ATOM 2 H01 TER 1 -2.608 4.142 -0.296
-ATOM 3 H02 TER 1 -4.032 3.635 -1.171
-ATOM 4 O03 TER 1 -2.484 2.289 -1.168
-ATOM 5 C04 TER 1 -2.961 1.615 0.000
-ATOM 6 H05 TER 1 -2.604 0.606 0.000
-ATOM 7 H06 TER 1 -2.604 2.119 0.874
-ATOM 8 H07 TER 1 -4.031 1.615 0.000
-ATOM 9 C00 INT 2 -2.449 6.384 -3.596
-ATOM 10 H01 INT 2 -2.804 5.879 -4.470
-ATOM 11 H02 INT 2 -1.379 6.386 -3.595
-ATOM 12 O03 INT 2 -2.927 5.710 -2.429
-ATOM 13 C04 INT 2 -2.448 4.362 -2.427
-ATOM 14 H05 INT 2 -2.803 3.856 -3.301
-ATOM 15 H06 INT 2 -1.378 4.364 -2.425
-ATOM 16 C00 INT 3 -2.966 9.857 -4.767
-ATOM 17 H01 INT 3 -2.612 10.363 -3.893
-ATOM 18 H02 INT 3 -4.036 9.855 -4.768
-ATOM 19 O03 INT 3 -2.488 8.509 -4.765
-ATOM 20 C04 INT 3 -2.965 7.835 -3.597
-ATOM 21 H05 INT 3 -2.610 8.340 -2.724
-ATOM 22 H06 INT 3 -4.035 7.833 -3.599
-ATOM 23 C00 TER 4 -2.452 10.582 -6.024
-ATOM 24 H01 TER 4 -2.807 10.077 -6.898
-ATOM 25 H02 TER 4 -1.382 10.584 -6.022
-ATOM 26 O03 TER 4 -2.931 11.930 -6.026
-ATOM 27 C04 TER 4 -2.453 12.604 -7.193
-ATOM 28 H05 TER 4 -2.808 12.099 -8.067
-ATOM 29 H06 TER 4 -2.812 13.613 -7.194
-ATOM 30 H07 TER 4 -1.383 12.606 -7.192
-TER
-CONECT 5 6
-CONECT 5 7
-CONECT 5 8
-CONECT 5 4
-CONECT 4 1
-CONECT 1 2
-CONECT 1 3
-CONECT 1 13
-CONECT 13 14
-CONECT 13 15
-CONECT 13 12
-CONECT 12 9
-CONECT 9 10
-CONECT 9 11
-CONECT 9 20
-CONECT 20 21
-CONECT 20 22
-CONECT 20 19
-CONECT 19 16
-CONECT 16 17
-CONECT 16 18
-CONECT 16 23
-CONECT 23 24
-CONECT 23 25
-CONECT 23 26
-CONECT 26 27
-CONECT 27 28
-CONECT 27 29
-CONECT 27 30
-END
diff --git a/examples/sgnn/peg4.pdb b/examples/sgnn/peg4.pdb
new file mode 120000
index 000000000..3c2bb15b6
--- /dev/null
+++ b/examples/sgnn/peg4.pdb
@@ -0,0 +1 @@
+test_backend/peg4.pdb
\ No newline at end of file
diff --git a/examples/sgnn/ref_out b/examples/sgnn/ref_out
index 96bc1e62f..039b75427 100644
--- a/examples/sgnn/ref_out
+++ b/examples/sgnn/ref_out
@@ -1,37 +1,5 @@
-Energy: -21.588394
-Force
-[[ 90.02814 2.0374336 35.38877 ]
- [ -98.410095 -1.6865425 -30.066338 ]
- [ 48.29245 31.675808 -43.390694 ]
- [ 59.717484 -35.94304 50.599678 ]
- [ -24.63767 218.36092 168.47194 ]
- [ 43.258293 81.24294 -87.22882 ]
- [ -67.66767 -17.780457 -5.6038494 ]
- [ -22.928284 -302.96246 -123.14815 ]
- [ 306.24683 -21.33866 -156.95491 ]
- [ -4.715515 13.664352 -23.222527 ]
- [-258.61304 -26.577957 85.58963 ]
- [ -10.179474 106.21161 64.846924 ]
- [-210.20566 -52.107193 58.04005 ]
- [ 118.68472 -8.033836 -81.18109 ]
- [ 44.02272 -34.508667 46.852356 ]
- [-214.84206 115.90286 -227.59117 ]
- [ 44.243336 -7.151741 26.06369 ]
- [ 87.46674 38.574554 192.17757 ]
- [ 27.345726 -58.87986 -44.685863 ]
- [ -83.354774 -29.714098 214.93097 ]
- [ -71.111305 34.880676 -77.53289 ]
- [ 141.12836 49.28147 -97.597305 ]
- [-220.25613 -134.58449 -23.567059 ]
- [ 75.2593 58.432755 -63.99505 ]
- [ 123.56466 -82.0066 94.63971 ]
- [ 57.822285 17.07631 -53.788273 ]
- [ -73.37115 0.50865555 16.240654 ]
- [ 54.86133 97.53715 73.672806 ]
- [ -23.997787 -73.92179 -13.749107 ]
- [ 62.348286 21.809956 25.78839 ]]
-Batched Energies:
-[-21.653 -39.830627 9.988983 -48.292953 -32.959183 -49.7164
- -47.617737 -51.76767 -37.42943 -35.06703 -46.111145 -31.748154
- -6.939003 -5.1853027 -27.427734 -44.695312 -52.027237 3.1541443
- -72.8221 -28.33014 ]
+-21.588284621154912
+[-21.58828462 -39.79334159 10.03889335 -48.22451239 -32.90970162
+ -49.68568287 -47.58035178 -51.73860617 -37.39235277 -35.01933271
+ -46.06621902 -31.69327601 -6.86739655 -5.13698524 -27.4031207
+ -44.65301991 -52.00357797 3.1734038 -72.79081259 -28.27007722]
diff --git a/examples/sgnn/residues.xml b/examples/sgnn/residues.xml
new file mode 100644
index 000000000..aa78866eb
--- /dev/null
+++ b/examples/sgnn/residues.xml
@@ -0,0 +1,37 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/sgnn/run.py b/examples/sgnn/run.py
index f87c887b5..14c7c1b84 100755
--- a/examples/sgnn/run.py
+++ b/examples/sgnn/run.py
@@ -1,45 +1,44 @@
#!/usr/bin/env python
import sys
-import numpy as np
+import jax
import jax.numpy as jnp
-import jax.lax as lax
-from jax import vmap, value_and_grad
-import dmff
-from dmff.sgnn.gnn import MolGNNForce
-from dmff.utils import jit_condition
-from dmff.sgnn.graph import MAX_VALENCE
-from dmff.sgnn.graph import TopGraph, from_pdb
+import openmm.app as app
+import openmm.unit as unit
+from dmff.api import Hamiltonian
+from dmff.common import nblist
+from jax import value_and_grad
import pickle
-import re
-from collections import OrderedDict
-from functools import partial
-
if __name__ == '__main__':
- # params = load_params('benchmark/model1.pickle')
- G = from_pdb('peg4.pdb')
- model = MolGNNForce(G, nn=1)
- model.load_params('model1.pickle')
- E = model.get_energy(G.positions, G.box, model.params)
+
+ H = Hamiltonian('peg.xml')
+ app.Topology.loadBondDefinitions("residues.xml")
+ pdb = app.PDBFile("peg4.pdb")
+ rc = 0.6
+ # generator stores all force field parameters
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, ethresh=5e-4)
+
+ # construct inputs
+ positions = jnp.array(pdb.positions._value)
+ a, b, c = pdb.topology.getPeriodicBoxVectors()
+ box = jnp.array([a._value, b._value, c._value])
+ # neighbor list
+ nbl = nblist.NeighborList(box, rc, pots.meta['cov_map'])
+ nbl.allocate(positions)
+
+
+ paramset = H.getParameters()
+ # params = paramset.parameters
- with open('set_test_lowT.pickle', 'rb') as ifile:
+ with open('test_backend/set_test_lowT.pickle', 'rb') as ifile:
data = pickle.load(ifile)
- # pos = jnp.array(data['positions'][0:100])
- # box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1))
- pos = jnp.array(data['positions'][0])
- box = jnp.eye(3) * 50
+ # input in nm
+ pos = jnp.array(data['positions'][0:20]) / 10
+ box = jnp.eye(3) * 5
- # energies = model.batch_forward(pos, box, model.params)
- E, F = value_and_grad(model.get_energy, argnums=(0))(pos, box, model.params)
- F = -F
- print('Energy:', E)
- print('Force')
- print(F)
+ efunc = jax.jit(pots.getPotentialFunc())
+ efunc_vmap = jax.vmap(jax.jit(pots.getPotentialFunc()), in_axes=(0, None, None, None), out_axes=0)
+ print(efunc(pos[0], box, nbl.pairs, paramset))
+ print(efunc_vmap(pos, box, nbl.pairs, paramset))
- # test batch processing
- pos = jnp.array(data['positions'][:20])
- box = jnp.tile(box, (20, 1, 1))
- E = model.batch_forward(pos, box, model.params)
- print('Batched Energies:')
- print(E)
diff --git a/examples/sgnn/model.pickle b/examples/sgnn/test_backend/model.pickle
similarity index 100%
rename from examples/sgnn/model.pickle
rename to examples/sgnn/test_backend/model.pickle
diff --git a/examples/sgnn/test_backend/model1.pickle b/examples/sgnn/test_backend/model1.pickle
new file mode 100644
index 000000000..0c3959cd9
Binary files /dev/null and b/examples/sgnn/test_backend/model1.pickle differ
diff --git a/examples/sgnn/model1.pth b/examples/sgnn/test_backend/model1.pth
similarity index 100%
rename from examples/sgnn/model1.pth
rename to examples/sgnn/test_backend/model1.pth
diff --git a/examples/sgnn/mse_testing.xvg b/examples/sgnn/test_backend/mse_testing.xvg
similarity index 100%
rename from examples/sgnn/mse_testing.xvg
rename to examples/sgnn/test_backend/mse_testing.xvg
diff --git a/examples/sgnn/test_backend/peg4.pdb b/examples/sgnn/test_backend/peg4.pdb
new file mode 100644
index 000000000..2c11081d1
--- /dev/null
+++ b/examples/sgnn/test_backend/peg4.pdb
@@ -0,0 +1,63 @@
+REMARK
+CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1
+ATOM 1 C00 TER 1 -2.962 3.637 -1.170
+ATOM 2 H01 TER 1 -2.608 4.142 -0.296
+ATOM 3 H02 TER 1 -4.032 3.635 -1.171
+ATOM 4 O03 TER 1 -2.484 2.289 -1.168
+ATOM 5 C04 TER 1 -2.961 1.615 0.000
+ATOM 6 H05 TER 1 -2.604 0.606 0.000
+ATOM 7 H06 TER 1 -2.604 2.119 0.874
+ATOM 8 H07 TER 1 -4.031 1.615 0.000
+ATOM 9 C00 INT 2 -2.449 6.384 -3.596
+ATOM 10 H01 INT 2 -2.804 5.879 -4.470
+ATOM 11 H02 INT 2 -1.379 6.386 -3.595
+ATOM 12 O03 INT 2 -2.927 5.710 -2.429
+ATOM 13 C04 INT 2 -2.448 4.362 -2.427
+ATOM 14 H05 INT 2 -2.803 3.856 -3.301
+ATOM 15 H06 INT 2 -1.378 4.364 -2.425
+ATOM 16 C00 INT 3 -2.966 9.857 -4.767
+ATOM 17 H01 INT 3 -2.612 10.363 -3.893
+ATOM 18 H02 INT 3 -4.036 9.855 -4.768
+ATOM 19 O03 INT 3 -2.488 8.509 -4.765
+ATOM 20 C04 INT 3 -2.965 7.835 -3.597
+ATOM 21 H05 INT 3 -2.610 8.340 -2.724
+ATOM 22 H06 INT 3 -4.035 7.833 -3.599
+ATOM 23 C00 TER 4 -2.452 10.582 -6.024
+ATOM 24 H01 TER 4 -2.807 10.077 -6.898
+ATOM 25 H02 TER 4 -1.382 10.584 -6.022
+ATOM 26 O03 TER 4 -2.931 11.930 -6.026
+ATOM 27 C04 TER 4 -2.453 12.604 -7.193
+ATOM 28 H05 TER 4 -2.808 12.099 -8.067
+ATOM 29 H06 TER 4 -2.812 13.613 -7.194
+ATOM 30 H07 TER 4 -1.383 12.606 -7.192
+TER
+CONECT 5 6
+CONECT 5 7
+CONECT 5 8
+CONECT 5 4
+CONECT 4 1
+CONECT 1 2
+CONECT 1 3
+CONECT 1 13
+CONECT 13 14
+CONECT 13 15
+CONECT 13 12
+CONECT 12 9
+CONECT 9 10
+CONECT 9 11
+CONECT 9 20
+CONECT 20 21
+CONECT 20 22
+CONECT 20 19
+CONECT 19 16
+CONECT 16 17
+CONECT 16 18
+CONECT 16 23
+CONECT 23 24
+CONECT 23 25
+CONECT 23 26
+CONECT 26 27
+CONECT 27 28
+CONECT 27 29
+CONECT 27 30
+END
diff --git a/examples/sgnn/pth2pickle.py b/examples/sgnn/test_backend/pth2pickle.py
similarity index 100%
rename from examples/sgnn/pth2pickle.py
rename to examples/sgnn/test_backend/pth2pickle.py
diff --git a/examples/sgnn/test_backend/ref_out b/examples/sgnn/test_backend/ref_out
new file mode 100644
index 000000000..96bc1e62f
--- /dev/null
+++ b/examples/sgnn/test_backend/ref_out
@@ -0,0 +1,37 @@
+Energy: -21.588394
+Force
+[[ 90.02814 2.0374336 35.38877 ]
+ [ -98.410095 -1.6865425 -30.066338 ]
+ [ 48.29245 31.675808 -43.390694 ]
+ [ 59.717484 -35.94304 50.599678 ]
+ [ -24.63767 218.36092 168.47194 ]
+ [ 43.258293 81.24294 -87.22882 ]
+ [ -67.66767 -17.780457 -5.6038494 ]
+ [ -22.928284 -302.96246 -123.14815 ]
+ [ 306.24683 -21.33866 -156.95491 ]
+ [ -4.715515 13.664352 -23.222527 ]
+ [-258.61304 -26.577957 85.58963 ]
+ [ -10.179474 106.21161 64.846924 ]
+ [-210.20566 -52.107193 58.04005 ]
+ [ 118.68472 -8.033836 -81.18109 ]
+ [ 44.02272 -34.508667 46.852356 ]
+ [-214.84206 115.90286 -227.59117 ]
+ [ 44.243336 -7.151741 26.06369 ]
+ [ 87.46674 38.574554 192.17757 ]
+ [ 27.345726 -58.87986 -44.685863 ]
+ [ -83.354774 -29.714098 214.93097 ]
+ [ -71.111305 34.880676 -77.53289 ]
+ [ 141.12836 49.28147 -97.597305 ]
+ [-220.25613 -134.58449 -23.567059 ]
+ [ 75.2593 58.432755 -63.99505 ]
+ [ 123.56466 -82.0066 94.63971 ]
+ [ 57.822285 17.07631 -53.788273 ]
+ [ -73.37115 0.50865555 16.240654 ]
+ [ 54.86133 97.53715 73.672806 ]
+ [ -23.997787 -73.92179 -13.749107 ]
+ [ 62.348286 21.809956 25.78839 ]]
+Batched Energies:
+[-21.653 -39.830627 9.988983 -48.292953 -32.959183 -49.7164
+ -47.617737 -51.76767 -37.42943 -35.06703 -46.111145 -31.748154
+ -6.939003 -5.1853027 -27.427734 -44.695312 -52.027237 3.1541443
+ -72.8221 -28.33014 ]
diff --git a/examples/sgnn/test_backend/run.py b/examples/sgnn/test_backend/run.py
new file mode 100755
index 000000000..f87c887b5
--- /dev/null
+++ b/examples/sgnn/test_backend/run.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+import sys
+import numpy as np
+import jax.numpy as jnp
+import jax.lax as lax
+from jax import vmap, value_and_grad
+import dmff
+from dmff.sgnn.gnn import MolGNNForce
+from dmff.utils import jit_condition
+from dmff.sgnn.graph import MAX_VALENCE
+from dmff.sgnn.graph import TopGraph, from_pdb
+import pickle
+import re
+from collections import OrderedDict
+from functools import partial
+
+
+if __name__ == '__main__':
+ # params = load_params('benchmark/model1.pickle')
+ G = from_pdb('peg4.pdb')
+ model = MolGNNForce(G, nn=1)
+ model.load_params('model1.pickle')
+ E = model.get_energy(G.positions, G.box, model.params)
+
+ with open('set_test_lowT.pickle', 'rb') as ifile:
+ data = pickle.load(ifile)
+
+ # pos = jnp.array(data['positions'][0:100])
+ # box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1))
+ pos = jnp.array(data['positions'][0])
+ box = jnp.eye(3) * 50
+
+ # energies = model.batch_forward(pos, box, model.params)
+ E, F = value_and_grad(model.get_energy, argnums=(0))(pos, box, model.params)
+ F = -F
+ print('Energy:', E)
+ print('Force')
+ print(F)
+
+ # test batch processing
+ pos = jnp.array(data['positions'][:20])
+ box = jnp.tile(box, (20, 1, 1))
+ E = model.batch_forward(pos, box, model.params)
+ print('Batched Energies:')
+ print(E)
diff --git a/examples/sgnn/set_test.pickle b/examples/sgnn/test_backend/set_test.pickle
similarity index 100%
rename from examples/sgnn/set_test.pickle
rename to examples/sgnn/test_backend/set_test.pickle
diff --git a/examples/sgnn/set_test_lowT.pickle b/examples/sgnn/test_backend/set_test_lowT.pickle
similarity index 100%
rename from examples/sgnn/set_test_lowT.pickle
rename to examples/sgnn/test_backend/set_test_lowT.pickle
diff --git a/examples/sgnn/test.py b/examples/sgnn/test_backend/test.py
similarity index 100%
rename from examples/sgnn/test.py
rename to examples/sgnn/test_backend/test.py
diff --git a/examples/sgnn/test_data.xvg b/examples/sgnn/test_backend/test_data.xvg
similarity index 100%
rename from examples/sgnn/test_data.xvg
rename to examples/sgnn/test_backend/test_data.xvg
diff --git a/examples/sgnn/train.py b/examples/sgnn/test_backend/train.py
similarity index 100%
rename from examples/sgnn/train.py
rename to examples/sgnn/test_backend/train.py
diff --git a/examples/water_fullpol/monopole_nonpol/run.py b/examples/water_fullpol/monopole_nonpol/run.py
index 3b1b4799f..617e96e76 100755
--- a/examples/water_fullpol/monopole_nonpol/run.py
+++ b/examples/water_fullpol/monopole_nonpol/run.py
@@ -13,18 +13,19 @@
H = Hamiltonian('forcefield.xml')
pdb = app.PDBFile("pair.pdb")
- rc = 6
+ rc = 0.6
# generator stores all force field parameters
params = H.getParameters()
- pot_pme = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom).dmff_potentials['ADMPPmeForce']
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer)
+ pot_pme = pots.dmff_potentials['ADMPPmeForce']
# construct inputs
- positions = jnp.array(pdb.positions._value) * 10
+ positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
- box = jnp.array([a._value, b._value, c._value]) * 10
+ box = jnp.array([a._value, b._value, c._value])
# neighbor list
- nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map)
+ nbl = nblist.NeighborList(box, rc, pots.meta['cov_map'])
nbl.allocate(positions)
E_pme, F_pme = value_and_grad(pot_pme)(positions, box, nbl.pairs, params)
diff --git a/examples/water_fullpol/monopole_polarizable/run.py b/examples/water_fullpol/monopole_polarizable/run.py
index 808ee5801..560cd8da0 100755
--- a/examples/water_fullpol/monopole_polarizable/run.py
+++ b/examples/water_fullpol/monopole_polarizable/run.py
@@ -13,18 +13,19 @@
H = Hamiltonian('forcefield.xml')
pdb = app.PDBFile("waterbox_31ang.pdb")
- rc = 6
+ rc = 0.6
# generator stores all force field parameters
params = H.getParameters()
- pot_pme = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom).dmff_potentials['ADMPPmeForce']
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer)
+ pot_pme = pots.dmff_potentials['ADMPPmeForce']
# construct inputs
- positions = jnp.array(pdb.positions._value) * 10
+ positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
- box = jnp.array([a._value, b._value, c._value]) * 10
+ box = jnp.array([a._value, b._value, c._value])
# neighbor list
- nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map)
+ nbl = nblist.NeighborList(box, rc, pots.meta['cov_map'])
nbl.allocate(positions)
E_pme, F_pme = value_and_grad(pot_pme)(positions, box, nbl.pairs, params)
diff --git a/examples/water_fullpol/quadrupole_nonpol/run.py b/examples/water_fullpol/quadrupole_nonpol/run.py
index b408792aa..0b6fe6394 100755
--- a/examples/water_fullpol/quadrupole_nonpol/run.py
+++ b/examples/water_fullpol/quadrupole_nonpol/run.py
@@ -14,20 +14,20 @@
H = Hamiltonian('forcefield.xml')
app.Topology.loadBondDefinitions("residues.xml")
pdb = app.PDBFile("waterbox_31ang.pdb")
- rc = 6
+ rc = 0.6
# generator stores all force field parameters
params = H.getParameters()
- pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom)
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer)
pot_disp = pots.dmff_potentials['ADMPDispForce']
pot_pme = pots.dmff_potentials['ADMPPmeForce']
# construct inputs
- positions = jnp.array(pdb.positions._value) * 10
+ positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
- box = jnp.array([a._value, b._value, c._value]) * 10
+ box = jnp.array([a._value, b._value, c._value])
# neighbor list
- nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map)
+ nbl = nblist.NeighborList(box, rc, pots.meta["cov_map"])
nbl.allocate(positions)
diff --git a/examples/water_fullpol/run.py b/examples/water_fullpol/run.py
index dccf92921..f024a8d66 100755
--- a/examples/water_fullpol/run.py
+++ b/examples/water_fullpol/run.py
@@ -13,18 +13,18 @@
H = Hamiltonian('forcefield.xml')
app.Topology.loadBondDefinitions("residues.xml")
pdb = app.PDBFile("waterbox_31ang.pdb")
- rc = 6
+ rc = 0.6
# generator stores all force field parameters
disp_generator, pme_generator = H.getGenerators()
- pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4)
+ pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, ethresh=5e-4)
# construct inputs
- positions = jnp.array(pdb.positions._value) * 10
+ positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
- box = jnp.array([a._value, b._value, c._value]) * 10
+ box = jnp.array([a._value, b._value, c._value])
# neighbor list
- nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map)
+ nbl = nblist.NeighborList(box, rc, pots.meta['cov_map'])
nbl.allocate(positions)
diff --git a/tests/data/admp_mono.xml b/tests/data/admp_mono.xml
new file mode 100644
index 000000000..3970ff522
--- /dev/null
+++ b/tests/data/admp_mono.xml
@@ -0,0 +1,34 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/data/admp_nonpol.xml b/tests/data/admp_nonpol.xml
new file mode 100644
index 000000000..7cc1b4653
--- /dev/null
+++ b/tests/data/admp_nonpol.xml
@@ -0,0 +1,40 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/data/peg4.pdb b/tests/data/peg4.pdb
new file mode 100644
index 000000000..eee06d7da
--- /dev/null
+++ b/tests/data/peg4.pdb
@@ -0,0 +1,64 @@
+HEADER
+TITLE MDANALYSIS FRAME 0: Created by PDBWriter
+CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1
+ATOM 1 C00 TER X 1 47.381 10.286 49.808 0.00 1.00 SYST
+ATOM 2 H01 TER X 1 47.251 11.255 50.307 0.00 1.00 SYST
+ATOM 3 H02 TER X 1 46.907 9.487 50.425 0.00 1.00 SYST
+ATOM 4 O03 TER X 1 48.814 10.202 49.785 0.00 1.00 SYST
+ATOM 5 C04 TER X 1 49.336 9.203 50.665 0.00 1.00 SYST
+ATOM 6 H05 TER X 1 50.344 9.329 51.054 0.00 1.00 SYST
+ATOM 7 H06 TER X 1 48.796 9.176 51.611 0.00 1.00 SYST
+ATOM 8 H07 TER X 1 49.296 8.320 50.177 0.00 1.00 SYST
+ATOM 9 C00 INT X 2 46.552 8.760 46.601 0.00 1.00 SYST
+ATOM 10 H01 INT X 2 46.737 9.609 45.939 0.00 1.00 SYST
+ATOM 11 H02 INT X 2 45.532 8.628 46.649 0.00 1.00 SYST
+ATOM 12 O03 INT X 2 47.247 8.976 47.799 0.00 1.00 SYST
+ATOM 13 C04 INT X 2 46.919 10.250 48.371 0.00 1.00 SYST
+ATOM 14 H05 INT X 2 47.190 11.176 47.880 0.00 1.00 SYST
+ATOM 15 H06 INT X 2 45.801 10.369 48.307 0.00 1.00 SYST
+ATOM 16 C00 INT X 3 46.760 5.982 44.153 0.00 1.00 SYST
+ATOM 17 H01 INT X 3 47.759 6.173 43.770 0.00 1.00 SYST
+ATOM 18 H02 INT X 3 46.121 5.894 43.168 0.00 1.00 SYST
+ATOM 19 O03 INT X 3 46.268 7.098 44.918 0.00 1.00 SYST
+ATOM 20 C04 INT X 3 47.139 7.493 45.949 0.00 1.00 SYST
+ATOM 21 H05 INT X 3 47.292 6.726 46.769 0.00 1.00 SYST
+ATOM 22 H06 INT X 3 48.124 7.662 45.625 0.00 1.00 SYST
+ATOM 23 C00 TER X 4 46.610 4.692 44.880 0.00 1.00 SYST
+ATOM 24 H01 TER X 4 45.686 4.613 45.520 0.00 1.00 SYST
+ATOM 25 H02 TER X 4 47.444 4.603 45.516 0.00 1.00 SYST
+ATOM 26 O03 TER X 4 46.501 3.674 43.869 0.00 1.00 SYST
+ATOM 27 C04 TER X 4 45.802 2.493 44.226 0.00 1.00 SYST
+ATOM 28 H05 TER X 4 45.959 1.651 43.497 0.00 1.00 SYST
+ATOM 29 H06 TER X 4 46.125 2.280 45.251 0.00 1.00 SYST
+ATOM 30 H07 TER X 4 44.695 2.638 44.209 0.00 1.00 SYST
+CONECT 1 2 3 4 13
+CONECT 2 1
+CONECT 3 1
+CONECT 4 1 5
+CONECT 5 4 6 7 8
+CONECT 6 5
+CONECT 7 5
+CONECT 8 5
+CONECT 9 10 11 12 20
+CONECT 10 9
+CONECT 11 9
+CONECT 12 9 13
+CONECT 13 1 12 14 15
+CONECT 14 13
+CONECT 15 13
+CONECT 16 17 18 19 23
+CONECT 17 16
+CONECT 18 16
+CONECT 19 16 20
+CONECT 20 9 19 21 22
+CONECT 21 20
+CONECT 22 20
+CONECT 23 16 24 25 26
+CONECT 24 23
+CONECT 25 23
+CONECT 26 23 27
+CONECT 27 26 28 29 30
+CONECT 28 27
+CONECT 29 27
+CONECT 30 27
+END
diff --git a/tests/data/peg_sgnn.xml b/tests/data/peg_sgnn.xml
new file mode 100644
index 000000000..206326d1e
--- /dev/null
+++ b/tests/data/peg_sgnn.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/data/sgnn_model.pickle b/tests/data/sgnn_model.pickle
new file mode 100644
index 000000000..0c3959cd9
Binary files /dev/null and b/tests/data/sgnn_model.pickle differ
diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py
index 02d81ead4..be4b9d99d 100644
--- a/tests/test_admp/test_compute.py
+++ b/tests/test_admp/test_compute.py
@@ -24,13 +24,17 @@ def test_init(self):
"""
rc = 4.0
H = Hamiltonian('tests/data/admp.xml')
+ H1 = Hamiltonian('tests/data/admp_mono.xml')
+ H2 = Hamiltonian('tests/data/admp_nonpol.xml')
pdb = app.PDBFile('tests/data/water_dimer.pdb')
potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
+ potential1 = H1.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
+ potential2 = H2.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
- yield potential, H.paramset
+ yield potential, potential1, potential2, H.paramset, H1.paramset, H2.paramset
def test_ADMPPmeForce(self, pot_prm):
- potential, paramset = pot_prm
+ potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
@@ -51,7 +55,7 @@ def test_ADMPPmeForce(self, pot_prm):
def test_ADMPPmeForce_jit(self, pot_prm):
- potential, paramset = pot_prm
+ potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
@@ -67,5 +71,47 @@ def test_ADMPPmeForce_jit(self, pot_prm):
pot = potential.getPotentialFunc(names=["ADMPPmeForce"])
j_pot_pme = jit(value_and_grad(pot))
energy, grad = j_pot_pme(positions, box, pairs, paramset.parameters)
- print(energy)
+ print('hahahah', energy)
np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1)
+
+
+ def test_ADMPPmeForce_mono(self, pot_prm):
+ potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
+ rc = 0.4
+ pdb = app.PDBFile('tests/data/water_dimer.pdb')
+ positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
+ positions = jnp.array(positions)
+ a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
+ box = jnp.array([a, b, c])
+ # neighbor list
+
+ covalent_map = potential1.meta["cov_map"]
+
+ nblist = NeighborList(box, rc, covalent_map)
+ nblist.allocate(positions)
+ pairs = nblist.pairs
+ pot = potential1.getPotentialFunc(names=["ADMPPmeForce"])
+ energy = pot(positions, box, pairs, paramset1)
+ print(energy)
+ np.testing.assert_almost_equal(energy, -66.55921382, decimal=2)
+
+
+ def test_ADMPPmeForce_nonpol(self, pot_prm):
+ potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
+ rc = 0.4
+ pdb = app.PDBFile('tests/data/water_dimer.pdb')
+ positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
+ positions = jnp.array(positions)
+ a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
+ box = jnp.array([a, b, c])
+ # neighbor list
+
+ covalent_map = potential2.meta["cov_map"]
+
+ nblist = NeighborList(box, rc, covalent_map)
+ nblist.allocate(positions)
+ pairs = nblist.pairs
+ pot = potential2.getPotentialFunc(names=["ADMPPmeForce"])
+ energy = pot(positions, box, pairs, paramset2)
+ print(energy)
+ np.testing.assert_almost_equal(energy, -31.69025446, decimal=2)
diff --git a/tests/test_sgnn/test_energy.py b/tests/test_sgnn/test_energy.py
new file mode 100644
index 000000000..a771f4508
--- /dev/null
+++ b/tests/test_sgnn/test_energy.py
@@ -0,0 +1,51 @@
+import openmm.app as app
+import openmm.unit as unit
+import numpy as np
+import jax.numpy as jnp
+import numpy.testing as npt
+import pytest
+from dmff import Hamiltonian, NeighborList
+from jax import jit, value_and_grad
+
+class TestADMPAPI:
+
+ """ Test sGNN related generators
+ """
+
+ @pytest.fixture(scope='class', name='pot_prm')
+ def test_init(self):
+ """load generators from XML file
+
+ Yields:
+ Tuple: (
+ ADMPDispForce,
+ ADMPPmeForce, # polarized
+ )
+ """
+ rc = 4.0
+ H = Hamiltonian('tests/data/peg_sgnn.xml')
+ pdb = app.PDBFile('tests/data/peg4.pdb')
+ potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
+
+ yield potential, H.paramset
+
+ def test_sGNN_energy(self, pot_prm):
+ potential, paramset = pot_prm
+ rc = 0.4
+ pdb = app.PDBFile('tests/data/peg4.pdb')
+ positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
+ positions = jnp.array(positions)
+ a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
+ box = jnp.array([a, b, c])
+ # neighbor list
+ covalent_map = potential.meta["cov_map"]
+
+ nblist = NeighborList(box, rc, covalent_map)
+ nblist.allocate(positions)
+ pairs = nblist.pairs
+ pot = potential.getPotentialFunc(names=["SGNNForce"])
+ energy = pot(positions, box, pairs, paramset)
+ print(energy)
+ np.testing.assert_almost_equal(energy, -21.81780787, decimal=2)
+
+