Skip to content

Commit

Permalink
Update expected attrs of VarHandleOp -- add debug_name.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558314532
  • Loading branch information
jamesmullenbach authored and TF2JAXDev committed Aug 19, 2023
1 parent 22dfe41 commit a66033d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2266,7 +2266,8 @@ def _func(x: jnp.ndarray) -> List[jnp.ndarray]:
@register_operation("VarHandleOp")
def _var_handle(proto):
_check_attrs(
proto, {"shared_name", "container", "allowed_devices", "shape", "dtype"})
proto, {"shared_name", "container", "allowed_devices", "shape", "dtype",
"debug_name"})

def _func():
raise ValueError(f"VarHandleOp `{proto.name}` cannot be evaluated.")
Expand Down

0 comments on commit a66033d

Please sign in to comment.