From 5ce196df25a9bd2a38fb04d27dcb6dd0ac02f853 Mon Sep 17 00:00:00 2001 From: dingye Date: Sun, 17 Dec 2023 17:30:44 +0800 Subject: [PATCH] Add aux support in ml.py for MLP. Fix jit error when saving ml potential with save_dmff2tf.py. --- dmff/eann/eann.py | 5 +++-- dmff/generators/ml.py | 13 ++++++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/dmff/eann/eann.py b/dmff/eann/eann.py index bd7db7487..dc1edd5ca 100644 --- a/dmff/eann/eann.py +++ b/dmff/eann/eann.py @@ -247,7 +247,7 @@ def get_features(self, radial, dr, pairs, buffer_scales, orb_coeff): f_cut = cutoff_cosine(dr_norm, self.rc) neigh_list = jnp.concatenate((pairs,pairs[:,[1,0]]),axis=0) buffer_scales_ = jnp.concatenate((buffer_scales,buffer_scales),axis=0) - totneighbour = len(neigh_list) + totneighbour = neigh_list.shape[0] prefacs = f_cut.reshape(1, -1) angular = prefacs for ipsin in range(1,self.nipsin+1): @@ -310,7 +310,8 @@ def get_energy(positions, box, pairs, params): self.rs = params['density.rs'] self.inta = params['density.inta'] - radial_i, radial_j = get_gto(jnp.arange(len(dr_norm)), dr_norm, pairs, self.rc, self.rs, self.inta, self.elem_indices) + length_dr_norm = dr_norm.shape[0] + radial_i, radial_j = get_gto(jnp.arange(length_dr_norm), dr_norm, pairs, self.rc, self.rs, self.inta, self.elem_indices) radial = jnp.concatenate((radial_i,radial_j), axis=0) orb_coeff = params['density.params'][self.elem_indices,:] # (48,16) diff --git a/dmff/generators/ml.py b/dmff/generators/ml.py index ae747bb88..4bf5595ad 100644 --- a/dmff/generators/ml.py +++ b/dmff/generators/ml.py @@ -116,12 +116,19 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof n_elem, elem_indices = get_elem_indices(self.ommtopology) self.model = EANNForce(n_elem, elem_indices, n_gto=self.ngto, nipsin=self.nipsin, rc=self.rc) n_layers = self.model.n_layers - def potential_fn(positions, box, pairs, params): + + has_aux = False + if "has_aux" in kwargs and kwargs["has_aux"]: + has_aux = True + + def potential_fn(positions, box, pairs, params, aux=None): # convert unit to angstrom positions = positions * 10 box = box * 10 - - return self.model.get_energy(positions, box, pairs, params[self.name]) + if has_aux: + return self.model.get_energy(positions, box, pairs, params[self.name]), aux + else: + return self.model.get_energy(positions, box, pairs, params[self.name]) self._jaxPotential = potential_fn return potential_fn