Skip to content

Commit

Permalink
Remove type cast
Browse files Browse the repository at this point in the history
  • Loading branch information
iababio committed Aug 15, 2023
1 parent c1475dd commit f76738a
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions ivy/functional/frontends/paddle/nn/functional/norm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# local
import ivy
import paddle
from ivy.func_wrapper import with_supported_dtypes
from ivy.functional.frontends.paddle.func_wrapper import to_ivy_arrays_and_back


@to_ivy_arrays_and_back
@with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle")
@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle")
def batch_norm(
x,
running_mean,
Expand All @@ -18,13 +17,6 @@ def batch_norm(
epsilon=1e-05,
data_format="NCHW",
):
# convert to paddle tensors
x = paddle.to_tensor(ivy.to_list(x), dtype=x.dtype)
weight = paddle.to_tensor(ivy.to_list(weight), dtype=weight.dtype)
bias = paddle.to_tensor(ivy.to_list(bias), dtype=bias.dtype)
running_mean = paddle.to_tensor(ivy.to_list(running_mean), dtype=running_mean.dtype)
running_var = paddle.to_tensor(ivy.to_list(running_var), dtype=running_var.dtype)

normalized, _, _ = ivy.batch_norm(
x,
running_mean,
Expand Down

0 comments on commit f76738a

Please sign in to comment.