Skip to content

Commit

Permalink
Merge pull request #1237 from IntelPython/inplace-operator-initial-impl
Browse files Browse the repository at this point in the history
In-place addition, multiplication, subtraction of usm_ndarrays
  • Loading branch information
ndgrigorian authored Jun 13, 2023
2 parents 43f3b7b + da5f2f7 commit 81553f8
Show file tree
Hide file tree
Showing 14 changed files with 1,378 additions and 23 deletions.
134 changes: 132 additions & 2 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_empty_like_pair_orderK,
_find_buf_dtype,
_find_buf_dtype2,
_find_inplace_dtype,
_to_device_supported_dtype,
)

Expand Down Expand Up @@ -331,11 +332,19 @@ class BinaryElementwiseFunc:
Class that implements binary element-wise functions.
"""

def __init__(self, name, result_type_resolver_fn, binary_dp_impl_fn, docs):
def __init__(
self,
name,
result_type_resolver_fn,
binary_dp_impl_fn,
docs,
binary_inplace_fn=None,
):
self.__name__ = "BinaryElementwiseFunc"
self.name_ = name
self.result_type_resolver_fn_ = result_type_resolver_fn
self.binary_fn_ = binary_dp_impl_fn
self.binary_inplace_fn_ = binary_inplace_fn
self.__doc__ = docs

def __str__(self):
Expand All @@ -345,6 +354,13 @@ def __repr__(self):
return f"<BinaryElementwiseFunc '{self.name_}'>"

def __call__(self, o1, o2, out=None, order="K"):
# FIXME: replace with check against base array
# when views can be identified
if o1 is out:
return self._inplace(o1, o2)
elif o2 is out:
return self._inplace(o2, o1)

if order not in ["K", "C", "F", "A"]:
order = "K"
q1, o1_usm_type = _get_queue_usm_type(o1)
Expand Down Expand Up @@ -388,6 +404,7 @@ def __call__(self, o1, o2, out=None, order="K"):
raise TypeError(
"Shape of arguments can not be inferred. "
"Arguments are expected to be "
"lists, tuples, or both"
)
try:
res_shape = _broadcast_shape_impl(
Expand Down Expand Up @@ -415,7 +432,7 @@ def __call__(self, o1, o2, out=None, order="K"):

if res_dt is None:
raise TypeError(
"function 'add' does not support input types "
f"function '{self.name_}' does not support input types "
f"({o1_dtype}, {o2_dtype}), "
"and the inputs could not be safely coerced to any "
"supported types according to the casting rule ''safe''."
Expand Down Expand Up @@ -631,3 +648,116 @@ def __call__(self, o1, o2, out=None, order="K"):
)
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
return out

def _inplace(self, lhs, val):
if self.binary_inplace_fn_ is None:
raise ValueError(
f"In-place operation not supported for ufunc '{self.name_}'"
)
if not isinstance(lhs, dpt.usm_ndarray):
raise TypeError(
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
)
q1, lhs_usm_type = _get_queue_usm_type(lhs)
q2, val_usm_type = _get_queue_usm_type(val)
if q2 is None:
exec_q = q1
usm_type = lhs_usm_type
else:
exec_q = dpctl.utils.get_execution_queue((q1, q2))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
usm_type = dpctl.utils.get_coerced_usm_type(
(
lhs_usm_type,
val_usm_type,
)
)
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
lhs_shape = _get_shape(lhs)
val_shape = _get_shape(val)
if not all(
isinstance(s, (tuple, list))
for s in (
lhs_shape,
val_shape,
)
):
raise TypeError(
"Shape of arguments can not be inferred. "
"Arguments are expected to be "
"lists, tuples, or both"
)
try:
res_shape = _broadcast_shape_impl(
[
lhs_shape,
val_shape,
]
)
except ValueError:
raise ValueError(
"operands could not be broadcast together with shapes "
f"{lhs_shape} and {val_shape}"
)
if res_shape != lhs_shape:
raise ValueError(
f"output shape {lhs_shape} does not match "
f"broadcast shape {res_shape}"
)
sycl_dev = exec_q.sycl_device
lhs_dtype = lhs.dtype
val_dtype = _get_dtype(val, sycl_dev)
if not _validate_dtype(val_dtype):
raise ValueError("Input operand of unsupported type")

lhs_dtype, val_dtype = _resolve_weak_types(
lhs_dtype, val_dtype, sycl_dev
)

buf_dt = _find_inplace_dtype(
lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
)

if buf_dt is None:
raise TypeError(
f"In-place '{self.name_}' does not support input types "
f"({lhs_dtype}, {val_dtype}), "
"and the inputs could not be safely coerced to any "
"supported types according to the casting rule ''safe''."
)

if isinstance(val, dpt.usm_ndarray):
rhs = val
overlap = ti._array_overlap(lhs, rhs)
else:
rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
overlap = False

if buf_dt == val_dtype and overlap is False:
rhs = dpt.broadcast_to(rhs, res_shape)
ht_, _ = self.binary_inplace_fn_(
lhs=lhs, rhs=rhs, sycl_queue=exec_q
)
ht_.wait()

else:
buf = dpt.empty_like(rhs, dtype=buf_dt)
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=rhs, dst=buf, sycl_queue=exec_q
)

buf = dpt.broadcast_to(buf, res_shape)
ht_, _ = self.binary_inplace_fn_(
lhs=lhs,
rhs=buf,
sycl_queue=exec_q,
depends=[copy_ev],
)
ht_copy_ev.wait()
ht_.wait()

return lhs
18 changes: 15 additions & 3 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@
returned array is determined by the Type Promotion Rules.
"""
add = BinaryElementwiseFunc(
"add", ti._add_result_type, ti._add, _add_docstring_
"add",
ti._add_result_type,
ti._add,
_add_docstring_,
binary_inplace_fn=ti._add_inplace,
)

