diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index 9dc7949793e88..950dd0512b6f7 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -349,7 +349,13 @@ def mish( """ return ivy.mish(self._data, complex_mode=complex_mode, out=out) - def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def hardswish( + self: ivy.Array, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: """ Apply the hardswish activation function element-wise. @@ -357,6 +363,9 @@ def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Arr ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -385,4 +394,4 @@ def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Arr b: ivy.array([0., 5.]) } """ - return ivy.hardswish(self._data, out=out) + return ivy.hardswish(self._data, complex_mode=complex_mode, out=out) diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index 98bcfced15c42..9d5883bd1d213 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -1120,6 +1120,7 @@ def _static_hardswish( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -1142,6 +1143,9 @@ def _static_hardswish( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -1169,6 +1173,7 @@ def _static_hardswish( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -1180,6 +1185,7 @@ def hardswish( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -1202,6 +1208,9 @@ def hardswish( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -1228,5 +1237,6 @@ def hardswish( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index a6e968f7a24df..2dc0643665959 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -108,5 +108,11 @@ def mish( return x * jnp.tanh(jax.nn.softplus(x)) -def hardswish(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def hardswish( + x: JaxArray, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[JaxArray] = None, +) -> JaxArray: return jax.nn.hard_swish(x) diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index ecc140cdc56c5..cb9f698df3d39 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -143,7 +143,13 @@ def mish( @_scalar_output_to_0d_array -def hardswish(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def hardswish( + x: np.ndarray, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[np.ndarray] = None, +) -> np.ndarray: max_x_3 = np.maximum(x + 3, 0, dtype=x.dtype) return (x * np.minimum(max_x_3, 6, out=out, dtype=x.dtype) / 6).astype(x.dtype) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index cac37f0069fa7..ac1343e86aa9f 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -203,6 +203,10 @@ def mish( {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version ) def hardswish( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None + x: paddle.Tensor, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: return F.hardswish(x) diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 62bab21a8229c..ffd4efbe705b1 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -152,5 +152,11 @@ def mish( @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) -def hardswish(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def hardswish( + x: Tensor, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[Tensor] = None, +) -> Tensor: return x * tf.nn.relu6(x + 3) / 6 diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index b1dcf2fea30af..b863be601dfa2 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -156,6 +156,10 @@ def mish( backend_version, ) def hardswish( - x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None + x: torch.Tensor, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.nn.functional.hardswish(x) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index e68e0a0e16e37..7f3595bf92353 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -731,6 +731,19 @@ def mish( return current_backend(x).mish(x, out=out) +def _hardswish_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + def hard_sigmoid(x): + return ivy.relu6(x + 3.0) / 6 + + return ivy.multiply(x, hard_sigmoid(x).astype(x.dtype)) + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -738,8 +751,13 @@ def mish( @handle_out_argument @to_native_arrays_and_back @handle_array_function +@handle_complex_input def hardswish( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the hardswish activation function element-wise. @@ -748,6 +766,9 @@ def hardswish( ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -777,3 +798,6 @@ def hardswish( } """ return current_backend(x).hardswish(x, out=out) + + +hardswish.jax_like = _hardswish_jax_like diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 218941bfb37fd..a2a789701a64b 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -385,8 +385,17 @@ def _forward(self, x): class Hardswish(Module): - def __init__(self): - """Apply the HARDSWISH activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the HARDSWISH activation function. + + Parameters + ---------- + complex_mode + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ + self._complex_mode = complex_mode Module.__init__(self) def _forward(self, x): @@ -402,7 +411,7 @@ def _forward(self, x): ret The outputs following the HARDSWISH activation *[batch_shape, d]* """ - return ivy.hardswish(x) + return ivy.hardswish(x, complex_mode=self._complex_mode) class Logit(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 5c2d7545cfc16..be20daaabf52e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -245,7 +245,7 @@ def test_jax_hard_silu( @handle_frontend_test( fn_tree="jax.nn.hard_swish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-10, max_value=10, safety_factor_scale="linear", diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 8254cc3141a04..8c47ca2f40447 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -48,15 +48,17 @@ def test_gelu( @handle_test( fn_tree="functional.ivy.hardswish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), + complex_mode=st.sampled_from(["jax", "split", "magnitude"]), ) def test_hardswish( *, dtype_and_x, + complex_mode, test_flags, backend_fw, fn_name, @@ -70,6 +72,7 @@ def test_hardswish( fn_name=fn_name, on_device=on_device, x=x[0], + complex_mode=complex_mode, )