Skip to content

Commit

Permalink
refactor: removed infer_device and renamed handle_device_shifting dec…
Browse files Browse the repository at this point in the history
…orator (ivy-llc#23373)
  • Loading branch information
ShreyanshBardia authored and iababio committed Sep 27, 2023
1 parent 64d7409 commit a38e634
Show file tree
Hide file tree
Showing 47 changed files with 514 additions and 562 deletions.
50 changes: 8 additions & 42 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
# for wrapping (sequence matters)
FN_DECORATORS = [
"handle_complex_input",
"infer_device",
"handle_device_shifting",
"handle_device",
"infer_dtype",
"handle_array_function",
"outputs_to_ivy_arrays",
Expand Down Expand Up @@ -778,42 +777,9 @@ def _infer_dtype(*args, dtype=None, **kwargs):
# ----------------#


def infer_device(fn: Callable) -> Callable:
def handle_device(fn: Callable) -> Callable:
@functools.wraps(fn)
def _infer_device(*args, device=None, **kwargs):
"""
Determine the correct `device`, and then calls the function with the `device`
passed explicitly.
Parameters
----------
args
The arguments to be passed to the function.
device
The device for the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, with `device` passed explicitly.
"""
# find the first array argument, if required
arr = None if ivy.exists(device) else _get_first_array(*args, **kwargs)
# infer the correct device
device = ivy.default_device(device, item=arr, as_native=True)
# call the function with device provided explicitly
return fn(*args, device=device, **kwargs)

_infer_device.infer_device = True
return _infer_device


def handle_device_shifting(fn: Callable) -> Callable:
@functools.wraps(fn)
def _handle_device_shifting(*args, **kwargs):
def _handle_device(*args, **kwargs):
"""
Move all array inputs of the function to `ivy.default_device()`.
Expand Down Expand Up @@ -856,8 +822,8 @@ def _handle_device_shifting(*args, **kwargs):
)
return fn(*args, **kwargs)

_handle_device_shifting.handle_device_shifting = True
return _handle_device_shifting
_handle_device.handle_device = True
return _handle_device


# Inplace Update Handling #
Expand Down Expand Up @@ -1071,9 +1037,9 @@ def _wrap_function(
"""
Apply wrapping to backend implementation `to_wrap` if the original implementation
`original` is also wrapped, and if `to_wrap` is not already wrapped. Attributes
`handle_nestable`, `infer_device` etc are set during wrapping, hence indicate to us
whether a certain function has been wrapped or not. Also handles wrapping of the
`linalg` namespace.
`handle_nestable` etc are set during wrapping, hence indicate to us whether a
certain function has been wrapped or not. Also handles wrapping of the `linalg`
namespace.
Parameters
----------
Expand Down
28 changes: 14 additions & 14 deletions ivy/functional/backends/jax/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def arange(
step: float = 1,
*,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if dtype:
Expand Down Expand Up @@ -69,7 +69,7 @@ def asarray(
*,
copy: Optional[bool] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
ivy.utils.assertions._check_jax_x64_flag(dtype)
Expand All @@ -83,7 +83,7 @@ def empty(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.empty(shape, dtype)
Expand All @@ -94,7 +94,7 @@ def empty_like(
/,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.empty_like(x, dtype=dtype)
Expand All @@ -108,7 +108,7 @@ def eye(
k: int = 0,
batch_shape: Optional[Union[int, Sequence[int]]] = None,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if n_cols is None:
Expand All @@ -135,7 +135,7 @@ def full(
fill_value: Union[int, float, bool],
*,
dtype: Optional[Union[ivy.Dtype, jnp.dtype]] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
dtype = ivy.default_dtype(dtype=dtype, item=fill_value, as_native=True)
Expand All @@ -148,7 +148,7 @@ def full_like(
fill_value: Number,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.full_like(x, fill_value, dtype=dtype)
Expand All @@ -165,7 +165,7 @@ def linspace(
axis: Optional[int] = None,
endpoint: bool = True,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if axis is None:
Expand Down Expand Up @@ -239,7 +239,7 @@ def ones(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.ones(shape, dtype)
Expand All @@ -250,7 +250,7 @@ def ones_like(
/,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.ones_like(x, dtype=dtype)
Expand All @@ -268,7 +268,7 @@ def zeros(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.zeros(shape, dtype)
Expand All @@ -279,7 +279,7 @@ def zeros_like(
/,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.zeros_like(x, dtype=dtype)
Expand Down Expand Up @@ -314,7 +314,7 @@ def one_hot(
off_value: Optional[Number] = None,
axis: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
on_none = on_value is None
Expand Down Expand Up @@ -356,6 +356,6 @@ def triu_indices(
k: int = 0,
/,
*,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
) -> Tuple[JaxArray]:
return jnp.triu_indices(n=n_rows, k=k, m=n_cols)
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def tril_indices(
k: int = 0,
/,
*,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
) -> Tuple[JaxArray, ...]:
return jnp.tril_indices(n=n_rows, k=k, m=n_cols)

Expand Down
8 changes: 4 additions & 4 deletions ivy/functional/backends/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def random_uniform(
low: Union[float, JaxArray] = 0.0,
high: Union[float, JaxArray] = 1.0,
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
dtype: jnp.dtype,
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
Expand All @@ -66,7 +66,7 @@ def random_normal(
mean: Union[float, JaxArray] = 0.0,
std: Union[float, JaxArray] = 1.0,
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
dtype: jnp.dtype,
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
Expand All @@ -91,7 +91,7 @@ def multinomial(
batch_size: int = 1,
probs: Optional[JaxArray] = None,
replace: bool = True,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
Expand Down Expand Up @@ -132,7 +132,7 @@ def randint(
/,
*,
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device = None,
dtype: Optional[Union[jnp.dtype, ivy.Dtype]] = None,
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
Expand Down
43 changes: 29 additions & 14 deletions ivy/functional/backends/numpy/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def arange(
step: float = 1,
*,
dtype: Optional[np.dtype] = None,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if dtype:
Expand All @@ -57,7 +57,7 @@ def asarray(
*,
copy: Optional[bool] = None,
dtype: Optional[np.dtype] = None,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
ret = _to_device(np.asarray(obj, dtype=dtype), device=device)
Expand All @@ -68,14 +68,19 @@ def empty(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.empty(shape, dtype), device=device)


def empty_like(
x: np.ndarray, /, *, dtype: np.dtype, device: str, out: Optional[np.ndarray] = None
x: np.ndarray,
/,
*,
dtype: np.dtype,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.empty_like(x, dtype=dtype), device=device)

Expand All @@ -88,7 +93,7 @@ def eye(
k: int = 0,
batch_shape: Optional[Union[int, Sequence[int]]] = None,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if n_cols is None:
Expand Down Expand Up @@ -116,7 +121,7 @@ def full(
fill_value: Union[int, float, bool],
*,
dtype: Optional[Union[ivy.Dtype, np.dtype]] = None,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
dtype = ivy.default_dtype(dtype=dtype, item=fill_value, as_native=True)
Expand All @@ -132,7 +137,7 @@ def full_like(
fill_value: Number,
*,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.full_like(x, fill_value, dtype=dtype), device=device)
Expand All @@ -147,7 +152,7 @@ def linspace(
axis: Optional[int] = None,
endpoint: bool = True,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if axis is None:
Expand Down Expand Up @@ -176,14 +181,19 @@ def ones(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.ones(shape, dtype), device=device)


def ones_like(
x: np.ndarray, /, *, dtype: np.dtype, device: str, out: Optional[np.ndarray] = None
x: np.ndarray,
/,
*,
dtype: np.dtype,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.ones_like(x, dtype=dtype), device=device)

Expand All @@ -204,14 +214,19 @@ def zeros(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: np.dtype,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.zeros(shape, dtype), device=device)


def zeros_like(
x: np.ndarray, /, *, dtype: np.dtype, device: str, out: Optional[np.ndarray] = None
x: np.ndarray,
/,
*,
dtype: np.dtype,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return _to_device(np.zeros_like(x, dtype=dtype), device=device)

Expand Down Expand Up @@ -243,7 +258,7 @@ def one_hot(
off_value: Optional[Number] = None,
axis: Optional[int] = None,
dtype: Optional[np.dtype] = None,
device: str,
device: str = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
on_none = on_value is None
Expand Down Expand Up @@ -287,7 +302,7 @@ def triu_indices(
k: int = 0,
/,
*,
device: str,
device: str = None,
) -> Tuple[np.ndarray]:
return tuple(
_to_device(np.asarray(np.triu_indices(n=n_rows, k=k, m=n_cols)), device=device)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/numpy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def tril_indices(
k: int = 0,
/,
*,
device: str,
device: str = None,
) -> Tuple[np.ndarray, ...]:
return tuple(
_to_device(np.asarray(np.tril_indices(n=n_rows, k=k, m=n_cols)), device=device)
Expand Down
Loading

0 comments on commit a38e634

Please sign in to comment.