Skip to content

Commit

Permalink
Add QEQ test
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 22, 2023
1 parent 5acacff commit 81741de
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 22 deletions.
34 changes: 16 additions & 18 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@


@jit_condition()
def group_sum(val_list, indices):
max_idx = indices.max()
exceed = jnp.piecewise(
indices,
[indices < max_idx, indices >= max_idx],
[lambda x: CONST_1, lambda x: CONST_0],
def mask_index(idx, max_idx):
return jnp.piecewise(
idx, [idx < max_idx, idx >= max_idx], [lambda x: CONST_1, lambda x: CONST_0]
)
return jnp.sum(val_list[indices] * exceed)

mask_index = jax.vmap(mask_index, in_axes=(0, None))

group_sum_vmap = jax.vmap(group_sum, in_axes=(None, 0))
@jit_condition()
def group_sum(val_list, indices):
max_idx = val_list.shape[0]
mask = mask_index(indices, max_idx)
return jnp.sum(val_list[indices] * mask)
group_sum = jax.vmap(group_sum, in_axes=(None, 0))


# @jit_condition
def padding_consts(const_list, max_idx):
max_length = max([len(i) for i in const_list])
new_const_list = np.zeros((len(const_list), max_length)) + max_idx
Expand All @@ -50,13 +50,13 @@ def padding_consts(const_list, max_idx):

@jit_condition()
def E_constQ(q, lagmt, const_list, const_vals):
constraint = (group_sum_vmap(q, const_list) - const_vals) * lagmt
constraint = (group_sum(q, const_list) - const_vals) * lagmt
return jnp.sum(constraint)


@jit_condition()
def E_constP(q, lagmt, const_list, const_vals):
constraint = group_sum_vmap(q, const_list) * const_vals
constraint = group_sum(q, const_list) * const_vals
return jnp.sum(constraint)


Expand Down Expand Up @@ -202,10 +202,10 @@ def __init__(
constQ: bool = True,
pbc_flag: bool = True,
):
if not isinstance(const_vals, jnp.ndarray):
self.const_vals = jnp.array(const_vals)
else:
self.const_vals = const_vals
const_vals = np.array(const_vals)
if neutral_flag:
const_vals = const_vals - np.sum(const_vals) / len(const_vals)
self.const_vals = jnp.array(const_vals)
assert len(const_list) == len(
const_vals
), "const_list and const_vals must have the same length"
Expand Down Expand Up @@ -290,8 +290,6 @@ def get_energy(positions, box, pairs, mscales, eta, chi, J):
b_0 = jax.lax.stop_gradient(b_0)
q_0 = b_0[:-n_const]
lagmt_0 = b_0[-n_const:]
print("Q:", q_0)
print("Lagrange_multi:", lagmt_0)

energy = E_full(
q_0,
Expand Down
8 changes: 4 additions & 4 deletions dmff/generators/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def overwrite(self, paramset: ParamSet) -> None:
J0 = J[nidx]
eta0 = eta[nidx]
mask = atom_mask[nidx]
self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["chi"] = str(chi0)
self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["J"] = str(J0)
self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["eta"] = str(eta0)
self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["chi"] = chi0
self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["J"] = J0
self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["eta"] = eta0
if mask < 0.999:
self.ffinfo["Forces"][self.name]["node"][idx]["attrib"]["mask"] = "true"

Expand Down Expand Up @@ -177,7 +177,7 @@ def createPotential(
aidx = [a.index for a in r.atoms()]
const_list.append(aidx)
const_vals.append(sum(init_q[aidx]))

qeq_force = ADMPQeqForce(
init_q, r_cut, kappa, K, damp_mod=self.damp_mod,
const_list=const_list, const_vals=const_vals,
Expand Down
38 changes: 38 additions & 0 deletions tests/test_admp/test_qeq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import openmm.app as app
import openmm.unit as unit
from dmff.api import Hamiltonian
from dmff.api import DMFFTopology
from dmff.api.xmlio import XMLIO
from dmff import NeighborList
import jax.numpy as jnp
import numpy as np


def test_qeq_energy():
xml = XMLIO()
xml.loadXML("tests/data/qeq.xml")
res = xml.parseResidues()
charges = [a["charge"] for a in res[0]["particles"]]
types = [a["type"] for a in res[0]["particles"]]

pdb = app.PDBFile("tests/data/qeq.pdb")
dmfftop = DMFFTopology(from_top=pdb.topology)
pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
pos = jnp.array(pos)
box = dmfftop.getPeriodicBoxVectors()
hamilt = Hamiltonian("tests/data/qeq.xml")

atoms = [a for a in dmfftop.atoms()]
for na, a in enumerate(atoms):
a.meta["charge"] = charges[na]
a.meta["type"] = types[na]

nblist = NeighborList(box, 0.6, dmfftop.buildCovMat())
pairs = nblist.allocate(pos)

pot = hamilt.createPotential(dmfftop, nonbondedCutoff=0.6*unit.nanometer, nonbondedMethod=app.PME,
ethresh=5e-4, neutral=True, slab=False, constQ=True
)
efunc = pot.getPotentialFunc()
energy = efunc(pos, box, pairs, hamilt.paramset.parameters)
np.testing.assert_almost_equal(energy, -37.84692763, decimal=3)

0 comments on commit 81741de

Please sign in to comment.