Skip to content

Commit

Permalink
feat ivy: adding backend implementations for ivy.requires_gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
YushaArif99 committed Sep 7, 2024
1 parent feeaa1c commit 6aa399c
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ivy/functional/backends/jax/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def stop_gradient(
) -> JaxArray:
return jlax.stop_gradient(x)

def requires_gradient(
x: JaxArray,
) -> bool:
if isinstance(x, jax.core.Tracer):
# JAX tracers indicate whether a value is being traced for gradients
return True
return False

def jac(func: Callable):
def grad_fn(x_in):
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/backends/numpy/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def is_variable(x, /, *, exclusive=False):
return False


def requires_gradient(x,) -> bool:
return False

def variable_data(x, /):
return x

Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/backends/tensorflow/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ def stop_gradient(
return variable(x)
return x

def requires_gradient(
x: Union[tf.Tensor, tf.Variable],
) -> bool:
is_var = is_variable(x)
if is_var:
return x.trainable
return False


def jac(func: Callable):
def grad_fn(x_in):
Expand Down
8 changes: 8 additions & 0 deletions ivy/functional/backends/torch/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ def stop_gradient(
return x
return x.detach()

def requires_gradient(
x: Optional[torch.Tensor],
) -> bool:
is_var = is_variable(x)
if is_var:
return x.requires_grad
return False


def jac(func: Callable):
def grad_fn(x_in):
Expand Down
25 changes: 25 additions & 0 deletions ivy/functional/ivy/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,31 @@ def stop_gradient(
return current_backend(x).stop_gradient(x, preserve_type=preserve_type, out=out)


@handle_exceptions
@handle_backend_invalid
@handle_nestable
@handle_array_like_without_promotion
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device
def requires_gradient(
x: Union[ivy.Array, ivy.NativeArray],
) -> bool:
"""Check if gradient computation is enabled for the given array.
Parameters
----------
x
Array to check whether it requires gradient computation.
Returns
-------
ret
A boolean indicating whether gradient computation is enabled for the input array.
"""
return current_backend(x).requires_gradient(x)
# AutoGrad #


Expand Down

0 comments on commit 6aa399c

Please sign in to comment.