Skip to content

Commit

Permalink
Add aux support in ml.py for MLP. Fix jit error when saving ml potent…
Browse files Browse the repository at this point in the history
…ial with save_dmff2tf.py.
  • Loading branch information
dingye18 committed Dec 17, 2023
1 parent f91e947 commit 5ce196d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
5 changes: 3 additions & 2 deletions dmff/eann/eann.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_features(self, radial, dr, pairs, buffer_scales, orb_coeff):
f_cut = cutoff_cosine(dr_norm, self.rc)
neigh_list = jnp.concatenate((pairs,pairs[:,[1,0]]),axis=0)
buffer_scales_ = jnp.concatenate((buffer_scales,buffer_scales),axis=0)
totneighbour = len(neigh_list)
totneighbour = neigh_list.shape[0]
prefacs = f_cut.reshape(1, -1)
angular = prefacs
for ipsin in range(1,self.nipsin+1):
Expand Down Expand Up @@ -310,7 +310,8 @@ def get_energy(positions, box, pairs, params):
self.rs = params['density.rs']
self.inta = params['density.inta']

radial_i, radial_j = get_gto(jnp.arange(len(dr_norm)), dr_norm, pairs, self.rc, self.rs, self.inta, self.elem_indices)
length_dr_norm = dr_norm.shape[0]
radial_i, radial_j = get_gto(jnp.arange(length_dr_norm), dr_norm, pairs, self.rc, self.rs, self.inta, self.elem_indices)
radial = jnp.concatenate((radial_i,radial_j), axis=0)
orb_coeff = params['density.params'][self.elem_indices,:] # (48,16)

Expand Down
13 changes: 10 additions & 3 deletions dmff/generators/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,19 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof
n_elem, elem_indices = get_elem_indices(self.ommtopology)
self.model = EANNForce(n_elem, elem_indices, n_gto=self.ngto, nipsin=self.nipsin, rc=self.rc)
n_layers = self.model.n_layers
def potential_fn(positions, box, pairs, params):

has_aux = False
if "has_aux" in kwargs and kwargs["has_aux"]:
has_aux = True

def potential_fn(positions, box, pairs, params, aux=None):
# convert unit to angstrom
positions = positions * 10
box = box * 10

return self.model.get_energy(positions, box, pairs, params[self.name])
if has_aux:
return self.model.get_energy(positions, box, pairs, params[self.name]), aux
else:
return self.model.get_energy(positions, box, pairs, params[self.name])

self._jaxPotential = potential_fn
return potential_fn
Expand Down

0 comments on commit 5ce196d

Please sign in to comment.