Skip to content

Commit

Permalink
Add jaxopt requirement in github workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 22, 2023
1 parent 81741de commit 4aa5b77
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
conda create -n dmff -y python=${{ matrix.python-version }} numpy openmm==7.7.0 pytest rdkit biopandas openbabel mdtraj ambertools -c conda-forge
conda activate dmff
pip install --upgrade pip
pip install jax jaxlib networkx parmed pymbar==4.0.1 chex==0.1.4 tqdm
pip install jax jaxlib jaxopt networkx parmed pymbar==4.0.1 chex==0.1.4 tqdm
- name: Install DMFF
run: |
source $CONDA/bin/activate dmff && pip install .
Expand Down
5 changes: 5 additions & 0 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ def mask_index(idx, max_idx):
return jnp.piecewise(
idx, [idx < max_idx, idx >= max_idx], [lambda x: CONST_1, lambda x: CONST_0]
)


mask_index = jax.vmap(mask_index, in_axes=(0, None))


@jit_condition()
def group_sum(val_list, indices):
max_idx = val_list.shape[0]
mask = mask_index(indices, max_idx)
return jnp.sum(val_list[indices] * mask)


group_sum = jax.vmap(group_sum, in_axes=(None, 0))


Expand Down

0 comments on commit 4aa5b77

Please sign in to comment.