Skip to content

Commit

Permalink
Do not use Numba objmode for supported AdvancedSubtensor operations
Browse files Browse the repository at this point in the history
Use ScalarTypes in MakeSlice for compatibility with Numba
  • Loading branch information
ricardoV94 committed Jun 21, 2024
1 parent a9c52dd commit a14cb2b
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 86 deletions.
113 changes: 43 additions & 70 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import warnings

import numba
import numpy as np

from pytensor.graph import Type
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand All @@ -17,7 +15,10 @@
)


def create_index_func(node, objmode=False):
@numba_funcify.register(Subtensor)
@numba_funcify.register(IncSubtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array."""

unique_names = unique_name_generator(
Expand All @@ -40,13 +41,13 @@ def convert_indices(indices, entry):
raise ValueError()

set_or_inc = isinstance(
node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
)
index_start_idx = 1 + int(set_or_inc)

input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None)
idx_list = getattr(op, "idx_list", None)

indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
Expand All @@ -61,8 +62,7 @@ def convert_indices(indices, entry):
indices_creation_src = f"indices = ({indices_creation_src})"

if set_or_inc:
fn_name = "incsubtensor"
if node.op.inplace:
if op.inplace:
index_prologue = f"z = {input_names[0]}"
else:
index_prologue = f"z = np.copy({input_names[0]})"
Expand All @@ -74,84 +74,57 @@ def convert_indices(indices, entry):
else:
y_name = input_names[1]

if node.op.set_instead_of_inc:
if op.set_instead_of_inc:
function_name = "setsubtensor"
index_body = f"z[indices] = {y_name}"
else:
function_name = "incsubtensor"
index_body = f"z[indices] += {y_name}"
else:
fn_name = "subtensor"
function_name = "subtensor"
index_prologue = ""
index_body = f"z = {input_names[0]}[indices]"

if objmode:
output_var = node.outputs[0]

if not set_or_inc:
# Since `z` is being "created" while in object mode, it's
# considered an "outgoing" variable and needs to be manually typed
output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'"
else:
output_sig = ""

index_body = f"""
with objmode({output_sig}):
{index_body}
"""

subtensor_def_src = f"""
def {fn_name}({", ".join(input_names)}):
def {function_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src}
{index_body}
return np.asarray(z)
"""

return subtensor_def_src


@numba_funcify.register(Subtensor)
@numba_funcify.register(AdvancedSubtensor1)
def numba_funcify_Subtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedSubtensor."),
UserWarning,
)

subtensor_def_src = create_index_func(node, objmode=objmode)

global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode

subtensor_fn = compile_function_src(
subtensor_def_src, "subtensor", {**globals(), **global_env}
func = compile_function_src(
subtensor_def_src,
function_name=function_name,
global_env=globals() | {"np": np},
)

return numba_njit(subtensor_fn, boundscheck=True)


@numba_funcify.register(IncSubtensor)
def numba_funcify_IncSubtensor(op, node, **kwargs):
objmode = isinstance(op, AdvancedIncSubtensor)
if objmode:
warnings.warn(
("Numba will use object mode to allow run " "AdvancedIncSubtensor."),
UserWarning,
return numba_njit(func, boundscheck=True)


@numba_funcify.register(AdvancedSubtensor)
@numba_funcify.register(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:]
adv_idxs_dims = [
idx.type.ndim
for idx in idxs
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
]

if (
# Numba does not support indexes with more than one dimension
# Nor multiple vector indexes
(len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1)
# The default index implementation does not handle duplicate indices correctly
or (
isinstance(op, AdvancedIncSubtensor)
and not op.set_instead_of_inc
and not op.ignore_duplicates
)
):
return generate_fallback_impl(op, node, **kwargs)

incsubtensor_def_src = create_index_func(node, objmode=objmode)

global_env = {"np": np}
if objmode:
global_env["objmode"] = numba.objmode

incsubtensor_fn = compile_function_src(
incsubtensor_def_src, "incsubtensor", {**globals(), **global_env}
)

return numba_njit(incsubtensor_fn, boundscheck=True)
return numba_funcify_default_subtensor(op, node, **kwargs)


@numba_funcify.register(AdvancedIncSubtensor1)
Expand Down
19 changes: 16 additions & 3 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
from pytensor.tensor.basic import (
ScalarFromTensor,
alloc,
get_underlying_scalar_constant_value,
nonzero,
)
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
Expand Down Expand Up @@ -168,8 +173,16 @@ def as_index_literal(
if isinstance(idx, Constant):
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data

if isinstance(getattr(idx, "type", None), SliceType):
idx = slice(*idx.owner.inputs)
if isinstance(idx, Variable):
if (
isinstance(idx.type, ps.ScalarType)
and idx.owner
and isinstance(idx.owner.op, ScalarFromTensor)
):
return as_index_literal(idx.owner.inputs[0])

if isinstance(idx.type, SliceType):
idx = slice(*idx.owner.inputs)

if isinstance(idx, slice):
return slice(
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/type_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def as_int_none_variable(x):
return NoneConst
elif NoneConst.equals(x):
return x
x = pytensor.tensor.as_tensor_variable(x, ndim=0)
x = pytensor.scalar.as_scalar(x)
if x.type.dtype not in integer_dtypes:
raise TypeError("index must be integers")
return x
Expand Down
Loading

0 comments on commit a14cb2b

Please sign in to comment.