# U04: ===== ASIN (x)
Expand Down Expand Up @@ -603,7 +607,11 @@
the returned array is determined by the Type Promotion Rules.
"""
multiply = BinaryElementwiseFunc(
"multiply", ti._multiply_result_type, ti._multiply, _multiply_docstring_
"multiply",
ti._multiply_result_type,
ti._multiply,
_multiply_docstring_,
ti._multiply_inplace,
)

# U25: ==== NEGATIVE (x)
Expand Down Expand Up @@ -782,7 +790,11 @@
of the returned array is determined by the Type Promotion Rules.
"""
subtract = BinaryElementwiseFunc(
"subtract", ti._subtract_result_type, ti._subtract, _subtract_docstring_
"subtract",
ti._subtract_result_type,
ti._subtract,
_subtract_docstring_,
ti._subtract_inplace,
)


Expand Down
18 changes: 18 additions & 0 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,27 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
return None, None, None


def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
res_dt = query_fn(lhs_dtype, rhs_dtype)
if res_dt and res_dt == lhs_dtype:
return rhs_dtype

_fp16 = sycl_dev.has_aspect_fp16
_fp64 = sycl_dev.has_aspect_fp64
all_dts = _all_data_types(_fp16, _fp64)
for buf_dt in all_dts:
if _can_cast(rhs_dtype, buf_dt, _fp16, _fp64):
res_dt = query_fn(lhs_dtype, buf_dt)
if res_dt and res_dt == lhs_dtype:
return buf_dt

return None


__all__ = [
"_find_buf_dtype",
"_find_buf_dtype2",
"_find_inplace_dtype",
"_empty_like_orderK",
"_empty_like_pair_orderK",
"_to_device_supported_dtype",
Expand Down
21 changes: 6 additions & 15 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1245,11 +1245,8 @@ cdef class usm_ndarray:
return _dispatch_binary_elementwise2(other, "logical_xor", self)

def __iadd__(self, other):
res = self.__add__(other)
if res is NotImplemented:
return res
self.__setitem__(Ellipsis, res)
return self
from ._elementwise_funcs import add
return add._inplace(self, other)

def __iand__(self, other):
res = self.__and__(other)
Expand Down Expand Up @@ -1287,11 +1284,8 @@ cdef class usm_ndarray:
return self

def __imul__(self, other):
res = self.__mul__(other)
if res is NotImplemented:
return res
self.__setitem__(Ellipsis, res)
return self
from ._elementwise_funcs import multiply
return multiply._inplace(self, other)

def __ior__(self, other):
res = self.__or__(other)
Expand All @@ -1315,11 +1309,8 @@ cdef class usm_ndarray:
return self

def __isub__(self, other):
res = self.__sub__(other)
if res is NotImplemented:
return res
self.__setitem__(Ellipsis, res)
return self
from ._elementwise_funcs import subtract
return subtract._inplace(self, other)

def __itruediv__(self, other):
res = self.__truediv__(other)
Expand Down
Loading

0 comments on commit 81553f8

Please sign in to comment.