Skip to content

Commit

Permalink
Remove old reduction implementation (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Oct 4, 2024
1 parent 2065794 commit 7ca5ae9
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 173 deletions.
15 changes: 5 additions & 10 deletions cubed/array_api/linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cubed.core import blockwise, reduction, squeeze


def matmul(x1, x2, /, use_new_impl=True, split_every=None):
def matmul(x1, x2, /, split_every=None):
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in matmul")

Expand Down Expand Up @@ -51,9 +51,7 @@ def matmul(x1, x2, /, use_new_impl=True, split_every=None):
dtype=dtype,
)

out = _sum_wo_cat(
out, axis=-2, dtype=dtype, use_new_impl=use_new_impl, split_every=split_every
)
out = _sum_wo_cat(out, axis=-2, dtype=dtype, split_every=split_every)

if x1_is_1d:
out = squeeze(out, -2)
Expand All @@ -68,7 +66,7 @@ def _matmul(a, b):
return chunk[..., nxp.newaxis, :]


def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None):
def _sum_wo_cat(a, axis=None, dtype=None, split_every=None):
if a.shape[axis] == 1:
return squeeze(a, axis)

Expand All @@ -78,7 +76,6 @@ def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None):
_chunk_sum,
axis=axis,
dtype=dtype,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
Expand All @@ -99,7 +96,7 @@ def matrix_transpose(x, /):
return permute_dims(x, axes)


def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
def tensordot(x1, x2, /, *, axes=2, split_every=None):
from cubed.array_api.statistical_functions import sum

if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
Expand Down Expand Up @@ -147,7 +144,6 @@ def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
out,
axis=x1_axes,
dtype=dtype,
use_new_impl=use_new_impl,
split_every=split_every,
)

Expand All @@ -161,7 +157,7 @@ def _tensordot(a, b, axes):
return x


def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
def vecdot(x1, x2, /, *, axis=-1, split_every=None):
# based on the implementation in array-api-compat
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in vecdot")
Expand All @@ -176,7 +172,6 @@ def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
res = matmul(
x1_[..., None, :],
x2_[..., None],
use_new_impl=use_new_impl,
split_every=split_every,
)
return res[..., 0, 0]
6 changes: 2 additions & 4 deletions cubed/array_api/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cubed.core.ops import arg_reduction, elemwise


