Skip to content

Commit

Permalink
[JAX] Replace uses of jax._src.ad_checkpoint.optimization_barrier wit…
Browse files Browse the repository at this point in the history
…h jax.lax.optimization_barrier.

PiperOrigin-RevId: 672878937
  • Loading branch information
hawkinsp authored and TF2JAXDev committed Sep 10, 2024
1 parent b1425c0 commit db3f7d1
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from absl import logging

import jax
from jax._src.lax import control_flow as lax_control_flow
from jax.experimental import checkify
from jax.lib import xla_client
import jax.numpy as jnp
Expand Down Expand Up @@ -2486,7 +2485,7 @@ def _xla_optimization_barrier(proto):
def _func(*operands: jnp.ndarray) -> Tuple[jnp.ndarray, ...]:
# TODO(b/241584320) Note this does not reproduce the remat transform in the
# forward pass, which may require some heurstics when parsing the graphdef.
return lax_control_flow.optimization_barrier_p.bind(*operands)
return jax.lax.optimization_barrier_p.bind(*operands)

return _func

Expand Down

0 comments on commit db3f7d1

Please sign in to comment.