Skip to content

Commit

Permalink
Correct the order of improper matching
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 11, 2023
1 parent 118b192 commit b97c462
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 15 deletions.
64 changes: 50 additions & 14 deletions dmff/generators/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,14 +607,51 @@ def _find_proper_key_index(self, key: Tuple[str, str, str, str]) -> int:
return wc_patch[0]
return None

def _find_improper_key_index(self, key: Tuple[str, str, str, str]):
pre_orders = [(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3),
(0, 2, 3, 1), (0, 3, 1, 2), (0, 3, 2, 1)]
for i, k in enumerate(self.imp_keys):
for order in pre_orders:
if k[0] == key[0] and k[1] in ["", key[order[1]]] and k[2] in ["", key[order[2]]] and k[3] in ["", key[order[3]]]:
return i, order
return None, None
def _find_improper_key_index(self, improper):

type1 = improper[0].meta[self.key_type]
type2 = improper[1].meta[self.key_type]
type3 = improper[2].meta[self.key_type]
type4 = improper[3].meta[self.key_type]

def _wild_match(tp, tps):
if tps == "":
return True
if tp == tps:
return True
return False

matched = None
for ndef, tordef in enumerate(self.imp_keys):
types1 = tordef[0]
types2 = tordef[1]
types3 = tordef[2]
types4 = tordef[3]
hasWildcard = ("" in (types1, types2, types3, types4))

if matched is not None and hasWildcard:
continue

import itertools
if type1 in types1:
for (t2, t3, t4) in itertools.permutations(((type2, 1), (type3, 2), (type4, 3))):
if _wild_match(t2[0], types2) and _wild_match(t3[0], types3) and _wild_match(t4[0], types4):
a1 = improper[t2[1]].index
a2 = improper[t3[1]].index
e1 = improper[t2[1]].element
e2 = improper[t3[1]].element
m1 = app.element.get_by_symbol(e1).mass
m2 = app.element.get_by_symbol(e2).mass
if e1 == e2 and a1 > a2:
(a1, a2) = (a2, a1)
elif e1 != "C" and (e2 == "C" or m1 < m2):
(a1, a2) = (a2, a1)
matched = (a1, a2, improper[0].index, improper[t4[1]].index, ndef)
break
if matched is None:
return None, None
return matched[4], matched[:4]


def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
nonbondedCutoff, args):
Expand Down Expand Up @@ -682,18 +719,17 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,

improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period = [], [], [], [], [], []
for improper in impr_list:
iidx, order = self._find_improper_key_index(
(improper[0].meta[self.key_type], improper[1].meta[self.key_type], improper[2].meta[self.key_type], improper[3].meta[self.key_type]))
iidx, order = self._find_improper_key_index(improper)
if iidx is None:
continue

prm_indices = self.imp_key_to_prms[iidx]
for prm_idx in prm_indices:
prm_period = self.imp_periods[prm_idx]
improper_a1.append(improper[order[0]].index)
improper_a2.append(improper[order[1]].index)
improper_a3.append(improper[order[2]].index)
improper_a4.append(improper[order[3]].index)
improper_a1.append(atoms[order[0]].index)
improper_a2.append(atoms[order[1]].index)
improper_a3.append(atoms[order[2]].index)
improper_a4.append(atoms[order[3]].index)
improper_indices.append(prm_idx)
improper_period.append(prm_period)
improper_a1 = jnp.array(improper_a1)
Expand Down
5 changes: 4 additions & 1 deletion examples/classical/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def getEnergyDecomposition(context, forcegroups):

forcegroups = forcegroupify(system)
integrator = mm.VerletIntegrator(0.1)
context = mm.Context(system, integrator)
context = mm.Context(system, integrator, mm.Platform.getPlatformByName("Reference"))
context.setPositions(pdb.positions)
state = context.getState(getEnergy=True)
energy = state.getPotentialEnergy()
Expand Down Expand Up @@ -80,3 +80,6 @@ def getEnergyDecomposition(context, forcegroups):

nbE = pot.dmff_potentials['NonbondedForce']
print("Nonbonded:", nbE(positions, box, pairs, params))

etotal = pot.getPotentialFunc()
print("Total:", etotal(positions, box, pairs, params))

0 comments on commit b97c462

Please sign in to comment.