Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove infer_device and rename handle_device_shifting decorator #23373

Merged
merged 19 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9419b6a
removed infer_device function
ShreyanshBardia Sep 10, 2023
392ef32
renamed handle_device_shifting to handle_device
ShreyanshBardia Sep 10, 2023
d1372e0
removed unwanted changed
ShreyanshBardia Sep 10, 2023
3f01570
Merge branch 'main' into refactor-decorator
ShreyanshBardia Sep 10, 2023
5ff99e1
removed infer_device from FN_DECORATORS list
ShreyanshBardia Sep 10, 2023
42b94b5
removed mention of infer_device from docstring
ShreyanshBardia Sep 10, 2023
ebb802f
added handle_device to functions infer_device was present but handle_…
ShreyanshBardia Sep 11, 2023
d2e6cea
Merge branch 'main' into refactor-decorator
ShreyanshBardia Sep 11, 2023
ccbf232
fixing function inspecting for torch
ShreyanshBardia Sep 13, 2023
392b11f
assigning device to kwargs for jax
ShreyanshBardia Sep 13, 2023
2bd202c
set device to None by default in backends
ShreyanshBardia Sep 22, 2023
2ece4f3
Merge branch 'main' into refactor-decorator
ShreyanshBardia Sep 22, 2023
ea58017
remove changes added by mistake
ShreyanshBardia Sep 22, 2023
07cac6a
set default device value None
ShreyanshBardia Sep 22, 2023
4ac4be4
updated paddle/experimental
ShreyanshBardia Sep 23, 2023
5c21849
removed max_unpool which might have been added during merge
ShreyanshBardia Sep 23, 2023
5033955
updated handle_device_shifting for newly added functions
ShreyanshBardia Sep 25, 2023
248f084
Merge branch 'unifyai:main' into refactor-decorator
ShreyanshBardia Sep 25, 2023
48cc8d5
updated handle_device_shifting for newly added functions
ShreyanshBardia Sep 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 8 additions & 42 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
# for wrapping (sequence matters)
FN_DECORATORS = [
"handle_complex_input",
"infer_device",
"handle_device_shifting",
"handle_device",
"infer_dtype",
"handle_array_function",
"outputs_to_ivy_arrays",
Expand Down Expand Up @@ -779,42 +778,9 @@ def _infer_dtype(*args, dtype=None, **kwargs):
# ----------------#


def infer_device(fn: Callable) -> Callable:
def handle_device(fn: Callable) -> Callable:
@functools.wraps(fn)
def _infer_device(*args, device=None, **kwargs):
"""
Determine the correct `device`, and then calls the function with the `device`
passed explicitly.

Parameters
----------
args
The arguments to be passed to the function.

device
The device for the function.

kwargs
The keyword arguments to be passed to the function.

Returns
-------
The return of the function, with `device` passed explicitly.
"""
# find the first array argument, if required
arr = None if ivy.exists(device) else _get_first_array(*args, **kwargs)
# infer the correct device
device = ivy.default_device(device, item=arr, as_native=True)
# call the function with device provided explicitly
return fn(*args, device=device, **kwargs)

_infer_device.infer_device = True
return _infer_device


def handle_device_shifting(fn: Callable) -> Callable:
@functools.wraps(fn)
def _handle_device_shifting(*args, **kwargs):
def _handle_device(*args, **kwargs):
"""
Move all array inputs of the function to `ivy.default_device()`.

Expand Down Expand Up @@ -857,8 +823,8 @@ def _handle_device_shifting(*args, **kwargs):
)
return fn(*args, **kwargs)

_handle_device_shifting.handle_device_shifting = True
return _handle_device_shifting
_handle_device.handle_device = True
return _handle_device


# Inplace Update Handling #
Expand Down Expand Up @@ -1074,9 +1040,9 @@ def _wrap_function(
"""
Apply wrapping to backend implementation `to_wrap` if the original implementation
`original` is also wrapped, and if `to_wrap` is not already wrapped. Attributes
`handle_nestable`, `infer_device` etc are set during wrapping, hence indicate to us
whether a certain function has been wrapped or not. Also handles wrapping of the
`linalg` namespace.
`handle_nestable` etc are set during wrapping, hence indicate to us whether a
certain function has been wrapped or not. Also handles wrapping of the `linalg`
namespace.

Parameters
----------
Expand Down
20 changes: 10 additions & 10 deletions ivy/functional/ivy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
to_native_arrays_and_back,
handle_nestable,
handle_array_like_without_promotion,
handle_device_shifting,
handle_device,
handle_complex_input,
handle_backend_invalid,
)
Expand All @@ -38,7 +38,7 @@ def _gelu_jax_like(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
@handle_complex_input
def gelu(
x: Union[ivy.Array, ivy.NativeArray],
Expand Down Expand Up @@ -128,7 +128,7 @@ def _leaky_relu_jax_like(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
@handle_complex_input
def leaky_relu(
x: Union[ivy.Array, ivy.NativeArray],
Expand Down Expand Up @@ -209,7 +209,7 @@ def leaky_relu(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
def log_softmax(
x: Union[ivy.Array, ivy.NativeArray],
/,
Expand Down Expand Up @@ -301,7 +301,7 @@ def _relu_jax_like(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
@handle_complex_input
def relu(
x: Union[ivy.Array, ivy.NativeArray],
Expand Down Expand Up @@ -373,7 +373,7 @@ def relu(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
def sigmoid(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
Expand Down Expand Up @@ -451,7 +451,7 @@ def sigmoid(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
@handle_complex_input
def softmax(
x: Union[ivy.Array, ivy.NativeArray],
Expand Down Expand Up @@ -548,7 +548,7 @@ def _softplus_jax_like(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
@handle_complex_input
def softplus(
x: Union[ivy.Array, ivy.NativeArray],
Expand Down Expand Up @@ -622,7 +622,7 @@ def softplus(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
def softsign(
x: Union[ivy.Array, ivy.NativeArray],
/,
Expand Down Expand Up @@ -662,7 +662,7 @@ def softsign(
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_device
def mish(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/ivy/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
handle_array_like_without_promotion,
to_native_arrays_and_back,
to_ivy_arrays_and_back,
handle_device_shifting,
handle_device,
)


Expand Down Expand Up @@ -60,7 +60,7 @@ def if_else(

@to_native_arrays_and_back
@handle_array_like_without_promotion
@handle_device_shifting
@handle_device
def _if_else(cond, body_fn, orelse_fn, vars):
return current_backend().if_else(cond, body_fn, orelse_fn, vars)

Expand Down Expand Up @@ -116,7 +116,7 @@ def while_loop(

@to_native_arrays_and_back
@handle_array_like_without_promotion
@handle_device_shifting
@handle_device
def _while_loop(test_fn, body_fn, vars):
return current_backend().while_loop(test_fn, body_fn, vars)

Expand Down
Loading
Loading