Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove infer_device and rename handle_device_shifting decorator #23373

Merged
merged 19 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9419b6a
removed infer_device function
ShreyanshBardia Sep 10, 2023
392ef32
renamed handle_device_shifting to handle_device
ShreyanshBardia Sep 10, 2023
d1372e0
removed unwanted changed
ShreyanshBardia Sep 10, 2023
3f01570
Merge branch 'main' into refactor-decorator
ShreyanshBardia Sep 10, 2023
5ff99e1
removed infer_device from FN_DECORATORS list
ShreyanshBardia Sep 10, 2023
42b94b5
removed mention of infer_device from docstring
ShreyanshBardia Sep 10, 2023
ebb802f
added handle_device to functions infer_device was present but handle_…
ShreyanshBardia Sep 11, 2023
d2e6cea
Merge branch 'main' into refactor-decorator
ShreyanshBardia Sep 11, 2023
ccbf232
fixing function inspecting for torch
ShreyanshBardia Sep 13, 2023
392b11f
assigning device to kwargs for jax
ShreyanshBardia Sep 13, 2023
2bd202c
set device to None by default in backends
ShreyanshBardia Sep 22, 2023
2ece4f3
Merge branch 'main' into refactor-decorator
ShreyanshBardia Sep 22, 2023
ea58017
remove changes added by mistake
ShreyanshBardia Sep 22, 2023
07cac6a
set default device value None
ShreyanshBardia Sep 22, 2023
4ac4be4
updated paddle/experimental
ShreyanshBardia Sep 23, 2023
5c21849
removed max_unpool which might have been added during merge
ShreyanshBardia Sep 23, 2023
5033955
updated handle_device_shifting for newly added functions
ShreyanshBardia Sep 25, 2023
248f084
Merge branch 'unifyai:main' into refactor-decorator
ShreyanshBardia Sep 25, 2023
48cc8d5
updated handle_device_shifting for newly added functions
ShreyanshBardia Sep 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 8 additions & 42 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
# for wrapping (sequence matters)
FN_DECORATORS = [
"handle_complex_input",
"infer_device",
"handle_device_shifting",
"handle_device",
"infer_dtype",
"handle_array_function",
"outputs_to_ivy_arrays",
Expand Down Expand Up @@ -778,42 +777,9 @@ def _infer_dtype(*args, dtype=None, **kwargs):
# ----------------#


def infer_device(fn: Callable) -> Callable:
def handle_device(fn: Callable) -> Callable:
@functools.wraps(fn)
def _infer_device(*args, device=None, **kwargs):
"""
Determine the correct `device`, and then calls the function with the `device`
passed explicitly.

Parameters
----------
args
The arguments to be passed to the function.

device
The device for the function.

kwargs
The keyword arguments to be passed to the function.

Returns
-------
The return of the function, with `device` passed explicitly.
"""
# find the first array argument, if required
arr = None if ivy.exists(device) else _get_first_array(*args, **kwargs)
# infer the correct device
device = ivy.default_device(device, item=arr, as_native=True)
# call the function with device provided explicitly
return fn(*args, device=device, **kwargs)

_infer_device.infer_device = True
return _infer_device


def handle_device_shifting(fn: Callable) -> Callable:
@functools.wraps(fn)
def _handle_device_shifting(*args, **kwargs):
def _handle_device(*args, **kwargs):
"""
Move all array inputs of the function to `ivy.default_device()`.

Expand Down Expand Up @@ -856,8 +822,8 @@ def _handle_device_shifting(*args, **kwargs):
)
return fn(*args, **kwargs)

_handle_device_shifting.handle_device_shifting = True
return _handle_device_shifting
_handle_device.handle_device = True
return _handle_device


# Inplace Update Handling #
Expand Down Expand Up @@ -1071,9 +1037,9 @@ def _wrap_function(
"""
Apply wrapping to backend implementation `to_wrap` if the original implementation
`original` is also wrapped, and if `to_wrap` is not already wrapped. Attributes
`handle_nestable`, `infer_device` etc are set during wrapping, hence indicate to us
whether a certain function has been wrapped or not. Also handles wrapping of the
`linalg` namespace.
`handle_nestable` etc are set during wrapping, hence indicate to us whether a
certain function has been wrapped or not. Also handles wrapping of the `linalg`
namespace.

Parameters
----------
Expand Down
28 changes: 14 additions & 14 deletions ivy/functional/backends/jax/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def arange(
step: float = 1,
*,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if dtype:
Expand Down Expand Up @@ -69,7 +69,7 @@ def asarray(
*,
copy: Optional[bool] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
ivy.utils.assertions._check_jax_x64_flag(dtype)
Expand All @@ -83,7 +83,7 @@ def empty(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.empty(shape, dtype)
Expand All @@ -94,7 +94,7 @@ def empty_like(
/,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.empty_like(x, dtype=dtype)
Expand All @@ -108,7 +108,7 @@ def eye(
k: int = 0,
batch_shape: Optional[Union[int, Sequence[int]]] = None,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if n_cols is None:
Expand All @@ -135,7 +135,7 @@ def full(
fill_value: Union[int, float, bool],
*,
dtype: Optional[Union[ivy.Dtype, jnp.dtype]] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
dtype = ivy.default_dtype(dtype=dtype, item=fill_value, as_native=True)
Expand All @@ -148,7 +148,7 @@ def full_like(
fill_value: Number,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.full_like(x, fill_value, dtype=dtype)
Expand All @@ -165,7 +165,7 @@ def linspace(
axis: Optional[int] = None,
endpoint: bool = True,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if axis is None:
Expand Down Expand Up @@ -239,7 +239,7 @@ def ones(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.ones(shape, dtype)
Expand All @@ -250,7 +250,7 @@ def ones_like(
/,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.ones_like(x, dtype=dtype)
Expand All @@ -268,7 +268,7 @@ def zeros(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.zeros(shape, dtype)
Expand All @@ -279,7 +279,7 @@ def zeros_like(
/,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.zeros_like(x, dtype=dtype)
Expand Down Expand Up @@ -314,7 +314,7 @@ def one_hot(
off_value: Optional[Number] = None,
axis: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
on_none = on_value is None
Expand Down Expand Up @@ -356,6 +356,6 @@ def triu_indices(
k: int = 0,
/,
*,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
) -> Tuple[JaxArray]:
return jnp.triu_indices(n=n_rows, k=k, m=n_cols)
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def tril_indices(
k: int = 0,
/,
*,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
) -> Tuple[JaxArray, ...]:
return jnp.tril_indices(n=n_rows, k=k, m=n_cols)

Expand Down
8 changes: 4 additions & 4 deletions ivy/functional/backends/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def random_uniform(
low: Union[float, JaxArray] = 0.0,
high: Union[float, JaxArray] = 1.0,
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
dtype: jnp.dtype,
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
Expand All @@ -66,7 +66,7 @@ def random_normal(
mean: Union[float, JaxArray] = 0.0,
std: Union[float, JaxArray] = 1.0,
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
dtype: jnp.dtype,
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
Expand All @@ -91,7 +91,7 @@ def multinomial(
batch_size: int = 1,
probs: Optional[JaxArray] = None,
replace: bool = True,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
Expand Down Expand Up @@ -132,7 +132,7 @@ def randint(
/,
*,
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
dtype: Optional[Union[jnp.dtype, ivy.Dtype]] = None,
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
Expand Down
43 changes: 29 additions & 14 deletions ivy/functional/backends/numpy/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def arange(
step: float = 1,
*,
dtype: Optional[np.dtype] = None,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if dtype:
Expand All @@ -57,7 +57,7 @@ def asarray(
*,
copy: Optional[bool] = None,
dtype: Optional[np.dtype] = None,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
ret = _to_device(np.asarray(obj, dtype=dtype), device=device)
Expand All @@ -68,14 +68,19 @@ def empty(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.empty(shape, dtype), device=device)


def empty_like(
x: np.ndarray, /, *, dtype: np.dtype, device: str, out: Optional[np.ndarray] = None
x: np.ndarray,
/,
*,
dtype: np.dtype,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.empty_like(x, dtype=dtype), device=device)

Expand All @@ -88,7 +93,7 @@ def eye(
k: int = 0,
batch_shape: Optional[Union[int, Sequence[int]]] = None,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if n_cols is None:
Expand Down Expand Up @@ -116,7 +121,7 @@ def full(
fill_value: Union[int, float, bool],
*,
dtype: Optional[Union[ivy.Dtype, np.dtype]] = None,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
dtype = ivy.default_dtype(dtype=dtype, item=fill_value, as_native=True)
Expand All @@ -132,7 +137,7 @@ def full_like(
fill_value: Number,
*,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.full_like(x, fill_value, dtype=dtype), device=device)
Expand All @@ -147,7 +152,7 @@ def linspace(
axis: Optional[int] = None,
endpoint: bool = True,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if axis is None:
Expand Down Expand Up @@ -176,14 +181,19 @@ def ones(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.ones(shape, dtype), device=device)


def ones_like(
x: np.ndarray, /, *, dtype: np.dtype, device: str, out: Optional[np.ndarray] = None
x: np.ndarray,
/,
*,
dtype: np.dtype,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.ones_like(x, dtype=dtype), device=device)

Expand All @@ -204,14 +214,19 @@ def zeros(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.zeros(shape, dtype), device=device)


def zeros_like(
x: np.ndarray, /, *, dtype: np.dtype, device: str, out: Optional[np.ndarray] = None
x: np.ndarray,
/,
*,
dtype: np.dtype,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.zeros_like(x, dtype=dtype), device=device)

Expand Down Expand Up @@ -243,7 +258,7 @@ def one_hot(
off_value: Optional[Number] = None,
axis: Optional[int] = None,
dtype: Optional[np.dtype] = None,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
on_none = on_value is None
Expand Down Expand Up @@ -287,7 +302,7 @@ def triu_indices(
k: int = 0,
/,
*,
device: str,
device: str = None,
) -> Tuple[np.ndarray]:
return tuple(
_to_device(np.asarray(np.triu_indices(n=n_rows, k=k, m=n_cols)), device=device)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/numpy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def tril_indices(
k: int = 0,
/,
*,
device: str,
device: str = None,
) -> Tuple[np.ndarray, ...]:
return tuple(
_to_device(np.asarray(np.tril_indices(n=n_rows, k=k, m=n_cols)), device=device)
Expand Down
Loading
Loading