From a14cb2bd4ebda0d25eee2811127a6c9d02f89a0d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 5 Jun 2024 17:37:56 +0200 Subject: [PATCH] Do not use Numba objmode for supported AdvancedSubtensor operations Use ScalarTypes in MakeSlice for compatibility with Numba --- pytensor/link/numba/dispatch/subtensor.py | 113 +++++++---------- pytensor/tensor/subtensor.py | 19 ++- pytensor/tensor/type_other.py | 2 +- tests/link/numba/test_subtensor.py | 145 ++++++++++++++++++++-- 4 files changed, 193 insertions(+), 86 deletions(-) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 3d2f3f2901..178ce0b857 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -1,12 +1,10 @@ -import warnings - -import numba import numpy as np from pytensor.graph import Type from pytensor.link.numba.dispatch import numba_funcify -from pytensor.link.numba.dispatch.basic import numba_njit +from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit from pytensor.link.utils import compile_function_src, unique_name_generator +from pytensor.tensor import TensorType from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -17,7 +15,10 @@ ) -def create_index_func(node, objmode=False): +@numba_funcify.register(Subtensor) +@numba_funcify.register(IncSubtensor) +@numba_funcify.register(AdvancedSubtensor1) +def numba_funcify_default_subtensor(op, node, **kwargs): """Create a Python function that assembles and uses an index on an array.""" unique_names = unique_name_generator( @@ -40,13 +41,13 @@ def convert_indices(indices, entry): raise ValueError() set_or_inc = isinstance( - node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor + op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor ) index_start_idx = 1 + int(set_or_inc) input_names = [unique_names(v, force_unique=True) for v in node.inputs] op_indices = list(node.inputs[index_start_idx:]) - idx_list = getattr(node.op, "idx_list", None) + idx_list = getattr(op, "idx_list", None) indices_creation_src = ( tuple(convert_indices(op_indices, idx) for idx in idx_list) @@ -61,8 +62,7 @@ def convert_indices(indices, entry): indices_creation_src = f"indices = ({indices_creation_src})" if set_or_inc: - fn_name = "incsubtensor" - if node.op.inplace: + if op.inplace: index_prologue = f"z = {input_names[0]}" else: index_prologue = f"z = np.copy({input_names[0]})" @@ -74,84 +74,57 @@ def convert_indices(indices, entry): else: y_name = input_names[1] - if node.op.set_instead_of_inc: + if op.set_instead_of_inc: + function_name = "setsubtensor" index_body = f"z[indices] = {y_name}" else: + function_name = "incsubtensor" index_body = f"z[indices] += {y_name}" else: - fn_name = "subtensor" + function_name = "subtensor" index_prologue = "" index_body = f"z = {input_names[0]}[indices]" - if objmode: - output_var = node.outputs[0] - - if not set_or_inc: - # Since `z` is being "created" while in object mode, it's - # considered an "outgoing" variable and needs to be manually typed - output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'" - else: - output_sig = "" - - index_body = f""" - with objmode({output_sig}): - {index_body} - """ - subtensor_def_src = f""" -def {fn_name}({", ".join(input_names)}): +def {function_name}({", ".join(input_names)}): {index_prologue} {indices_creation_src} {index_body} return np.asarray(z) """ - return subtensor_def_src - - -@numba_funcify.register(Subtensor) -@numba_funcify.register(AdvancedSubtensor1) -def numba_funcify_Subtensor(op, node, **kwargs): - objmode = isinstance(op, AdvancedSubtensor) - if objmode: - warnings.warn( - ("Numba will use object mode to allow run " "AdvancedSubtensor."), - UserWarning, - ) - - subtensor_def_src = create_index_func(node, objmode=objmode) - - global_env = {"np": np} - if objmode: - global_env["objmode"] = numba.objmode - - subtensor_fn = compile_function_src( - subtensor_def_src, "subtensor", {**globals(), **global_env} + func = compile_function_src( + subtensor_def_src, + function_name=function_name, + global_env=globals() | {"np": np}, ) - - return numba_njit(subtensor_fn, boundscheck=True) - - -@numba_funcify.register(IncSubtensor) -def numba_funcify_IncSubtensor(op, node, **kwargs): - objmode = isinstance(op, AdvancedIncSubtensor) - if objmode: - warnings.warn( - ("Numba will use object mode to allow run " "AdvancedIncSubtensor."), - UserWarning, + return numba_njit(func, boundscheck=True) + + +@numba_funcify.register(AdvancedSubtensor) +@numba_funcify.register(AdvancedIncSubtensor) +def numba_funcify_AdvancedSubtensor(op, node, **kwargs): + idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:] + adv_idxs_dims = [ + idx.type.ndim + for idx in idxs + if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) + ] + + if ( + # Numba does not support indexes with more than one dimension + # Nor multiple vector indexes + (len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1) + # The default index implementation does not handle duplicate indices correctly + or ( + isinstance(op, AdvancedIncSubtensor) + and not op.set_instead_of_inc + and not op.ignore_duplicates ) + ): + return generate_fallback_impl(op, node, **kwargs) - incsubtensor_def_src = create_index_func(node, objmode=objmode) - - global_env = {"np": np} - if objmode: - global_env["objmode"] = numba.objmode - - incsubtensor_fn = compile_function_src( - incsubtensor_def_src, "incsubtensor", {**globals(), **global_env} - ) - - return numba_njit(incsubtensor_fn, boundscheck=True) + return numba_funcify_default_subtensor(op, node, **kwargs) @numba_funcify.register(AdvancedIncSubtensor1) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 64543170c0..e223077f81 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -21,7 +21,12 @@ from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length -from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero +from pytensor.tensor.basic import ( + ScalarFromTensor, + alloc, + get_underlying_scalar_constant_value, + nonzero, +) from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError @@ -168,8 +173,16 @@ def as_index_literal( if isinstance(idx, Constant): return idx.data.item() if isinstance(idx, np.ndarray) else idx.data - if isinstance(getattr(idx, "type", None), SliceType): - idx = slice(*idx.owner.inputs) + if isinstance(idx, Variable): + if ( + isinstance(idx.type, ps.ScalarType) + and idx.owner + and isinstance(idx.owner.op, ScalarFromTensor) + ): + return as_index_literal(idx.owner.inputs[0]) + + if isinstance(idx.type, SliceType): + idx = slice(*idx.owner.inputs) if isinstance(idx, slice): return slice( diff --git a/pytensor/tensor/type_other.py b/pytensor/tensor/type_other.py index 593204b1ef..bc293d8906 100644 --- a/pytensor/tensor/type_other.py +++ b/pytensor/tensor/type_other.py @@ -18,7 +18,7 @@ def as_int_none_variable(x): return NoneConst elif NoneConst.equals(x): return x - x = pytensor.tensor.as_tensor_variable(x, ndim=0) + x = pytensor.scalar.as_scalar(x) if x.type.dtype not in integer_dtypes: raise TypeError("index must be integers") return x diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index 87f1300bfb..5e1784f368 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -1,6 +1,9 @@ +import contextlib + import numpy as np import pytest +import pytensor.tensor as pt from pytensor.graph import FunctionGraph from pytensor.tensor import as_tensor from pytensor.tensor.subtensor import ( @@ -48,8 +51,8 @@ def test_Subtensor(x, indices): @pytest.mark.parametrize( "x, indices", [ - (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), - (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)), + (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), + (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)), ], ) def test_AdvancedSubtensor1(x, indices): @@ -69,21 +72,46 @@ def test_AdvancedSubtensor1_out_of_bounds(): @pytest.mark.parametrize( - "x, indices", + "x, indices, objmode_needed", [ - (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (0, [1, 2, 2, 3]), + False, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (np.array([True, False, False])), + False, + ), + ( + as_tensor(np.arange(3 * 3).reshape((3, 3))), + (np.eye(3).astype(bool)), + True, + ), + (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], slice(None), [3, 4]), + True, ), ], ) -def test_AdvancedSubtensor(x, indices): +@pytest.mark.filterwarnings("error") +def test_AdvancedSubtensor(x, indices, objmode_needed): """Test NumPy's advanced indexing in more than one dimension.""" out_pt = x[indices] assert isinstance(out_pt.owner.op, AdvancedSubtensor) out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedSubtensor's perform method", + ) + if objmode_needed + else contextlib.nullcontext() + ): + compare_numba_and_py(out_fg, []) @pytest.mark.parametrize( @@ -194,35 +222,120 @@ def test_AdvancedIncSubtensor1(x, y, indices): @pytest.mark.parametrize( - "x, y, indices", + "x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode", [ + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -np.arange(3 * 5).reshape(3, 5), + (slice(None, None, 2), [1, 2, 3]), + False, + False, + False, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -99, + (slice(None, None, 2), [1, 2, 3], -1), + False, + False, + False, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -99, # Broadcasted value + (slice(None, None, 2), [1, 2, 3]), + False, + False, + False, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -np.arange(4 * 5).reshape(4, 5), + (0, [1, 2, 2, 3]), + True, + False, + True, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + [-99], # Broadcsasted value + (0, [1, 2, 2, 3]), + True, + False, + True, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -np.arange(1 * 4 * 5).reshape(1, 4, 5), + (np.array([True, False, False])), + False, + False, + False, + ), + ( + as_tensor(np.arange(3 * 3).reshape((3, 3))), + -np.arange(3), + (np.eye(3).astype(bool)), + False, + True, + True, + ), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(rng.poisson(size=(2, 5))), ([1, 2], [2, 3]), + False, + True, + True, ), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(rng.poisson(size=(2, 4))), ([1, 2], slice(None), [3, 4]), + False, + True, + True, ), pytest.param( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(rng.poisson(size=(2, 5))), ([1, 1], [2, 2]), + False, + True, + True, ), ], ) -def test_AdvancedIncSubtensor(x, y, indices): +@pytest.mark.filterwarnings("error") +def test_AdvancedIncSubtensor( + x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode +): out_pt = set_subtensor(x[indices], y) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - out_pt = inc_subtensor(x[indices], y) + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedSetSubtensor's perform method", + ) + if set_requires_objmode + else contextlib.nullcontext() + ): + compare_numba_and_py(out_fg, []) + + out_pt = inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedIncSubtensor's perform method", + ) + if inc_requires_objmode + else contextlib.nullcontext() + ): + compare_numba_and_py(out_fg, []) x_pt = x.type() out_pt = set_subtensor(x_pt[indices], y) @@ -231,4 +344,12 @@ def test_AdvancedIncSubtensor(x, y, indices): out_pt.owner.op.inplace = True assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data]) + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedSetSubtensor's perform method", + ) + if set_requires_objmode + else contextlib.nullcontext() + ): + compare_numba_and_py(out_fg, [x.data])