Skip to content

Commit

Permalink
Add more precise output type info to RFFT Op
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 8, 2024
1 parent e9f58c9 commit d68f53f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions pytensor/tensor/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ class RFFTOp(Op):

def output_type(self, inp):
# add extra dim for real/imag
return TensorType(inp.dtype, shape=(None,) * (inp.type.ndim + 1))
return TensorType(inp.dtype, shape=((None,) * inp.type.ndim) + (2,))

def make_node(self, a, s=None):
a = as_tensor_variable(a)
if a.ndim < 2:
raise TypeError(
f"{self.__class__.__name__}: input must have dimension > 2, with first dimension batches"
f"{self.__class__.__name__}: input must have dimension >= 2, with first dimension batches"
)

if s is None:
Expand All @@ -39,9 +39,10 @@ def perform(self, node, inputs, output_storage):
a = inputs[0]
s = inputs[1]

# FIXME: This call is deprecated in numpy 2.0
# axis must be provided when s is not None
A = np.fft.rfftn(a, s=tuple(s))
# Format output with two extra dimensions for real and imaginary
# parts.
# Format output with two extra dimensions for real and imaginary parts.
out = np.zeros((*A.shape, 2), dtype=a.dtype)
out[..., 0], out[..., 1] = np.real(A), np.imag(A)
output_storage[0][0] = out
Expand Down

0 comments on commit d68f53f

Please sign in to comment.