Skip to content

Commit

Permalink
Only require input_ndim and not input_broadcastable in DimShuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 8, 2024
1 parent d68f53f commit e88117e
Show file tree
Hide file tree
Showing 24 changed files with 132 additions and 181 deletions.
5 changes: 2 additions & 3 deletions pytensor/sparse/sandbox/sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pytensor.tensor.math import dot
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.shape import reshape
from pytensor.tensor.subtensor import DimShuffle


def register_specialize(lopt, *tags, **kwargs):
Expand Down Expand Up @@ -375,7 +374,7 @@ def convolve(
[images.shape[0], pt.as_tensor(np.prod(outshp)), pt.as_tensor(nkern)]
)
tensout = reshape(output, newshp, ndim=3)
output = DimShuffle((False,) * tensout.ndim, (0, 2, 1))(tensout)
output = tensout.transpose(0, 2, 1)
if flatten:
output = pt.flatten(output, 2)

Expand Down Expand Up @@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
)
out2 = reshape(out1, pshape, ndim=3)

out3 = DimShuffle(out2.broadcastable, (0, 2, 1))(out2)
out3 = out2.transpose(0, 2, 1)

return pt.flatten(out3, 2), outshp
4 changes: 2 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
# No-op
return _x

ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
ret = _x.dimshuffle(axes)

if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T"
Expand Down Expand Up @@ -3518,7 +3518,7 @@ def grad(self, inp, grads):
newdims.append(i)
i += 1

