Skip to content

Commit

Permalink
Merge pull request #154 from dingye18/devel
Browse files Browse the repository at this point in the history
Add aux support in ml.py for MLP. Fix jit error when saving ml potential with save_dmff2tf.py
  • Loading branch information
KuangYu authored Jan 12, 2024
2 parents 0be33ff + 2547786 commit 37fb33a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 7 deletions.
12 changes: 11 additions & 1 deletion backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,17 @@ class OPENMM_EXPORT_DMFF DMFFForce : public OpenMM::Force {
* @param hasAux : true if model was saved with auxilary input.
*/
void setHasAux(const bool hasAux);

/**
* @brief Set the Cutoff for neighbor list fetching.
*
* @param cutoff
*/
void setCutoff(const double cutoff);
/**
* @brief get the DMFF graph file.
*
* @return const std::string&
*/
const std::string& getDMFFGraphFile() const;
/**
* @brief Get the Coord Unit Coefficient.
Expand Down
4 changes: 4 additions & 0 deletions backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ void DMFFForce::setHasAux(const bool hasAux){
this->has_aux = hasAux;
}

void DMFFForce::setCutoff(const double cutoff){
this->cutoff = cutoff;
}

double DMFFForce::getCoordUnitCoefficient() const {return coordCoeff;}
double DMFFForce::getForceUnitCoefficient() const {return forceCoeff;}
double DMFFForce::getEnergyUnitCoefficient() const {return energyCoeff;}
Expand Down
2 changes: 1 addition & 1 deletion backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public:

void setUnitTransformCoefficients(const double coordCoefficient, const double forceCoefficient, const double energyCoefficient);
void setHasAux(const bool hasAux);

void setCutoff(const double cutoff);
/*
* Add methods for casting a Force to a DMFFForce.
*/
Expand Down
10 changes: 10 additions & 0 deletions backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def setHasAux(self, has_aux = False):
has_aux (bool, optional): Defaults to False.
"""
self.dmff_force.setHasAux(has_aux)
return

def setCutoff(self, cutoff = 1.2):
"""Set the cutoff for the DMFF model.
Args:
cutoff (float, optional): Defaults to 1.2.
"""
self.dmff_force.setCutoff(cutoff)
return

def createSystem(self, topology):
"""Create the OpenMM System object for the DMFF model.
Expand Down
5 changes: 3 additions & 2 deletions dmff/eann/eann.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_features(self, radial, dr, pairs, buffer_scales, orb_coeff):
f_cut = cutoff_cosine(dr_norm, self.rc)
neigh_list = jnp.concatenate((pairs,pairs[:,[1,0]]),axis=0)
buffer_scales_ = jnp.concatenate((buffer_scales,buffer_scales),axis=0)
totneighbour = len(neigh_list)
totneighbour = neigh_list.shape[0]
prefacs = f_cut.reshape(1, -1)
angular = prefacs
for ipsin in range(1,self.nipsin+1):
Expand Down Expand Up @@ -310,7 +310,8 @@ def get_energy(positions, box, pairs, params):
self.rs = params['density.rs']
self.inta = params['density.inta']

radial_i, radial_j = get_gto(jnp.arange(len(dr_norm)), dr_norm, pairs, self.rc, self.rs, self.inta, self.elem_indices)
length_dr_norm = dr_norm.shape[0]
radial_i, radial_j = get_gto(jnp.arange(length_dr_norm), dr_norm, pairs, self.rc, self.rs, self.inta, self.elem_indices)
radial = jnp.concatenate((radial_i,radial_j), axis=0)
orb_coeff = params['density.params'][self.elem_indices,:] # (48,16)

Expand Down
13 changes: 10 additions & 3 deletions dmff/generators/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,19 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof
n_elem, elem_indices = get_elem_indices(self.ommtopology)
self.model = EANNForce(n_elem, elem_indices, n_gto=self.ngto, nipsin=self.nipsin, rc=self.rc)
n_layers = self.model.n_layers
def potential_fn(positions, box, pairs, params):

has_aux = False
if "has_aux" in kwargs and kwargs["has_aux"]:
has_aux = True

def potential_fn(positions, box, pairs, params, aux=None):
# convert unit to angstrom
positions = positions * 10
box = box * 10

return self.model.get_energy(positions, box, pairs, params[self.name])
if has_aux:
return self.model.get_energy(positions, box, pairs, params[self.name]), aux
else:
return self.model.get_energy(positions, box, pairs, params[self.name])

self._jaxPotential = potential_fn
return potential_fn
Expand Down

0 comments on commit 37fb33a

Please sign in to comment.