diff --git a/.github/workflows/ut.yml b/.github/workflows/ut.yml index 701acf9bb..5b8de3ba1 100644 --- a/.github/workflows/ut.yml +++ b/.github/workflows/ut.yml @@ -33,4 +33,5 @@ jobs: pytest -vs tests/test_common/test_* pytest -vs tests/test_admp/test_* pytest -vs tests/test_utils.py - pytest -vs tests/test_mbar/test_* + pytest -vs tests/test_mbar/test_* + pytest -vs tests/test_sgnn/test_* diff --git a/dmff/api/graph.py b/dmff/api/graph.py index 6e1ccd6c0..64e9af10a 100644 --- a/dmff/api/graph.py +++ b/dmff/api/graph.py @@ -13,7 +13,7 @@ def matchTemplate(graph, template): if graph.number_of_nodes() != template.number_of_nodes(): - print("Node with different number of nodes.") + # print("Node with different number of nodes.") return False, {}, {} def match_func(n1, n2): diff --git a/dmff/generators/__init__.py b/dmff/generators/__init__.py index 7ddb93293..6f37cf7f0 100644 --- a/dmff/generators/__init__.py +++ b/dmff/generators/__init__.py @@ -1,2 +1,3 @@ from .classical import * -from .admp import * \ No newline at end of file +from .admp import * +from .ml import * diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index bf1cce2e0..cada68fe1 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -822,15 +822,28 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): kzs.append(kz) # record multipoles c0.append(float(attribs["c0"])) - dX.append(float(attribs["dX"])) - dY.append(float(attribs["dY"])) - dZ.append(float(attribs["dZ"])) - qXX.append(float(attribs["qXX"])) - qYY.append(float(attribs["qYY"])) - qZZ.append(float(attribs["qZZ"])) - qXY.append(float(attribs["qXY"])) - qXZ.append(float(attribs["qXZ"])) - qYZ.append(float(attribs["qYZ"])) + if self.lmax >= 1: + dX.append(float(attribs["dX"])) + dY.append(float(attribs["dY"])) + dZ.append(float(attribs["dZ"])) + else: + dX.append(0.0) + dY.append(0.0) + dZ.append(0.0) + if self.lmax >= 2: + qXX.append(float(attribs["qXX"])) + qYY.append(float(attribs["qYY"])) + qZZ.append(float(attribs["qZZ"])) + qXY.append(float(attribs["qXY"])) + qXZ.append(float(attribs["qXZ"])) + qYZ.append(float(attribs["qYZ"])) + else: + qXX.append(0.0) + qYY.append(0.0) + qZZ.append(0.0) + qXY.append(0.0) + qXZ.append(0.0) + qYZ.append(0.0) mask = 1.0 if "mask" in attribs and attribs["mask"].upper() == "TRUE": mask = 0.0 @@ -1146,6 +1159,7 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof pme_force = ADMPPmeForce(box, axis_types, axis_indices, rc, self.ethresh, self.lmax, self.lpol, lpme, self.step_pol) + self.pme_force = pme_force def potential_fn(positions, box, pairs, params): positions = positions * 10 @@ -1181,4 +1195,4 @@ def getMetaData(self): return self._meta -_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator \ No newline at end of file +_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator diff --git a/dmff/generators/ml.py b/dmff/generators/ml.py new file mode 100644 index 000000000..afec8cee8 --- /dev/null +++ b/dmff/generators/ml.py @@ -0,0 +1,74 @@ +from ..api.topology import DMFFTopology +from ..api.paramset import ParamSet +from ..api.hamiltonian import _DMFFGenerators +from ..utils import DMFFException, isinstance_jnp +from ..utils import jit_condition +import numpy as np +import jax +import jax.numpy as jnp +import openmm.app as app +import openmm.unit as unit +import pickle + +from ..sgnn.graph import MAX_VALENCE, TopGraph, from_pdb +from ..sgnn.gnn import MolGNNForce, prm_transform_f2i + + +class SGNNGenerator: + def __init__(self, ffinfo: dict, paramset: ParamSet): + + self.name = "SGNNForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self.key_type = None + + self.file = self.ffinfo["Forces"][self.name]["meta"]["file"] + self.nn = int(self.ffinfo["Forces"][self.name]["meta"]["nn"]) + self.pdb = self.ffinfo["Forces"][self.name]["meta"]["pdb"] + + # load ML potential parameters + with open(self.file, 'rb') as ifile: + params = pickle.load(ifile) + + # convert to jnp array + for k in params: + params[k] = jnp.array(params[k]) + # set mask to all true + paramset.addParameter(params[k], k, field=self.name, mask=jnp.ones(params[k].shape)) + + # mask = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape), params) + # paramset.addParameter(params, "params", field=self.name, mask=mask) + + + def getName(self) -> str: + return self.name + + def overwrite(self, paramset): + # do not use xml to handle ML potentials + # for ML potentials, xml only documents param file path + # so for ML potentials, overwrite function overwrites the file directly + with open(self.file, 'wb') as ofile: + pickle.dump(paramset[self.name], ofile) + return + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs): + self.G = from_pdb(self.pdb) + n_atoms = topdata.getNumAtoms() + self.model = MolGNNForce(self.G, nn=self.nn) + n_layers = self.model.n_layers + def potential_fn(positions, box, pairs, params): + # convert unit to angstrom + positions = positions * 10 + box = box * 10 + prms = prm_transform_f2i(params[self.name], n_layers) + return self.model.get_energy(positions, box, prms) + + self._jaxPotential = potential_fn + return potential_fn + + def getJaxPotential(self): + return self._jaxPotential + + +_DMFFGenerators["SGNNForce"] = SGNNGenerator + diff --git a/dmff/sgnn/gnn.py b/dmff/sgnn/gnn.py index 49a5ee985..403d36429 100755 --- a/dmff/sgnn/gnn.py +++ b/dmff/sgnn/gnn.py @@ -13,6 +13,39 @@ from jax import value_and_grad, vmap +def prm_transform_f2i(params, n_layers): + p = {} + for k in params: + p[k] = jnp.array(params[k]) + for i_nn in [0, 1]: + nn_name = 'fc%d' % i_nn + p['%s.weight' % nn_name] = [] + p['%s.bias' % nn_name] = [] + for i_layer in range(n_layers[i_nn]): + k_w = '%s.%d.weight' % (nn_name, i_layer) + k_b = '%s.%d.bias' % (nn_name, i_layer) + p['%s.weight' % nn_name].append(p.pop(k_w, None)) + p['%s.bias' % nn_name].append(p.pop(k_b, None)) + return p + + +def prm_transform_i2f(params, n_layers): + # transform format + p = {} + p['w'] = params['w'] + p['fc_final.weight'] = params['fc_final.weight'] + p['fc_final.bias'] = params['fc_final.bias'] + for i_nn in range(2): + nn_name = 'fc%d' % i_nn + for i_layer in range(n_layers[i_nn]): + p[nn_name + '.%d.weight' % + i_layer] = params[nn_name + '.weight'][i_layer] + p[nn_name + + '.%d.bias' % i_layer] = params[nn_name + + '.bias'][i_layer] + return p + + class MolGNNForce: def __init__(self, @@ -146,6 +179,7 @@ def message_pass(f_in, nb_connect, w, nn): return + def load_params(self, ifn): """ Load the network parameters from saved file @@ -160,32 +194,12 @@ def load_params(self, ifn): for k in params.keys(): params[k] = jnp.array(params[k]) # transform format - keys = list(params.keys()) - for i_nn in [0, 1]: - nn_name = 'fc%d' % i_nn - keys_weight = [] - keys_bias = [] - for k in keys: - if re.search(nn_name + '.[0-9]+.weight', k) is not None: - keys_weight.append(k) - elif re.search(nn_name + '.[0-9]+.bias', k) is not None: - keys_bias.append(k) - if len(keys_weight) != self.n_layers[i_nn] or len( - keys_bias) != self.n_layers[i_nn]: - sys.exit( - 'Error while loading GNN params, inconsistent inputs with the GNN structure, check your input!' - ) - params['%s.weight' % nn_name] = [] - params['%s.bias' % nn_name] = [] - for i_layer in range(self.n_layers[i_nn]): - k_w = '%s.%d.weight' % (nn_name, i_layer) - k_b = '%s.%d.bias' % (nn_name, i_layer) - params['%s.weight' % nn_name].append(params.pop(k_w, None)) - params['%s.bias' % nn_name].append(params.pop(k_b, None)) - # params[nn_name] - self.params = params + self.params = prm_transform_f2i(params, self.n_layers) return + + + def save_params(self, ofn): """ Save the network parameters to a pickle file @@ -196,18 +210,8 @@ def save_params(self, ofn): """ # transform format - params = {} - params['w'] = self.params['w'] - params['fc_final.weight'] = self.params['fc_final.weight'] - params['fc_final.bias'] = self.params['fc_final.bias'] - for i_nn in range(2): - nn_name = 'fc%d' % i_nn - for i_layer in range(self.n_layers[i_nn]): - params[nn_name + '.%d.weight' % - i_layer] = self.params[nn_name + '.weight'][i_layer] - params[nn_name + - '.%d.bias' % i_layer] = self.params[nn_name + - '.bias'][i_layer] + params = prm_transform_i2f(self.params, self.n_layers) with open(ofn, 'wb') as ofile: pickle.dump(params, ofile) return + diff --git a/dmff/sgnn/graph.py b/dmff/sgnn/graph.py index 93f6a809a..164a41f3f 100755 --- a/dmff/sgnn/graph.py +++ b/dmff/sgnn/graph.py @@ -1219,6 +1219,20 @@ def from_pdb(pdb): return TopGraph(list_atom_elems, bonds, positions=positions, box=box) +# def from_dmff_top(topdata): +# ''' +# Build the sGNN TopGraph object from a DMFFTopology object + +# Parameters +# ---------- +# topdata: DMFFTopology data +# ''' +# list_atom_elems = np.array([a.element for a in topdata.atoms()]) +# bonds = np.array([np.sort([b.atom1.index, b.atom2.index]) for b in topdata.bonds()]) +# n_atoms = len(list_atom_elems) +# return TopGraph(list_atom_elems, bonds, positions=jnp.zeros((n_atoms, 3)), box=jnp.eye(3)*10) + + def validation(): G = from_pdb('peg4.pdb') nn = 1 diff --git a/examples/classical/test_xml.py b/examples/classical/test_xml.py index e84c6d849..c5d594033 100755 --- a/examples/classical/test_xml.py +++ b/examples/classical/test_xml.py @@ -82,4 +82,5 @@ def getEnergyDecomposition(context, forcegroups): print("Nonbonded:", nbE(positions, box, pairs, params)) etotal = pot.getPotentialFunc() - print("Total:", etotal(positions, box, pairs, params)) \ No newline at end of file + print("Total:", etotal(positions, box, pairs, params)) + diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle deleted file mode 100644 index 0c3959cd9..000000000 Binary files a/examples/sgnn/model1.pickle and /dev/null differ diff --git a/examples/sgnn/model1.pickle b/examples/sgnn/model1.pickle new file mode 120000 index 000000000..88c340bb9 --- /dev/null +++ b/examples/sgnn/model1.pickle @@ -0,0 +1 @@ +test_backend/model1.pickle \ No newline at end of file diff --git a/examples/sgnn/peg.xml b/examples/sgnn/peg.xml new file mode 100644 index 000000000..d3f41baaf --- /dev/null +++ b/examples/sgnn/peg.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/sgnn/peg4.pdb b/examples/sgnn/peg4.pdb deleted file mode 100644 index 2c11081d1..000000000 --- a/examples/sgnn/peg4.pdb +++ /dev/null @@ -1,63 +0,0 @@ -REMARK -CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1 -ATOM 1 C00 TER 1 -2.962 3.637 -1.170 -ATOM 2 H01 TER 1 -2.608 4.142 -0.296 -ATOM 3 H02 TER 1 -4.032 3.635 -1.171 -ATOM 4 O03 TER 1 -2.484 2.289 -1.168 -ATOM 5 C04 TER 1 -2.961 1.615 0.000 -ATOM 6 H05 TER 1 -2.604 0.606 0.000 -ATOM 7 H06 TER 1 -2.604 2.119 0.874 -ATOM 8 H07 TER 1 -4.031 1.615 0.000 -ATOM 9 C00 INT 2 -2.449 6.384 -3.596 -ATOM 10 H01 INT 2 -2.804 5.879 -4.470 -ATOM 11 H02 INT 2 -1.379 6.386 -3.595 -ATOM 12 O03 INT 2 -2.927 5.710 -2.429 -ATOM 13 C04 INT 2 -2.448 4.362 -2.427 -ATOM 14 H05 INT 2 -2.803 3.856 -3.301 -ATOM 15 H06 INT 2 -1.378 4.364 -2.425 -ATOM 16 C00 INT 3 -2.966 9.857 -4.767 -ATOM 17 H01 INT 3 -2.612 10.363 -3.893 -ATOM 18 H02 INT 3 -4.036 9.855 -4.768 -ATOM 19 O03 INT 3 -2.488 8.509 -4.765 -ATOM 20 C04 INT 3 -2.965 7.835 -3.597 -ATOM 21 H05 INT 3 -2.610 8.340 -2.724 -ATOM 22 H06 INT 3 -4.035 7.833 -3.599 -ATOM 23 C00 TER 4 -2.452 10.582 -6.024 -ATOM 24 H01 TER 4 -2.807 10.077 -6.898 -ATOM 25 H02 TER 4 -1.382 10.584 -6.022 -ATOM 26 O03 TER 4 -2.931 11.930 -6.026 -ATOM 27 C04 TER 4 -2.453 12.604 -7.193 -ATOM 28 H05 TER 4 -2.808 12.099 -8.067 -ATOM 29 H06 TER 4 -2.812 13.613 -7.194 -ATOM 30 H07 TER 4 -1.383 12.606 -7.192 -TER -CONECT 5 6 -CONECT 5 7 -CONECT 5 8 -CONECT 5 4 -CONECT 4 1 -CONECT 1 2 -CONECT 1 3 -CONECT 1 13 -CONECT 13 14 -CONECT 13 15 -CONECT 13 12 -CONECT 12 9 -CONECT 9 10 -CONECT 9 11 -CONECT 9 20 -CONECT 20 21 -CONECT 20 22 -CONECT 20 19 -CONECT 19 16 -CONECT 16 17 -CONECT 16 18 -CONECT 16 23 -CONECT 23 24 -CONECT 23 25 -CONECT 23 26 -CONECT 26 27 -CONECT 27 28 -CONECT 27 29 -CONECT 27 30 -END diff --git a/examples/sgnn/peg4.pdb b/examples/sgnn/peg4.pdb new file mode 120000 index 000000000..3c2bb15b6 --- /dev/null +++ b/examples/sgnn/peg4.pdb @@ -0,0 +1 @@ +test_backend/peg4.pdb \ No newline at end of file diff --git a/examples/sgnn/ref_out b/examples/sgnn/ref_out index 96bc1e62f..039b75427 100644 --- a/examples/sgnn/ref_out +++ b/examples/sgnn/ref_out @@ -1,37 +1,5 @@ -Energy: -21.588394 -Force -[[ 90.02814 2.0374336 35.38877 ] - [ -98.410095 -1.6865425 -30.066338 ] - [ 48.29245 31.675808 -43.390694 ] - [ 59.717484 -35.94304 50.599678 ] - [ -24.63767 218.36092 168.47194 ] - [ 43.258293 81.24294 -87.22882 ] - [ -67.66767 -17.780457 -5.6038494 ] - [ -22.928284 -302.96246 -123.14815 ] - [ 306.24683 -21.33866 -156.95491 ] - [ -4.715515 13.664352 -23.222527 ] - [-258.61304 -26.577957 85.58963 ] - [ -10.179474 106.21161 64.846924 ] - [-210.20566 -52.107193 58.04005 ] - [ 118.68472 -8.033836 -81.18109 ] - [ 44.02272 -34.508667 46.852356 ] - [-214.84206 115.90286 -227.59117 ] - [ 44.243336 -7.151741 26.06369 ] - [ 87.46674 38.574554 192.17757 ] - [ 27.345726 -58.87986 -44.685863 ] - [ -83.354774 -29.714098 214.93097 ] - [ -71.111305 34.880676 -77.53289 ] - [ 141.12836 49.28147 -97.597305 ] - [-220.25613 -134.58449 -23.567059 ] - [ 75.2593 58.432755 -63.99505 ] - [ 123.56466 -82.0066 94.63971 ] - [ 57.822285 17.07631 -53.788273 ] - [ -73.37115 0.50865555 16.240654 ] - [ 54.86133 97.53715 73.672806 ] - [ -23.997787 -73.92179 -13.749107 ] - [ 62.348286 21.809956 25.78839 ]] -Batched Energies: -[-21.653 -39.830627 9.988983 -48.292953 -32.959183 -49.7164 - -47.617737 -51.76767 -37.42943 -35.06703 -46.111145 -31.748154 - -6.939003 -5.1853027 -27.427734 -44.695312 -52.027237 3.1541443 - -72.8221 -28.33014 ] +-21.588284621154912 +[-21.58828462 -39.79334159 10.03889335 -48.22451239 -32.90970162 + -49.68568287 -47.58035178 -51.73860617 -37.39235277 -35.01933271 + -46.06621902 -31.69327601 -6.86739655 -5.13698524 -27.4031207 + -44.65301991 -52.00357797 3.1734038 -72.79081259 -28.27007722] diff --git a/examples/sgnn/residues.xml b/examples/sgnn/residues.xml new file mode 100644 index 000000000..aa78866eb --- /dev/null +++ b/examples/sgnn/residues.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/sgnn/run.py b/examples/sgnn/run.py index f87c887b5..14c7c1b84 100755 --- a/examples/sgnn/run.py +++ b/examples/sgnn/run.py @@ -1,45 +1,44 @@ #!/usr/bin/env python import sys -import numpy as np +import jax import jax.numpy as jnp -import jax.lax as lax -from jax import vmap, value_and_grad -import dmff -from dmff.sgnn.gnn import MolGNNForce -from dmff.utils import jit_condition -from dmff.sgnn.graph import MAX_VALENCE -from dmff.sgnn.graph import TopGraph, from_pdb +import openmm.app as app +import openmm.unit as unit +from dmff.api import Hamiltonian +from dmff.common import nblist +from jax import value_and_grad import pickle -import re -from collections import OrderedDict -from functools import partial - if __name__ == '__main__': - # params = load_params('benchmark/model1.pickle') - G = from_pdb('peg4.pdb') - model = MolGNNForce(G, nn=1) - model.load_params('model1.pickle') - E = model.get_energy(G.positions, G.box, model.params) + + H = Hamiltonian('peg.xml') + app.Topology.loadBondDefinitions("residues.xml") + pdb = app.PDBFile("peg4.pdb") + rc = 0.6 + # generator stores all force field parameters + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, ethresh=5e-4) + + # construct inputs + positions = jnp.array(pdb.positions._value) + a, b, c = pdb.topology.getPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) + # neighbor list + nbl = nblist.NeighborList(box, rc, pots.meta['cov_map']) + nbl.allocate(positions) + + + paramset = H.getParameters() + # params = paramset.parameters - with open('set_test_lowT.pickle', 'rb') as ifile: + with open('test_backend/set_test_lowT.pickle', 'rb') as ifile: data = pickle.load(ifile) - # pos = jnp.array(data['positions'][0:100]) - # box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1)) - pos = jnp.array(data['positions'][0]) - box = jnp.eye(3) * 50 + # input in nm + pos = jnp.array(data['positions'][0:20]) / 10 + box = jnp.eye(3) * 5 - # energies = model.batch_forward(pos, box, model.params) - E, F = value_and_grad(model.get_energy, argnums=(0))(pos, box, model.params) - F = -F - print('Energy:', E) - print('Force') - print(F) + efunc = jax.jit(pots.getPotentialFunc()) + efunc_vmap = jax.vmap(jax.jit(pots.getPotentialFunc()), in_axes=(0, None, None, None), out_axes=0) + print(efunc(pos[0], box, nbl.pairs, paramset)) + print(efunc_vmap(pos, box, nbl.pairs, paramset)) - # test batch processing - pos = jnp.array(data['positions'][:20]) - box = jnp.tile(box, (20, 1, 1)) - E = model.batch_forward(pos, box, model.params) - print('Batched Energies:') - print(E) diff --git a/examples/sgnn/model.pickle b/examples/sgnn/test_backend/model.pickle similarity index 100% rename from examples/sgnn/model.pickle rename to examples/sgnn/test_backend/model.pickle diff --git a/examples/sgnn/test_backend/model1.pickle b/examples/sgnn/test_backend/model1.pickle new file mode 100644 index 000000000..0c3959cd9 Binary files /dev/null and b/examples/sgnn/test_backend/model1.pickle differ diff --git a/examples/sgnn/model1.pth b/examples/sgnn/test_backend/model1.pth similarity index 100% rename from examples/sgnn/model1.pth rename to examples/sgnn/test_backend/model1.pth diff --git a/examples/sgnn/mse_testing.xvg b/examples/sgnn/test_backend/mse_testing.xvg similarity index 100% rename from examples/sgnn/mse_testing.xvg rename to examples/sgnn/test_backend/mse_testing.xvg diff --git a/examples/sgnn/test_backend/peg4.pdb b/examples/sgnn/test_backend/peg4.pdb new file mode 100644 index 000000000..2c11081d1 --- /dev/null +++ b/examples/sgnn/test_backend/peg4.pdb @@ -0,0 +1,63 @@ +REMARK +CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER 1 -2.962 3.637 -1.170 +ATOM 2 H01 TER 1 -2.608 4.142 -0.296 +ATOM 3 H02 TER 1 -4.032 3.635 -1.171 +ATOM 4 O03 TER 1 -2.484 2.289 -1.168 +ATOM 5 C04 TER 1 -2.961 1.615 0.000 +ATOM 6 H05 TER 1 -2.604 0.606 0.000 +ATOM 7 H06 TER 1 -2.604 2.119 0.874 +ATOM 8 H07 TER 1 -4.031 1.615 0.000 +ATOM 9 C00 INT 2 -2.449 6.384 -3.596 +ATOM 10 H01 INT 2 -2.804 5.879 -4.470 +ATOM 11 H02 INT 2 -1.379 6.386 -3.595 +ATOM 12 O03 INT 2 -2.927 5.710 -2.429 +ATOM 13 C04 INT 2 -2.448 4.362 -2.427 +ATOM 14 H05 INT 2 -2.803 3.856 -3.301 +ATOM 15 H06 INT 2 -1.378 4.364 -2.425 +ATOM 16 C00 INT 3 -2.966 9.857 -4.767 +ATOM 17 H01 INT 3 -2.612 10.363 -3.893 +ATOM 18 H02 INT 3 -4.036 9.855 -4.768 +ATOM 19 O03 INT 3 -2.488 8.509 -4.765 +ATOM 20 C04 INT 3 -2.965 7.835 -3.597 +ATOM 21 H05 INT 3 -2.610 8.340 -2.724 +ATOM 22 H06 INT 3 -4.035 7.833 -3.599 +ATOM 23 C00 TER 4 -2.452 10.582 -6.024 +ATOM 24 H01 TER 4 -2.807 10.077 -6.898 +ATOM 25 H02 TER 4 -1.382 10.584 -6.022 +ATOM 26 O03 TER 4 -2.931 11.930 -6.026 +ATOM 27 C04 TER 4 -2.453 12.604 -7.193 +ATOM 28 H05 TER 4 -2.808 12.099 -8.067 +ATOM 29 H06 TER 4 -2.812 13.613 -7.194 +ATOM 30 H07 TER 4 -1.383 12.606 -7.192 +TER +CONECT 5 6 +CONECT 5 7 +CONECT 5 8 +CONECT 5 4 +CONECT 4 1 +CONECT 1 2 +CONECT 1 3 +CONECT 1 13 +CONECT 13 14 +CONECT 13 15 +CONECT 13 12 +CONECT 12 9 +CONECT 9 10 +CONECT 9 11 +CONECT 9 20 +CONECT 20 21 +CONECT 20 22 +CONECT 20 19 +CONECT 19 16 +CONECT 16 17 +CONECT 16 18 +CONECT 16 23 +CONECT 23 24 +CONECT 23 25 +CONECT 23 26 +CONECT 26 27 +CONECT 27 28 +CONECT 27 29 +CONECT 27 30 +END diff --git a/examples/sgnn/pth2pickle.py b/examples/sgnn/test_backend/pth2pickle.py similarity index 100% rename from examples/sgnn/pth2pickle.py rename to examples/sgnn/test_backend/pth2pickle.py diff --git a/examples/sgnn/test_backend/ref_out b/examples/sgnn/test_backend/ref_out new file mode 100644 index 000000000..96bc1e62f --- /dev/null +++ b/examples/sgnn/test_backend/ref_out @@ -0,0 +1,37 @@ +Energy: -21.588394 +Force +[[ 90.02814 2.0374336 35.38877 ] + [ -98.410095 -1.6865425 -30.066338 ] + [ 48.29245 31.675808 -43.390694 ] + [ 59.717484 -35.94304 50.599678 ] + [ -24.63767 218.36092 168.47194 ] + [ 43.258293 81.24294 -87.22882 ] + [ -67.66767 -17.780457 -5.6038494 ] + [ -22.928284 -302.96246 -123.14815 ] + [ 306.24683 -21.33866 -156.95491 ] + [ -4.715515 13.664352 -23.222527 ] + [-258.61304 -26.577957 85.58963 ] + [ -10.179474 106.21161 64.846924 ] + [-210.20566 -52.107193 58.04005 ] + [ 118.68472 -8.033836 -81.18109 ] + [ 44.02272 -34.508667 46.852356 ] + [-214.84206 115.90286 -227.59117 ] + [ 44.243336 -7.151741 26.06369 ] + [ 87.46674 38.574554 192.17757 ] + [ 27.345726 -58.87986 -44.685863 ] + [ -83.354774 -29.714098 214.93097 ] + [ -71.111305 34.880676 -77.53289 ] + [ 141.12836 49.28147 -97.597305 ] + [-220.25613 -134.58449 -23.567059 ] + [ 75.2593 58.432755 -63.99505 ] + [ 123.56466 -82.0066 94.63971 ] + [ 57.822285 17.07631 -53.788273 ] + [ -73.37115 0.50865555 16.240654 ] + [ 54.86133 97.53715 73.672806 ] + [ -23.997787 -73.92179 -13.749107 ] + [ 62.348286 21.809956 25.78839 ]] +Batched Energies: +[-21.653 -39.830627 9.988983 -48.292953 -32.959183 -49.7164 + -47.617737 -51.76767 -37.42943 -35.06703 -46.111145 -31.748154 + -6.939003 -5.1853027 -27.427734 -44.695312 -52.027237 3.1541443 + -72.8221 -28.33014 ] diff --git a/examples/sgnn/test_backend/run.py b/examples/sgnn/test_backend/run.py new file mode 100755 index 000000000..f87c887b5 --- /dev/null +++ b/examples/sgnn/test_backend/run.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +import sys +import numpy as np +import jax.numpy as jnp +import jax.lax as lax +from jax import vmap, value_and_grad +import dmff +from dmff.sgnn.gnn import MolGNNForce +from dmff.utils import jit_condition +from dmff.sgnn.graph import MAX_VALENCE +from dmff.sgnn.graph import TopGraph, from_pdb +import pickle +import re +from collections import OrderedDict +from functools import partial + + +if __name__ == '__main__': + # params = load_params('benchmark/model1.pickle') + G = from_pdb('peg4.pdb') + model = MolGNNForce(G, nn=1) + model.load_params('model1.pickle') + E = model.get_energy(G.positions, G.box, model.params) + + with open('set_test_lowT.pickle', 'rb') as ifile: + data = pickle.load(ifile) + + # pos = jnp.array(data['positions'][0:100]) + # box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1)) + pos = jnp.array(data['positions'][0]) + box = jnp.eye(3) * 50 + + # energies = model.batch_forward(pos, box, model.params) + E, F = value_and_grad(model.get_energy, argnums=(0))(pos, box, model.params) + F = -F + print('Energy:', E) + print('Force') + print(F) + + # test batch processing + pos = jnp.array(data['positions'][:20]) + box = jnp.tile(box, (20, 1, 1)) + E = model.batch_forward(pos, box, model.params) + print('Batched Energies:') + print(E) diff --git a/examples/sgnn/set_test.pickle b/examples/sgnn/test_backend/set_test.pickle similarity index 100% rename from examples/sgnn/set_test.pickle rename to examples/sgnn/test_backend/set_test.pickle diff --git a/examples/sgnn/set_test_lowT.pickle b/examples/sgnn/test_backend/set_test_lowT.pickle similarity index 100% rename from examples/sgnn/set_test_lowT.pickle rename to examples/sgnn/test_backend/set_test_lowT.pickle diff --git a/examples/sgnn/test.py b/examples/sgnn/test_backend/test.py similarity index 100% rename from examples/sgnn/test.py rename to examples/sgnn/test_backend/test.py diff --git a/examples/sgnn/test_data.xvg b/examples/sgnn/test_backend/test_data.xvg similarity index 100% rename from examples/sgnn/test_data.xvg rename to examples/sgnn/test_backend/test_data.xvg diff --git a/examples/sgnn/train.py b/examples/sgnn/test_backend/train.py similarity index 100% rename from examples/sgnn/train.py rename to examples/sgnn/test_backend/train.py diff --git a/examples/water_fullpol/monopole_nonpol/run.py b/examples/water_fullpol/monopole_nonpol/run.py index 3b1b4799f..617e96e76 100755 --- a/examples/water_fullpol/monopole_nonpol/run.py +++ b/examples/water_fullpol/monopole_nonpol/run.py @@ -13,18 +13,19 @@ H = Hamiltonian('forcefield.xml') pdb = app.PDBFile("pair.pdb") - rc = 6 + rc = 0.6 # generator stores all force field parameters params = H.getParameters() - pot_pme = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom).dmff_potentials['ADMPPmeForce'] + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer) + pot_pme = pots.dmff_potentials['ADMPPmeForce'] # construct inputs - positions = jnp.array(pdb.positions._value) * 10 + positions = jnp.array(pdb.positions._value) a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 + box = jnp.array([a._value, b._value, c._value]) # neighbor list - nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl = nblist.NeighborList(box, rc, pots.meta['cov_map']) nbl.allocate(positions) E_pme, F_pme = value_and_grad(pot_pme)(positions, box, nbl.pairs, params) diff --git a/examples/water_fullpol/monopole_polarizable/run.py b/examples/water_fullpol/monopole_polarizable/run.py index 808ee5801..560cd8da0 100755 --- a/examples/water_fullpol/monopole_polarizable/run.py +++ b/examples/water_fullpol/monopole_polarizable/run.py @@ -13,18 +13,19 @@ H = Hamiltonian('forcefield.xml') pdb = app.PDBFile("waterbox_31ang.pdb") - rc = 6 + rc = 0.6 # generator stores all force field parameters params = H.getParameters() - pot_pme = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom).dmff_potentials['ADMPPmeForce'] + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer) + pot_pme = pots.dmff_potentials['ADMPPmeForce'] # construct inputs - positions = jnp.array(pdb.positions._value) * 10 + positions = jnp.array(pdb.positions._value) a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 + box = jnp.array([a._value, b._value, c._value]) # neighbor list - nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl = nblist.NeighborList(box, rc, pots.meta['cov_map']) nbl.allocate(positions) E_pme, F_pme = value_and_grad(pot_pme)(positions, box, nbl.pairs, params) diff --git a/examples/water_fullpol/quadrupole_nonpol/run.py b/examples/water_fullpol/quadrupole_nonpol/run.py index b408792aa..0b6fe6394 100755 --- a/examples/water_fullpol/quadrupole_nonpol/run.py +++ b/examples/water_fullpol/quadrupole_nonpol/run.py @@ -14,20 +14,20 @@ H = Hamiltonian('forcefield.xml') app.Topology.loadBondDefinitions("residues.xml") pdb = app.PDBFile("waterbox_31ang.pdb") - rc = 6 + rc = 0.6 # generator stores all force field parameters params = H.getParameters() - pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom) + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer) pot_disp = pots.dmff_potentials['ADMPDispForce'] pot_pme = pots.dmff_potentials['ADMPPmeForce'] # construct inputs - positions = jnp.array(pdb.positions._value) * 10 + positions = jnp.array(pdb.positions._value) a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 + box = jnp.array([a._value, b._value, c._value]) # neighbor list - nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl = nblist.NeighborList(box, rc, pots.meta["cov_map"]) nbl.allocate(positions) diff --git a/examples/water_fullpol/run.py b/examples/water_fullpol/run.py index dccf92921..f024a8d66 100755 --- a/examples/water_fullpol/run.py +++ b/examples/water_fullpol/run.py @@ -13,18 +13,18 @@ H = Hamiltonian('forcefield.xml') app.Topology.loadBondDefinitions("residues.xml") pdb = app.PDBFile("waterbox_31ang.pdb") - rc = 6 + rc = 0.6 # generator stores all force field parameters disp_generator, pme_generator = H.getGenerators() - pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4) + pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, ethresh=5e-4) # construct inputs - positions = jnp.array(pdb.positions._value) * 10 + positions = jnp.array(pdb.positions._value) a, b, c = pdb.topology.getPeriodicBoxVectors() - box = jnp.array([a._value, b._value, c._value]) * 10 + box = jnp.array([a._value, b._value, c._value]) # neighbor list - nbl = nblist.NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl = nblist.NeighborList(box, rc, pots.meta['cov_map']) nbl.allocate(positions) diff --git a/tests/data/admp_mono.xml b/tests/data/admp_mono.xml new file mode 100644 index 000000000..3970ff522 --- /dev/null +++ b/tests/data/admp_mono.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/admp_nonpol.xml b/tests/data/admp_nonpol.xml new file mode 100644 index 000000000..7cc1b4653 --- /dev/null +++ b/tests/data/admp_nonpol.xml @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/peg4.pdb b/tests/data/peg4.pdb new file mode 100644 index 000000000..eee06d7da --- /dev/null +++ b/tests/data/peg4.pdb @@ -0,0 +1,64 @@ +HEADER +TITLE MDANALYSIS FRAME 0: Created by PDBWriter +CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER X 1 47.381 10.286 49.808 0.00 1.00 SYST +ATOM 2 H01 TER X 1 47.251 11.255 50.307 0.00 1.00 SYST +ATOM 3 H02 TER X 1 46.907 9.487 50.425 0.00 1.00 SYST +ATOM 4 O03 TER X 1 48.814 10.202 49.785 0.00 1.00 SYST +ATOM 5 C04 TER X 1 49.336 9.203 50.665 0.00 1.00 SYST +ATOM 6 H05 TER X 1 50.344 9.329 51.054 0.00 1.00 SYST +ATOM 7 H06 TER X 1 48.796 9.176 51.611 0.00 1.00 SYST +ATOM 8 H07 TER X 1 49.296 8.320 50.177 0.00 1.00 SYST +ATOM 9 C00 INT X 2 46.552 8.760 46.601 0.00 1.00 SYST +ATOM 10 H01 INT X 2 46.737 9.609 45.939 0.00 1.00 SYST +ATOM 11 H02 INT X 2 45.532 8.628 46.649 0.00 1.00 SYST +ATOM 12 O03 INT X 2 47.247 8.976 47.799 0.00 1.00 SYST +ATOM 13 C04 INT X 2 46.919 10.250 48.371 0.00 1.00 SYST +ATOM 14 H05 INT X 2 47.190 11.176 47.880 0.00 1.00 SYST +ATOM 15 H06 INT X 2 45.801 10.369 48.307 0.00 1.00 SYST +ATOM 16 C00 INT X 3 46.760 5.982 44.153 0.00 1.00 SYST +ATOM 17 H01 INT X 3 47.759 6.173 43.770 0.00 1.00 SYST +ATOM 18 H02 INT X 3 46.121 5.894 43.168 0.00 1.00 SYST +ATOM 19 O03 INT X 3 46.268 7.098 44.918 0.00 1.00 SYST +ATOM 20 C04 INT X 3 47.139 7.493 45.949 0.00 1.00 SYST +ATOM 21 H05 INT X 3 47.292 6.726 46.769 0.00 1.00 SYST +ATOM 22 H06 INT X 3 48.124 7.662 45.625 0.00 1.00 SYST +ATOM 23 C00 TER X 4 46.610 4.692 44.880 0.00 1.00 SYST +ATOM 24 H01 TER X 4 45.686 4.613 45.520 0.00 1.00 SYST +ATOM 25 H02 TER X 4 47.444 4.603 45.516 0.00 1.00 SYST +ATOM 26 O03 TER X 4 46.501 3.674 43.869 0.00 1.00 SYST +ATOM 27 C04 TER X 4 45.802 2.493 44.226 0.00 1.00 SYST +ATOM 28 H05 TER X 4 45.959 1.651 43.497 0.00 1.00 SYST +ATOM 29 H06 TER X 4 46.125 2.280 45.251 0.00 1.00 SYST +ATOM 30 H07 TER X 4 44.695 2.638 44.209 0.00 1.00 SYST +CONECT 1 2 3 4 13 +CONECT 2 1 +CONECT 3 1 +CONECT 4 1 5 +CONECT 5 4 6 7 8 +CONECT 6 5 +CONECT 7 5 +CONECT 8 5 +CONECT 9 10 11 12 20 +CONECT 10 9 +CONECT 11 9 +CONECT 12 9 13 +CONECT 13 1 12 14 15 +CONECT 14 13 +CONECT 15 13 +CONECT 16 17 18 19 23 +CONECT 17 16 +CONECT 18 16 +CONECT 19 16 20 +CONECT 20 9 19 21 22 +CONECT 21 20 +CONECT 22 20 +CONECT 23 16 24 25 26 +CONECT 24 23 +CONECT 25 23 +CONECT 26 23 27 +CONECT 27 26 28 29 30 +CONECT 28 27 +CONECT 29 27 +CONECT 30 27 +END diff --git a/tests/data/peg_sgnn.xml b/tests/data/peg_sgnn.xml new file mode 100644 index 000000000..206326d1e --- /dev/null +++ b/tests/data/peg_sgnn.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/sgnn_model.pickle b/tests/data/sgnn_model.pickle new file mode 100644 index 000000000..0c3959cd9 Binary files /dev/null and b/tests/data/sgnn_model.pickle differ diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py index 02d81ead4..be4b9d99d 100644 --- a/tests/test_admp/test_compute.py +++ b/tests/test_admp/test_compute.py @@ -24,13 +24,17 @@ def test_init(self): """ rc = 4.0 H = Hamiltonian('tests/data/admp.xml') + H1 = Hamiltonian('tests/data/admp_mono.xml') + H2 = Hamiltonian('tests/data/admp_nonpol.xml') pdb = app.PDBFile('tests/data/water_dimer.pdb') potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + potential1 = H1.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + potential2 = H2.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) - yield potential, H.paramset + yield potential, potential1, potential2, H.paramset, H1.paramset, H2.paramset def test_ADMPPmeForce(self, pot_prm): - potential, paramset = pot_prm + potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm rc = 0.4 pdb = app.PDBFile('tests/data/water_dimer.pdb') positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) @@ -51,7 +55,7 @@ def test_ADMPPmeForce(self, pot_prm): def test_ADMPPmeForce_jit(self, pot_prm): - potential, paramset = pot_prm + potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm rc = 0.4 pdb = app.PDBFile('tests/data/water_dimer.pdb') positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) @@ -67,5 +71,47 @@ def test_ADMPPmeForce_jit(self, pot_prm): pot = potential.getPotentialFunc(names=["ADMPPmeForce"]) j_pot_pme = jit(value_and_grad(pot)) energy, grad = j_pot_pme(positions, box, pairs, paramset.parameters) - print(energy) + print('hahahah', energy) np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1) + + + def test_ADMPPmeForce_mono(self, pot_prm): + potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + rc = 0.4 + pdb = app.PDBFile('tests/data/water_dimer.pdb') + positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + positions = jnp.array(positions) + a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer) + box = jnp.array([a, b, c]) + # neighbor list + + covalent_map = potential1.meta["cov_map"] + + nblist = NeighborList(box, rc, covalent_map) + nblist.allocate(positions) + pairs = nblist.pairs + pot = potential1.getPotentialFunc(names=["ADMPPmeForce"]) + energy = pot(positions, box, pairs, paramset1) + print(energy) + np.testing.assert_almost_equal(energy, -66.55921382, decimal=2) + + + def test_ADMPPmeForce_nonpol(self, pot_prm): + potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + rc = 0.4 + pdb = app.PDBFile('tests/data/water_dimer.pdb') + positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + positions = jnp.array(positions) + a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer) + box = jnp.array([a, b, c]) + # neighbor list + + covalent_map = potential2.meta["cov_map"] + + nblist = NeighborList(box, rc, covalent_map) + nblist.allocate(positions) + pairs = nblist.pairs + pot = potential2.getPotentialFunc(names=["ADMPPmeForce"]) + energy = pot(positions, box, pairs, paramset2) + print(energy) + np.testing.assert_almost_equal(energy, -31.69025446, decimal=2) diff --git a/tests/test_sgnn/test_energy.py b/tests/test_sgnn/test_energy.py new file mode 100644 index 000000000..a771f4508 --- /dev/null +++ b/tests/test_sgnn/test_energy.py @@ -0,0 +1,51 @@ +import openmm.app as app +import openmm.unit as unit +import numpy as np +import jax.numpy as jnp +import numpy.testing as npt +import pytest +from dmff import Hamiltonian, NeighborList +from jax import jit, value_and_grad + +class TestADMPAPI: + + """ Test sGNN related generators + """ + + @pytest.fixture(scope='class', name='pot_prm') + def test_init(self): + """load generators from XML file + + Yields: + Tuple: ( + ADMPDispForce, + ADMPPmeForce, # polarized + ) + """ + rc = 4.0 + H = Hamiltonian('tests/data/peg_sgnn.xml') + pdb = app.PDBFile('tests/data/peg4.pdb') + potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + + yield potential, H.paramset + + def test_sGNN_energy(self, pot_prm): + potential, paramset = pot_prm + rc = 0.4 + pdb = app.PDBFile('tests/data/peg4.pdb') + positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + positions = jnp.array(positions) + a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer) + box = jnp.array([a, b, c]) + # neighbor list + covalent_map = potential.meta["cov_map"] + + nblist = NeighborList(box, rc, covalent_map) + nblist.allocate(positions) + pairs = nblist.pairs + pot = potential.getPotentialFunc(names=["SGNNForce"]) + energy = pot(positions, box, pairs, paramset) + print(energy) + np.testing.assert_almost_equal(energy, -21.81780787, decimal=2) + +