def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def argmax(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmax")
if axis is None:
Expand All @@ -17,12 +17,11 @@ def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=No
nxp.argmax,
axis=axis,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def argmin(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmin")
if axis is None:
Expand All @@ -34,7 +33,6 @@ def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=No
nxp.argmin,
axis=axis,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)

Expand Down
19 changes: 5 additions & 14 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,20 @@
from cubed.core import reduction


def max(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def max(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in max")
return reduction(
x,
nxp.max,
axis=axis,
dtype=x.dtype,
use_new_impl=use_new_impl,
split_every=split_every,
keepdims=keepdims,
)


def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def mean(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses NumPy and Zarr's structured arrays to store a
Expand All @@ -53,7 +52,6 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
Expand Down Expand Up @@ -108,23 +106,20 @@ def _numel(x, **kwargs):
return nxp.broadcast_to(nxp.asarray(prod, dtype=dtype), new_shape)


def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def min(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in min")
return reduction(
x,
nxp.min,
axis=axis,
dtype=x.dtype,
use_new_impl=use_new_impl,
split_every=split_every,
keepdims=keepdims,
)


def prod(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
Expand All @@ -148,15 +143,12 @@ def prod(
axis=axis,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def sum(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
Expand All @@ -180,7 +172,6 @@ def sum(
axis=axis,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
6 changes: 2 additions & 4 deletions cubed/array_api/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from cubed.core import reduction


def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def all(x, /, *, axis=None, keepdims=False, split_every=None):
if x.size == 0:
return asarray(True, dtype=x.dtype)
return reduction(
Expand All @@ -12,12 +12,11 @@ def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
axis=axis,
dtype=bool,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def any(x, /, *, axis=None, keepdims=False, split_every=None):
if x.size == 0:
return asarray(False, dtype=x.dtype)
return reduction(
Expand All @@ -26,6 +25,5 @@ def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
axis=axis,
dtype=bool,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)
4 changes: 2 additions & 2 deletions cubed/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_blocks, map_direct, reduction_new
from cubed.core.ops import map_blocks, map_direct, reduction
from cubed.utils import array_memory, get_item
from cubed.vendor.dask.array.core import normalize_chunks

Expand Down Expand Up @@ -105,7 +105,7 @@ def wrapper(a, by, **kwargs):
out = expand_dims(out, axis=dummy_axis)

# then reduce across blocks
return reduction_new(
return reduction(
out,
func=None,
combine_func=combine_func,
Expand Down
121 changes: 1 addition & 120 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,122 +1056,6 @@ def key_function(out_key):


def reduction(
x: "Array",
func,
combine_func=None,
aggregate_func=None,
axis=None,
intermediate_dtype=None,
dtype=None,
keepdims=False,
use_new_impl=True,
split_every=None,
extra_func_kwargs=None,
) -> "Array":
"""Apply a function to reduce an array along one or more axes."""
if use_new_impl:
return reduction_new(
x,
func,
combine_func=combine_func,
aggregate_func=aggregate_func,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
if combine_func is None:
combine_func = func
if axis is None:
axis = tuple(range(x.ndim))
if isinstance(axis, Integral):
axis = (axis,)
axis = validate_axis(axis, x.ndim)
if intermediate_dtype is None:
intermediate_dtype = dtype

inds = tuple(range(x.ndim))

result = x
allowed_mem = x.spec.allowed_mem
max_mem = allowed_mem - x.spec.reserved_mem

# reduce initial chunks
args = (result, inds)
adjust_chunks = {
i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks)
}
result = blockwise(
func,
inds,
*args,
axis=axis,
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

# merge/reduce along axis in multiple rounds until there's a single block in each reduction axis
while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis):
# merge along axis
target_chunks = list(result.chunksize)
chunk_mem = array_memory(intermediate_dtype, result.chunksize)
for i, s in enumerate(result.shape):
if i in axis:
assert result.chunksize[i] == 1 # result of reduction
if len(axis) > 1:
# multi-axis: don't exceed original chunksize in any reduction axis
# TODO: improve to use up to max_mem
target_chunks[i] = min(s, x.chunksize[i])
else:
# single axis: see how many result chunks fit in max_mem
# factor of 4 is memory for {compressed, uncompressed} x {input, output}
target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4)
if target_chunk_size <= 1:
raise ValueError(
f"Not enough memory for reduction. Increase allowed_mem ({allowed_mem}) or decrease chunk size"
)
target_chunks[i] = min(s, target_chunk_size)
_target_chunks = tuple(target_chunks)
result = merge_chunks(result, _target_chunks)

# reduce chunks (if any axis chunksize is > 1)
if any(s > 1 for i, s in enumerate(result.chunksize) if i in axis):
args = (result, inds)
adjust_chunks = {
i: (1,) * len(c) if i in axis else c
for i, c in enumerate(result.chunks)
}
result = blockwise(
combine_func,
inds,
*args,
axis=axis,
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

if aggregate_func is not None:
result = map_blocks(aggregate_func, result, dtype=dtype)

if not keepdims:
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
if len(axis_to_squeeze) > 0:
result = squeeze(result, axis_to_squeeze)

from cubed.array_api import astype

result = astype(result, dtype, copy=False)

return result


def reduction_new(
x: "Array",
func,
combine_func=None,
Expand Down Expand Up @@ -1426,9 +1310,7 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None):
return result


def arg_reduction(
x, /, arg_func, axis=None, *, keepdims=False, use_new_impl=True, split_every=None
):
def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False, split_every=None):
"""A reduction that returns the array indexes, not the values."""
dtype = nxp.int64 # index data type
intermediate_dtype = [("i", dtype), ("v", x.dtype)]
Expand All @@ -1454,7 +1336,6 @@ def arg_reduction(
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)

Expand Down
Loading

0 comments on commit 7ca5ae9

Please sign in to comment.