gx = DimShuffle(tuple(s == 1 for s in gx.type.shape), newdims)(gx)
gx = gx.dimshuffle(newdims)
assert gx.type.ndim == x.type.ndim
assert all(
s1 == s2
Expand Down
161 changes: 66 additions & 95 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Sequence
from copy import copy
from textwrap import dedent
from typing import Literal

import numpy as np
from numpy.core.numeric import normalize_axis_tuple
Expand Down Expand Up @@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
Parameters
----------
input_broadcastable
The expected broadcastable pattern of the input
input_ndim
The expected number of dimension of the input
new_order
A list representing the relationship between the input's
dimensions and the output's dimensions. Each element of the
list can either be an index or 'x'. Indices must be encoded
as python integers, not pytensor symbolic integers.
inplace : bool, optional
If True (default), the output will be a view of the input.
Missing indexes correspond to drop dimensions.
Notes
-----
Expand All @@ -77,50 +78,47 @@ class DimShuffle(ExternalCOp):
.. code-block:: python
DimShuffle((False, False, False), ["x", 2, "x", 0, 1])
DimShuffle(input_ndim=3, new_order=["x", 2, "x", 0, 1])
This `Op` will only work on 3d tensors with no broadcastable
dimensions. The first dimension will be broadcastable,
This `Op` will only work on 3d tensors.
The first dimension of the output will be broadcastable,
then we will have the third dimension of the input tensor as
the second of the resulting tensor, etc. If the tensor has
shape (20, 30, 40), the resulting tensor will have dimensions
(1, 40, 1, 20, 30). (AxBxC tensor is mapped to 1xCx1xAxB tensor)
.. code-block:: python
DimShuffle((True, False), [1])
DimShuffle(input_ndim=2, new_order=[1])
This `Op` will only work on 2d tensors with the first dimension
broadcastable.
The second dimension of the input tensor will be the first dimension of
the resulting tensor.
If the tensor has shape (1, 20), the resulting tensor will have shape
(20, ).
This `Op` will only work on 2d tensors with the first dimension broadcastable.
The second dimension of the input tensor will be the first dimension of the resulting tensor.
If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
Examples
--------
.. code-block:: python
DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector
DimShuffle((False, False), [0, 1]) # identity
DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector
# (N to 1xN)
DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector
# (N to Nx1)
DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB
DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA
The reordering of the dimensions can be done with the numpy.transpose
function.
Adding, subtracting dimensions can be done with reshape.
DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector
DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
# Make a row out of a 1d vector (N to 1xN)
DimShuffle(input_ndim=1, new_order=["x", 0])
# Make a colum out of a 1d vector (N to Nx1)
DimShuffle(input_ndim=1, new_order=[0, "x"])
DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB
DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA
Notes
-----
The python implementation of this Op combines numpy.transpose for reordering of the dimensions
and numpy.reshape for subtracting and adding broadcastable dimensions.
"""

_f16_ok = True
check_input = False
__props__ = ("input_broadcastable", "new_order", "inplace")
__props__ = ("input_ndim", "new_order", "inplace")
c_func_file = "c_code/dimshuffle.c"
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"

Expand All @@ -133,16 +131,14 @@ def params_type(self):
inplace=scalar_bool,
)

def __init__(self, input_broadcastable, new_order):
def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
super().__init__([self.c_func_file], self.c_func_name)

self.input_broadcastable = tuple(input_broadcastable)
if not all(isinstance(bs, bool | np.bool_) for bs in self.input_broadcastable):
raise ValueError(
f"input_broadcastable must be boolean, {self.input_broadcastable}"
)
self.new_order = tuple(new_order)
if not isinstance(input_ndim, int):
raise TypeError(f"input_ndim must be an integer, got {type(int)}")

self.input_ndim = input_ndim
self.new_order = tuple(new_order)
self.inplace = True

for i, j in enumerate(new_order):
Expand All @@ -152,10 +148,10 @@ def __init__(self, input_broadcastable, new_order):
"DimShuffle indices must be Python ints; got "
f"{j} of type {type(j)}."
)
if j >= len(input_broadcastable):
if j >= input_ndim:
raise ValueError(
f"new_order[{i}] is {j}, but the input only has "
f"{len(input_broadcastable)} axes."
f"{input_ndim} axes."
)
if j in new_order[(i + 1) :]:
raise ValueError(
Expand All @@ -164,19 +160,7 @@ def __init__(self, input_broadcastable, new_order):
)

# List of input dimensions to drop
drop = []
for i, b in enumerate(input_broadcastable):
if i not in new_order:
# We want to drop this dimension because it's not a value in
# `new_order`
if b == 1:
drop.append(i)
else:
# We cannot drop non-broadcastable dimensions
raise ValueError(
"Cannot drop a non-broadcastable dimension: "
f"{input_broadcastable}, {new_order}"
)
drop = [i for i in range(input_ndim) if i not in new_order]

# This is the list of the original dimensions that we keep
self.shuffle = [x for x in new_order if x != "x"]
Expand All @@ -186,7 +170,6 @@ def __init__(self, input_broadcastable, new_order):
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop

input_ndim = len(input_broadcastable)
self.is_left_expand_dims = self.augment and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
)
Expand All @@ -204,30 +187,29 @@ def __setstate__(self, state):
# Let's just build the ExternalCOp.
super().__init__([self.c_func_file], self.c_func_name)

def make_node(self, _input):
input = as_tensor_variable(_input)
ib = tuple(s == 1 for s in input.type.shape)
if ib != self.input_broadcastable:
if len(ib) != len(self.input_broadcastable):
def make_node(self, inp):
input = as_tensor_variable(inp)
if input.type.ndim != self.input_ndim:
raise TypeError(
"The number of dimensions of the input is incorrect for this op. "
f"Expected {self.input_ndim}, got {input.type.ndim}."
)

input_static_shape = input.type.shape

# Runtime check for invalid drop
for d in self.drop:
if input_static_shape[d] not in (1, None):
raise TypeError(
"The number of dimensions of the "
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
f"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
)
for expected, b in zip(self.input_broadcastable, ib):
if expected and not b:
raise TypeError(
"The broadcastable pattern of the "
f"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
)
# else, expected == b or not expected and b
# Both case are good.

out_static_shape = []
for dim_idx in self.new_order:
if dim_idx == "x":
out_static_shape.append(1)
else:
out_static_shape.append(input.type.shape[dim_idx])
out_static_shape.append(input_static_shape[dim_idx])

output = TensorType(dtype=input.type.dtype, shape=out_static_shape)()

Expand All @@ -254,12 +236,14 @@ def perform(self, node, inp, out):
if not isinstance(res, np.ndarray | np.memmap):
raise TypeError(res)

# Put dropped axis at end
res = res.transpose(self.transposition)

shape = list(res.shape[: len(self.shuffle)])
# Define new shape without dropped axis and including new ones
new_shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
res = res.reshape(shape)
new_shape.insert(augm, 1)
res = res.reshape(new_shape)

if not self.inplace:
res = np.copy(res)
Expand All @@ -284,22 +268,15 @@ def R_op(self, inputs, eval_points):
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
gz = as_tensor_variable(gz)
grad_order = ["x"] * x.type.ndim
for i, v in enumerate(self.new_order):
if v != "x":
grad_order[v] = i
# Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph.
if inp[0].dtype in discrete_dtypes:
return [inp[0].zeros_like(dtype=config.floatX)]

if x.type.dtype in discrete_dtypes:
return [x.zeros_like(dtype=config.floatX)]
else:
return [
DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
Elemwise(scalar_identity)(gz)
)
]
return [gz.dimshuffle(grad_order)]


class DimShufflePrinter(Printer):
Expand Down Expand Up @@ -409,7 +386,7 @@ def __setstate__(self, d):
self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)

def get_output_info(self, dim_shuffle, *inputs):
def get_output_info(self, *inputs):
"""Return the outputs dtype and broadcastable pattern and the
dimshuffled inputs.
Expand All @@ -427,12 +404,7 @@ def get_output_info(self, dim_shuffle, *inputs):
if not difference:
args.append(input)
else:
args.append(
dim_shuffle(
input.type.broadcastable,
["x"] * difference + list(range(length)),
)(input)
)
args.append(input.dimshuffle(["x"] * difference + list(range(length))))
inputs = args

# HERE: all the broadcast dims have the same length now
Expand Down Expand Up @@ -489,7 +461,7 @@ def make_node(self, *inputs):
using DimShuffle.
"""
inputs = [as_tensor_variable(i) for i in inputs]
out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
out_dtypes, out_shapes, inputs = self.get_output_info(*inputs)
outputs = [
TensorType(dtype=dtype, shape=shape)()
for dtype, shape in zip(out_dtypes, out_shapes)
Expand Down Expand Up @@ -634,7 +606,7 @@ def transform(r):
res = pytensor.tensor.basic.constant(
np.asarray(r.data), dtype=r.type.dtype
)
return DimShuffle((), ["x"] * nd)(res)
return res.dimshuffle(["x"] * nd)

new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
if isinstance(new_r, list | tuple):
Expand Down Expand Up @@ -1707,13 +1679,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node.op.make_node(x)
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
# e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
# e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
new_order = list(range(batched_ndims)) + [
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
]
return DimShuffle(input_broadcastable, new_order).make_node(x)
return x.dimshuffle(new_order).owner


def get_normalized_batch_axes(
Expand Down
7 changes: 1 addition & 6 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import Shape_i, specify_broadcastable
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -609,11 +609,6 @@ def squeeze(x, axis=None):
# Nothing could be squeezed
return _x

# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
_x = specify_broadcastable(_x, *non_broadcastable_axis)

return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])


Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/inplace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pytensor import printing
from pytensor.printing import pprint
from pytensor.tensor.elemwise import DimShuffle, scalar_elemwise
from pytensor.tensor.elemwise import scalar_elemwise


@scalar_elemwise
Expand Down Expand Up @@ -429,4 +429,4 @@ def hyp2f1_inplace(a, b, c, z):
def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1))
return DimShuffle(x.broadcastable, dims)(x)
return x.dimshuffle(dims)
Loading

0 comments on commit e88117e

Please sign in to comment.