Skip to content

Commit

Permalink
ELU added to Ivy Stateful Api (#19820)
Browse files Browse the repository at this point in the history
Co-authored-by: tejeshbhalla <[email protected]>
  • Loading branch information
tejeshbhalla and tejeshbhalla authored Jul 26, 2023
1 parent eded225 commit c0f3bb6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
21 changes: 21 additions & 0 deletions ivy/stateful/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,24 @@ def _forward(self, x):
The outputs following the SELU activation *[batch_shape, d]*
"""
return ivy.selu(x)


class ELU(Module):
def __init__(self):
"""Apply the ELU activation function."""
Module.__init__(self)

def _forward(self, x, alpha=1.0):
"""
Parameters
----------
x
Inputs to process *[batch_shape, d]*.
alpha
scaler for controlling the slope of the function for x <= 0 Default: 1.0
Returns
-------
ret
The outputs following the ELU activation *[batch_shape, d]*
"""
return ivy.elu(x, alpha=alpha)
46 changes: 46 additions & 0 deletions ivy_tests/test_ivy/test_stateful/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,49 @@ def test_selu(
test_gradients=test_gradients,
on_device=on_device,
)


# ELU
@handle_method(
method_tree="stateful.activations.ELU.__call__",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
num_arrays=2,
shared_dtype=True,
min_num_dims=2,
large_abs_safety_factor=8,
small_abs_safety_factor=8,
safety_factor_scale="log",
),
method_num_positional_args=helpers.num_positional_args(fn_name="ELU._forward"),
test_gradients=st.just(True),
alpha=helpers.floats(min_value=0.1, max_value=1),
)
def test_elu(
*,
dtype_and_x,
alpha,
test_gradients,
class_name,
method_name,
ground_truth_backend,
init_flags,
method_flags,
on_device,
):
input_dtype, x = dtype_and_x
helpers.test_method(
ground_truth_backend=ground_truth_backend,
init_flags=init_flags,
method_flags=method_flags,
init_input_dtypes=input_dtype,
method_input_dtypes=input_dtype,
init_all_as_kwargs_np={},
method_all_as_kwargs_np={"x": x[0], "alpha": alpha},
class_name=class_name,
method_name=method_name,
rtol_=1e-2,
atol_=1e-2,
test_gradients=test_gradients,
on_device=on_device,
)

0 comments on commit c0f3bb6

Please sign in to comment.