Skip to content

Commit

Permalink
feat(frontends): Implement rfft in PaddlePaddle frontend and fix fft …
Browse files Browse the repository at this point in the history
…for Tensorflow backend (#19454)

Co-authored-by: hmahmood24 <[email protected]>
  • Loading branch information
AwkNinja and hmahmood24 authored Sep 28, 2023
1 parent acebe2b commit d05370e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
9 changes: 8 additions & 1 deletion ivy/functional/backends/tensorflow/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ def _ifft_norm(
raise ivy.utils.exceptions.IvyError(f"Unrecognized normalization mode {norm}")


@with_supported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
@with_supported_dtypes(
{"2.13.0 and below": ("complex", "float32", "float64")}, backend_version
)
def fft(
x: Union[tf.Tensor, tf.Variable],
dim: int,
Expand All @@ -658,6 +660,11 @@ def fft(
n: Union[int, Tuple[int]] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
# ToDo: Remove conversion from float to complex when casting mode is working
if x.dtype == "float32":
x = tf.cast(x, tf.complex64)
elif x.dtype == "float64":
x = tf.cast(x, tf.complex128)
if not isinstance(dim, int):
raise ivy.utils.exceptions.IvyError(
f"Expecting <class 'int'> instead of {type(dim)}"
Expand Down
6 changes: 6 additions & 0 deletions ivy/functional/frontends/paddle/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def irfftn(x, s=None, axes=None, norm="backward", name=None):
return result_t


@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle")
@to_ivy_arrays_and_back
def rfft(x, n=None, axis=-1, norm="backward", name=None):
return ivy.dft(x, axis=axis, inverse=False, onesided=True, dft_length=n, norm=norm)


@to_ivy_arrays_and_back
def rfftfreq(n, d=1.0, dtype=None, name=None):
dtype = ivy.default_dtype()
Expand Down
41 changes: 41 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,47 @@ def test_paddle_irfftn(
)


# rfft
@handle_frontend_test(
fn_tree="paddle.fft.rfft",
dtype_input_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=1,
min_dim_size=2,
shape=helpers.get_shape(
min_num_dims=1,
max_num_dims=2,
min_dim_size=2,
max_dim_size=4,
),
large_abs_safety_factor=12,
small_abs_safety_factor=12,
safety_factor_scale="log",
force_int_axis=True,
valid_axis=True,
allow_neg_axes=True,
),
norm=st.sampled_from(["backward", "ortho", "forward"]),
n=st.integers(min_value=2, max_value=10) | st.none(),
)
def test_paddle_rfft(
dtype_input_axis, norm, n, frontend, backend_fw, test_flags, fn_tree, on_device
):
input_dtype, x, axis = dtype_input_axis
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
x=x[0],
n=n,
axis=axis,
norm=norm,
)


@handle_frontend_test(
fn_tree="paddle.fft.rfftfreq",
n=st.integers(min_value=1, max_value=1000),
Expand Down

0 comments on commit d05370e

Please sign in to comment.