diff --git a/pytensor/sparse/sandbox/sp.py b/pytensor/sparse/sandbox/sp.py index fb945c8fc1..22cc8b6d62 100644 --- a/pytensor/sparse/sandbox/sp.py +++ b/pytensor/sparse/sandbox/sp.py @@ -19,7 +19,6 @@ from pytensor.tensor.math import dot from pytensor.tensor.math import max as pt_max from pytensor.tensor.shape import reshape -from pytensor.tensor.subtensor import DimShuffle def register_specialize(lopt, *tags, **kwargs): @@ -375,7 +374,7 @@ def convolve( [images.shape[0], pt.as_tensor(np.prod(outshp)), pt.as_tensor(nkern)] ) tensout = reshape(output, newshp, ndim=3) - output = DimShuffle((False,) * tensout.ndim, (0, 2, 1))(tensout) + output = tensout.transpose(0, 2, 1) if flatten: output = pt.flatten(output, 2) @@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp): ) out2 = reshape(out1, pshape, ndim=3) - out3 = DimShuffle(out2.broadcastable, (0, 2, 1))(out2) + out3 = out2.transpose(0, 2, 1) return pt.flatten(out3, 2), outshp diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 9295a130c2..48ec908c7e 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2042,7 +2042,7 @@ def transpose(x, axes=None): # No-op return _x - ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x) + ret = _x.dimshuffle(axes) if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)): ret.name = _x.name + ".T" @@ -3518,7 +3518,7 @@ def grad(self, inp, grads): newdims.append(i) i += 1 - gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx) + gx = gx.dimshuffle(newdims) assert gx.type.ndim == x.type.ndim assert all( s1 == s2 diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 0d524c709e..53302c28c4 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,5 +1,7 @@ +from collections.abc import Sequence from copy import copy from textwrap import dedent +from typing import Literal import numpy as np from numpy.core.numeric import normalize_axis_tuple @@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp): Parameters ---------- - input_broadcastable - The expected broadcastable pattern of the input + input_ndim + The expected number of dimension of the input new_order A list representing the relationship between the input's dimensions and the output's dimensions. Each element of the list can either be an index or 'x'. Indices must be encoded as python integers, not pytensor symbolic integers. - inplace : bool, optional - If True (default), the output will be a view of the input. + Missing indexes correspond to drop dimensions. Notes ----- @@ -77,10 +78,10 @@ class DimShuffle(ExternalCOp): .. code-block:: python - DimShuffle((False, False, False), ["x", 2, "x", 0, 1]) + DimShuffle(input_ndim=3, new_order=["x", 2, "x", 0, 1]) - This `Op` will only work on 3d tensors with no broadcastable - dimensions. The first dimension will be broadcastable, + This `Op` will only work on 3d tensors. + The first dimension of the output will be broadcastable, then we will have the third dimension of the input tensor as the second of the resulting tensor, etc. If the tensor has shape (20, 30, 40), the resulting tensor will have dimensions @@ -88,39 +89,36 @@ class DimShuffle(ExternalCOp): .. code-block:: python - DimShuffle((True, False), [1]) + DimShuffle(input_ndim=2, new_order=[1]) - This `Op` will only work on 2d tensors with the first dimension - broadcastable. - The second dimension of the input tensor will be the first dimension of - the resulting tensor. - If the tensor has shape (1, 20), the resulting tensor will have shape - (20, ). + This `Op` will only work on 2d tensors with the first dimension broadcastable. + The second dimension of the input tensor will be the first dimension of the resulting tensor. + If the tensor has shape (1, 20), the resulting tensor will have shape (20, ). Examples -------- .. code-block:: python - DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector - DimShuffle((False, False), [0, 1]) # identity - DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions - DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector - # (N to 1xN) - DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector - # (N to Nx1) - DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB - DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB - DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA - - The reordering of the dimensions can be done with the numpy.transpose - function. - Adding, subtracting dimensions can be done with reshape. + DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector + DimShuffle(input_ndim=2, new_order=[0, 1]) # identity + DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition + # Make a row out of a 1d vector (N to 1xN) + DimShuffle(input_ndim=1, new_order=["x", 0]) + # Make a colum out of a 1d vector (N to Nx1) + DimShuffle(input_ndim=1, new_order=[0, "x"]) + DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB + DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB + DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA + Notes + ----- + The python implementation of this Op combines numpy.transpose for reordering of the dimensions + and numpy.reshape for subtracting and adding broadcastable dimensions. """ _f16_ok = True check_input = False - __props__ = ("input_broadcastable", "new_order", "inplace") + __props__ = ("input_ndim", "new_order", "inplace") c_func_file = "c_code/dimshuffle.c" c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)" @@ -133,16 +131,14 @@ def params_type(self): inplace=scalar_bool, ) - def __init__(self, input_broadcastable, new_order): + def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): super().__init__([self.c_func_file], self.c_func_name) - self.input_broadcastable = tuple(input_broadcastable) - if not all(isinstance(bs, bool | np.bool_) for bs in self.input_broadcastable): - raise ValueError( - f"input_broadcastable must be boolean, {self.input_broadcastable}" - ) - self.new_order = tuple(new_order) + if not isinstance(input_ndim, int): + raise TypeError(f"input_ndim must be an integer, got {type(int)}") + self.input_ndim = input_ndim + self.new_order = tuple(new_order) self.inplace = True for i, j in enumerate(new_order): @@ -152,10 +148,10 @@ def __init__(self, input_broadcastable, new_order): "DimShuffle indices must be Python ints; got " f"{j} of type {type(j)}." ) - if j >= len(input_broadcastable): + if j >= input_ndim: raise ValueError( f"new_order[{i}] is {j}, but the input only has " - f"{len(input_broadcastable)} axes." + f"{input_ndim} axes." ) if j in new_order[(i + 1) :]: raise ValueError( @@ -164,19 +160,7 @@ def __init__(self, input_broadcastable, new_order): ) # List of input dimensions to drop - drop = [] - for i, b in enumerate(input_broadcastable): - if i not in new_order: - # We want to drop this dimension because it's not a value in - # `new_order` - if b == 1: - drop.append(i) - else: - # We cannot drop non-broadcastable dimensions - raise ValueError( - "Cannot drop a non-broadcastable dimension: " - f"{input_broadcastable}, {new_order}" - ) + drop = [i for i in range(input_ndim) if i not in new_order] # This is the list of the original dimensions that we keep self.shuffle = [x for x in new_order if x != "x"] @@ -186,7 +170,6 @@ def __init__(self, input_broadcastable, new_order): self.augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.drop = drop - input_ndim = len(input_broadcastable) self.is_left_expand_dims = self.augment and ( input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) ) @@ -204,30 +187,29 @@ def __setstate__(self, state): # Let's just build the ExternalCOp. super().__init__([self.c_func_file], self.c_func_name) - def make_node(self, _input): - input = as_tensor_variable(_input) - ib = tuple(s == 1 for s in input.type.shape) - if ib != self.input_broadcastable: - if len(ib) != len(self.input_broadcastable): + def make_node(self, inp): + input = as_tensor_variable(inp) + if input.type.ndim != self.input_ndim: + raise TypeError( + "The number of dimensions of the input is incorrect for this op. " + f"Expected {self.input_ndim}, got {input.type.ndim}." + ) + + input_static_shape = input.type.shape + + # Runtime check for invalid drop + for d in self.drop: + if input_static_shape[d] not in (1, None): raise TypeError( - "The number of dimensions of the " - f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." + f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}" ) - for expected, b in zip(self.input_broadcastable, ib): - if expected and not b: - raise TypeError( - "The broadcastable pattern of the " - f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}." - ) - # else, expected == b or not expected and b - # Both case are good. out_static_shape = [] for dim_idx in self.new_order: if dim_idx == "x": out_static_shape.append(1) else: - out_static_shape.append(input.type.shape[dim_idx]) + out_static_shape.append(input_static_shape[dim_idx]) output = TensorType(dtype=input.type.dtype, shape=out_static_shape)() @@ -254,12 +236,14 @@ def perform(self, node, inp, out): if not isinstance(res, np.ndarray | np.memmap): raise TypeError(res) + # Put dropped axis at end res = res.transpose(self.transposition) - shape = list(res.shape[: len(self.shuffle)]) + # Define new shape without dropped axis and including new ones + new_shape = list(res.shape[: len(self.shuffle)]) for augm in self.augment: - shape.insert(augm, 1) - res = res.reshape(shape) + new_shape.insert(augm, 1) + res = res.reshape(new_shape) if not self.inplace: res = np.copy(res) @@ -284,22 +268,15 @@ def R_op(self, inputs, eval_points): def grad(self, inp, grads): (x,) = inp (gz,) = grads - gz = as_tensor_variable(gz) grad_order = ["x"] * x.type.ndim for i, v in enumerate(self.new_order): if v != "x": grad_order[v] = i - # Do not make the DimShuffle inplace as an optimization at the - # canonicalization optimization phase will remove the inplace. - # The inplace will be reintroduced automatically later in the graph. - if inp[0].dtype in discrete_dtypes: - return [inp[0].zeros_like(dtype=config.floatX)] + + if x.type.dtype in discrete_dtypes: + return [x.zeros_like(dtype=config.floatX)] else: - return [ - DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)( - Elemwise(scalar_identity)(gz) - ) - ] + return [gz.dimshuffle(grad_order)] class DimShufflePrinter(Printer): @@ -409,7 +386,7 @@ def __setstate__(self, d): self.nfunc = None self.inplace_pattern = frozendict(self.inplace_pattern) - def get_output_info(self, dim_shuffle, *inputs): + def get_output_info(self, *inputs): """Return the outputs dtype and broadcastable pattern and the dimshuffled inputs. @@ -427,12 +404,7 @@ def get_output_info(self, dim_shuffle, *inputs): if not difference: args.append(input) else: - args.append( - dim_shuffle( - input.type.broadcastable, - ["x"] * difference + list(range(length)), - )(input) - ) + args.append(input.dimshuffle(["x"] * difference + list(range(length)))) inputs = args # HERE: all the broadcast dims have the same length now @@ -489,7 +461,7 @@ def make_node(self, *inputs): using DimShuffle. """ inputs = [as_tensor_variable(i) for i in inputs] - out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) + out_dtypes, out_shapes, inputs = self.get_output_info(*inputs) outputs = [ TensorType(dtype=dtype, shape=shape)() for dtype, shape in zip(out_dtypes, out_shapes) @@ -634,7 +606,7 @@ def transform(r): res = pytensor.tensor.basic.constant( np.asarray(r.data), dtype=r.type.dtype ) - return DimShuffle((), ["x"] * nd)(res) + return res.dimshuffle(["x"] * nd) new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs]) if isinstance(new_r, list | tuple): @@ -1707,13 +1679,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl batched_ndims = x.type.ndim - node.inputs[0].type.ndim if not batched_ndims: return node.op.make_node(x) - input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable - # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2)) - # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x")) + # e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2)) + # e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x")) new_order = list(range(batched_ndims)) + [ "x" if (o == "x") else (o + batched_ndims) for o in op.new_order ] - return DimShuffle(input_broadcastable, new_order).make_node(x) + return x.dimshuffle(new_order).owner def get_normalized_batch_axes( diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 009880274e..6f181062de 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -41,7 +41,7 @@ ) from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import sum as pt_sum -from pytensor.tensor.shape import Shape_i, specify_broadcastable +from pytensor.tensor.shape import Shape_i from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.variable import TensorVariable @@ -609,11 +609,6 @@ def squeeze(x, axis=None): # Nothing could be squeezed return _x - # `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable. - # We add a `specify_broadcastable` instead of raising. - non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]] - _x = specify_broadcastable(_x, *non_broadcastable_axis) - return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis]) diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index 73b3942327..76738fdb63 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -1,6 +1,6 @@ from pytensor import printing from pytensor.printing import pprint -from pytensor.tensor.elemwise import DimShuffle, scalar_elemwise +from pytensor.tensor.elemwise import scalar_elemwise @scalar_elemwise @@ -429,4 +429,4 @@ def hyp2f1_inplace(a, b, c, z): def transpose_inplace(x, **kwargs): "Perform a transpose on a tensor without copying the underlying storage" dims = list(range(x.ndim - 1, -1, -1)) - return DimShuffle(x.broadcastable, dims)(x) + return x.dimshuffle(dims) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 1ad9ce0158..1b5b94aa7f 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -33,7 +33,6 @@ from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import ( CAReduce, - DimShuffle, Elemwise, get_normalized_batch_axes, scalar_elemwise, @@ -2338,8 +2337,7 @@ def L_op(self, inp, out, grads): else: new_dims.append(i) i += 1 - ds_op = DimShuffle(gz.type.broadcastable, new_dims) - gx = Elemwise(ps.second)(x, ds_op(gz)) + gx = Elemwise(ps.second)(x, gz.dimshuffle(new_dims)) return [gx] def R_op(self, inputs, eval_points): diff --git a/pytensor/tensor/random/rewriting/jax.py b/pytensor/tensor/random/rewriting/jax.py index ef68235889..d25c9e16ea 100644 --- a/pytensor/tensor/random/rewriting/jax.py +++ b/pytensor/tensor/random/rewriting/jax.py @@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node): if isinstance(size_node.op, MakeVector) or ( isinstance(size_node.op, DimShuffle) - and size_node.op.input_broadcastable == () + and size_node.op.input_ndim == 0 and size_node.op.new_order == ("x",) ): # Here PyTensor converted a tuple or list to a tensor diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 6a038cab15..3cdd5b7ad6 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node): dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list( range(len(new_output_shape)) ) - return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] + return [inner.dimshuffle(dimshuffle_new_order)] @node_rewriter([AllocEmpty]) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 99dee1fd3f..66261ef21f 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node): """ op = node.op - if not isinstance(op, DimShuffle): - return False inp = node.inputs[0] inode = inp.owner @@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node): # Don't use make_node to have tag.test_value set. new_inputs = [] for inp in inode.inputs: - new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp) + new_inp = inp.dimshuffle(op.new_order) new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp)) copy_stack_trace(node.outputs[0], new_inputs) ret = inode.op(*new_inputs, return_list=True) @@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node): if is_dimshuffle_useless(new_order, inp): return [inp] elif inode and isinstance(inode.op, DimShuffle): - ret = op.__class__(inp.type.broadcastable, new_order)(inp) + ret = inp.dimshuffle(new_order) ret = apply_local_dimshuffle_lift(fgraph, ret) copy_stack_trace(node.outputs[0], ret) return [ret] diff --git a/pytensor/tensor/rewriting/jax.py b/pytensor/tensor/rewriting/jax.py index 59e701d328..00ed3f2b14 100644 --- a/pytensor/tensor/rewriting/jax.py +++ b/pytensor/tensor/rewriting/jax.py @@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node): if isinstance(shape_node.op, MakeVector) or ( isinstance(shape_node.op, DimShuffle) - and shape_node.op.input_broadcastable == () + and shape_node.op.input_ndim == 0 and shape_node.op.new_order == ("x",) ): # Here PyTensor converted a tuple or list to a tensor diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 47ca08cf21..d34966775a 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool: if ndims < 2: return False transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2) - return cast(bool, node.op.new_order == transpose_order) + return node.op.new_order == transpose_order return False diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 3c1b648d88..91c731a4ff 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node): if index != output.type.ndim: inner = op.__class__(len(new_output_shape))(inp, new_output_shape) copy_stack_trace(output, inner) - new_node = [ - DimShuffle(tuple(s == 1 for s in inner.type.shape), dimshuffle_new_order)( - inner - ) - ] + new_node = [inner.dimshuffle(dimshuffle_new_order)] copy_stack_trace(output, new_node) return new_node diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 613fb80f3e..261a8bbc4a 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -344,8 +344,8 @@ def dimshuffle(self, *pattern): """ if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)): pattern = pattern[0] - op = pt.elemwise.DimShuffle(list(self.type.broadcastable), pattern) - return op(self) + ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern) + return ds_op(self) def flatten(self, ndim=1): return pt.basic.flatten(self, ndim) diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py index 0f08944814..856d8c4881 100644 --- a/tests/link/jax/test_elemwise.py +++ b/tests/link/jax/test_elemwise.py @@ -39,7 +39,7 @@ def test_jax_Dimshuffle(): compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) a_pt = tensor(dtype=config.floatX, shape=(None, 1)) - x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) + x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt) x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 8bbbe164fc..4c13004409 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -15,7 +15,7 @@ from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.tensor import elemwise as pt_elemwise +from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from tests.link.numba.test_basic import ( @@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark): ], ) def test_Dimshuffle(v, new_order): - g = pt_elemwise.DimShuffle(v.broadcastable, new_order)(v) + g = v.dimshuffle(new_order) g_fg = FunctionGraph(outputs=[g]) compare_numba_and_py( g_fg, @@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order): def test_Dimshuffle_returns_array(): x = pt.vector("x", shape=(1,)) - y = 2 * pt_elemwise.DimShuffle([True], [])(x) + y = 2 * x.dimshuffle([]) func = pytensor.function([x], y, mode="NUMBA") out = func(np.zeros(1, dtype=config.floatX)) assert out.ndim == 0 @@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous(): non-contiguous arrays, make sure we work around thpt.""" x = pt.dvector() idx = pt.vector(dtype="int64") - op = pytensor.tensor.elemwise.DimShuffle([True], []) + op = DimShuffle(input_ndim=1, new_order=[]) out = op(pt.specify_shape(x[idx][::2], (1,))) func = pytensor.function([x, idx], out, mode="NUMBA") assert func(np.zeros(3), np.array([1])).ndim == 0 diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index afb62848cc..8b334e0efe 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -5,7 +5,6 @@ import pytensor.tensor.math as ptm from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph -from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, tensor3, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle(): x_fg = FunctionGraph([a_pt], [x]) compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) - a_pt = tensor(dtype=config.floatX, shape=(None, 1)) - x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt) - x_fg = FunctionGraph([a_pt], [x]) - compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) - def test_multiple_input_output(): x = vector("x") diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 692598c2c7..82cfa884af 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -79,7 +79,7 @@ def ds(x, y): - return DimShuffle(x.type.broadcastable, y)(x) + return x.dimshuffle(y) def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 174858da30..4444fc6891 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -160,7 +160,7 @@ def ds(x, y): - return DimShuffle(x.type.broadcastable, y)(x) + return x.dimshuffle(y) def rewrite(g, level="fast_run"): @@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum(): check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None) # If a transpose is applied to the sum - transpose_op = DimShuffle((False, False), (1, 0)) + transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0)) check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op) # If the sum is performed with keepdims=True @@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one(): assert np.allclose(naive_ret, rewritten_ret) # If a transpose is applied - transpose_op = DimShuffle((False, False), (1, 0)) + transpose_op = DimShuffle(input_ndim=2, new_order=(1, 0)) f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op) naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T) rewritten_ret = f(x_val) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 05aa15aa05..c3db43dddd 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3418,7 +3418,7 @@ def test_unalign(): def test_dimshuffle_duplicate(): x = vector() with pytest.raises(ValueError, match="may not appear twice"): - DimShuffle((False,), (0, 0))(x) + DimShuffle(input_ndim=1, new_order=(0, 0))(x) class TestGetUnderlyingScalarConstantValue: diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 34a1d1bcf9..c2a9c95e18 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -593,9 +593,9 @@ def test_basic(self): b = pt.constant(np.asarray([[[0.5]]])) b2 = b.dimshuffle() assert b2.ndim == 0 - d_a = DimShuffle([], [])(a) - d_b = DimShuffle([True, True, True], [0, 2, 1])(b) - d_a2 = DimShuffle([], ["x", "x", "x"])(a) + d_a = DimShuffle(input_ndim=0, new_order=[])(a) + d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b) + d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a) assert _as_scalar(a) == a assert _as_scalar(b) != b @@ -607,13 +607,13 @@ def test_basic_1(self): # Test that it fails on nonscalar constants a = pt.constant(np.ones(5)) assert _as_scalar(a) is None - assert _as_scalar(DimShuffle([False], [0, "x"])(a)) is None + assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None def test_basic_2(self): # Test that it works on scalar variables a = dscalar() - d_a = DimShuffle([], [])(a) - d_a2 = DimShuffle([], ["x", "x"])(a) + d_a = DimShuffle(input_ndim=0, new_order=[])(a) + d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a) assert _as_scalar(a) is a assert _as_scalar(d_a) is a @@ -623,13 +623,15 @@ def test_basic_3(self): # Test that it fails on nonscalar variables a = matrix() assert _as_scalar(a) is None - assert _as_scalar(DimShuffle([False, False], [0, "x", 1])(a)) is None + assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None class TestRealMatrix: def test_basic(self): - assert _is_real_matrix(DimShuffle([False, False], [1, 0])(matrix())) - assert not _is_real_matrix(DimShuffle([False], ["x", 0])(dvector())) + assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix())) + assert not _is_real_matrix( + DimShuffle(input_ndim=1, new_order=["x", 0])(dvector()) + ) """ diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 284e8051a7..76906232af 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -60,46 +60,40 @@ def with_linker(self, linker): ((1,), ("x", "x"), (1, 1)), ]: i_shape = [entry if entry == 1 else None for entry in xsh] - ib = [entry == 1 for entry in i_shape] x = self.type(self.dtype, shape=i_shape)("x") - e = self.op(ib, shuffle)(x) + e = self.op(input_ndim=len(i_shape), new_order=shuffle)(x) f = pytensor.function([x], e, mode=Mode(linker=linker)) assert f(np.ones(xsh, dtype=self.dtype)).shape == zsh # test that DimShuffle.infer_shape work correctly x = self.type(self.dtype, shape=i_shape)("x") - e = self.op(ib, shuffle)(x) + e = self.op(input_ndim=len(i_shape), new_order=shuffle)(x) f = pytensor.function( [x], e.shape, mode=Mode(linker=linker), on_unused_input="ignore" ) assert all(f(np.ones(xsh, dtype=self.dtype))) == all(zsh) # Test when we drop a axis that is not broadcastable - ib = [False, True, False] - x = self.type(self.dtype, shape=(None, 1, None))("x") - with pytest.raises(ValueError): - self.op(ib, shuffle) + x = self.type(self.dtype, shape=(2, 1, None))("x") + with pytest.raises(TypeError): + self.op(input_ndim=3, new_order=shuffle)(x) # Test when we drop a axis that don't have shape 1 - ib = [True, True, False] - x = self.type(self.dtype, shape=(1, 1, None))("x") - e = self.op(ib, (1, 2))(x) - f = pytensor.function([x], e.shape, mode=Mode(linker=linker)) - with pytest.raises(TypeError): - f(np.ones((2, 1, 4))) + x = self.type(self.dtype, shape=(None, 1, None))("x") + e = self.op(input_ndim=3, new_order=(1, 2))(x) + f = pytensor.function([x], e, mode=Mode(linker=linker)) + with pytest.raises(ValueError): + f(np.ones((2, 1, 4), dtype=self.dtype)) # Test that we can't take a dimensions multiple time xsh, shuffle, zsh = ((1, 1, 4), (0, 1, 2, 0), (1, 4)) - ib = [False, True, False] x = self.type(self.dtype, shape=(None, 1, None))("x") with pytest.raises(ValueError): - DimShuffle(ib, shuffle) + DimShuffle(input_ndim=3, new_order=shuffle) def test_perform(self): self.with_linker(PerformLinker()) def test_c_or_py(self): - # Shape op don't have C code. - # But This will test DimShuffle c code self.with_linker(OpWiseCLinker()) def test_infer_shape(self): @@ -115,12 +109,11 @@ def test_infer_shape(self): ((1,), ("x", "x")), ]: i_shape = [entry if entry == 1 else None for entry in xsh] - ib = [(entry == 1) for entry in xsh] adtens = self.type(self.dtype, shape=i_shape)("x") adtens_val = np.ones(xsh, dtype=self.dtype) self._compile_and_check( [adtens], - [self.op(ib, shuffle)(adtens)], + [self.op(input_ndim=len(xsh), new_order=shuffle)(adtens)], [adtens_val], self.op, warn=False, @@ -191,11 +184,11 @@ def test_static_shape(self): y = x.dimshuffle([0, 1, "x"]) assert y.type.shape == (1, 2, 1) - def test_valid_input_broadcastable(self): - assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False) + def test_valid_input_ndim(self): + assert DimShuffle(input_ndim=2, new_order=(1, 0)).input_ndim == 2 - with pytest.raises(ValueError, match="input_broadcastable must be boolean"): - DimShuffle([None, None], (1, 0)) + with pytest.raises(TypeError, match="input_ndim must be an integer"): + DimShuffle(input_ndim=(True, False), new_order=(1, 0)) class TestBroadcast: diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index b8a4b46192..0da714c3bf 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -480,12 +480,9 @@ def test_invalid_input(self): assert f([0]) == 0 # Test that we cannot squeeze dimensions whose length is greater than 1 - error_txt_1 = re.escape("SpecifyShape: Got shape (3,), expected (1,).") - error_txt_2 = re.escape("SpecifyShape: dim 0 of input has shape 3, expected 1") - match = error_txt_1 if pytensor.config.mode == "FAST_COMPILE" else error_txt_2 with pytest.raises( - AssertionError, - match=match, + ValueError, + match="cannot reshape array of size 3 into shape ()", ): f([0, 1, 2]) diff --git a/tests/tensor/test_fft.py b/tests/tensor/test_fft.py index 3599c97de3..94c49662bc 100644 --- a/tests/tensor/test_fft.py +++ b/tests/tensor/test_fft.py @@ -204,3 +204,12 @@ def f_irfft(inp): pytensor.config.floatX ) utt.verify_grad(f_irfft, [inputs_val], eps=eps) + + def test_rfft_expanded_dims_grad(self): + # Regression test for https://github.com/pymc-devs/pytensor/issues/969 + def test_func(x): + return fft.rfft(x[None, :]) + + rng = np.random.default_rng(213) + inputs_val = rng.random((N,)).astype(pytensor.config.floatX) + utt.verify_grad(test_func, [inputs_val], rng=rng) diff --git a/tests/tensor/test_keepdims.py b/tests/tensor/test_keepdims.py index 17a8d6cdcc..06aaeb5ae9 100644 --- a/tests/tensor/test_keepdims.py +++ b/tests/tensor/test_keepdims.py @@ -4,7 +4,6 @@ import pytensor from pytensor import function from pytensor.compile.mode import Mode -from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import any as pt_any from pytensor.tensor.math import argmax, argmin, max_and_argmax, mean, prod, std, var @@ -40,7 +39,7 @@ def makeKeepDims_local(self, x, y, axis): new_dims.append(i) i += 1 - return DimShuffle(y.type.broadcastable, new_dims)(y) + return y.dimshuffle(new_dims) @pytest.mark.parametrize( "axis",