Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]fix one hot decomposite bug #68380

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 7 additions & 69 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,75 +209,13 @@ std::tuple<Tensor, Tensor> huber_loss_decomp(const Tensor& input,

template <typename T>
Tensor one_hot_decomp(const Tensor& x, const Tensor& num_classes) {
auto num_classes_tensor =
backend::full_with_tensor<T>(num_classes, 0, x.dtype());
Tensor ans;
if (has_dynamic_shape(x.shape())) {
Tensor x_shape = shape<T>(x);
Tensor one_tensor = full<T>({1}, 1, x_shape.dtype());
Tensor zero_tensor = full<T>({1}, 0, x_shape.dtype());
Tensor x_dims = one_tensor;
for (size_t i = 0; i < x.shape().size(); i++) {
x_dims = x_dims * get_slice<T>(x_shape, i);
}
Tensor input_dim = concat<T>(
{x_dims, full<T>({1}, num_classes_tensor.shape()[0], x_shape.dtype())});
auto input_tensor = backend::full_with_tensor<T>(input_dim, 0, x.dtype());
auto output_dim = concat<T>(
{x_shape,
full<T>({1}, num_classes_tensor.shape()[0], x_shape.dtype())});
auto arange_tensor = backend::arange_with_tensor<T>(
zero_tensor, x_dims, one_tensor, x_shape.dtype());
Tensor reshape_dim = concat<T>({x_dims, one_tensor});
auto x_reshape = backend::reshape<T>(x, reshape_dim);
auto arange_tensor_reshape =
backend::reshape<T>(arange_tensor, reshape_dim);
auto index_tensor = concat<T>({arange_tensor_reshape, x_reshape}, 1);
auto update_tensor = backend::full_with_tensor<T>(x_dims, 1, x.dtype());
ans = backend::reshape<T>(
cast<T>(scatter_nd_add<T>(input_tensor, index_tensor, update_tensor),
DataType::FLOAT32),
output_dim);
} else {
std::vector<int64_t> input_dim;
int x_dims = 1;
for (size_t i = 0; i < x.shape().size(); i++) {
x_dims *= x.shape()[i];
}

input_dim.push_back(x_dims);
input_dim.push_back(num_classes_tensor.shape()[0]);
auto input_tensor = full<T>(input_dim, 0, x.dtype());

std::vector<int64_t> output_dim;
for (size_t i = 0; i < x.shape().size(); i++) {
output_dim.push_back(x.shape()[i]);
}
output_dim.push_back(num_classes_tensor.shape()[0]);

auto end = full<T>({1}, x_dims, x.dtype());
auto start = full<T>({1}, 0, x.dtype());
auto step = full<T>({1}, 1, x.dtype());
auto arange_tensor =
backend::arange_with_tensor<T>(start, end, step, x.dtype());

std::vector<int64_t> reshape_dim{x_dims, 1};
auto x_reshape = reshape<T>(x, reshape_dim);
auto arange_tensor_reshape = reshape<T>(arange_tensor, reshape_dim);

std::vector<Tensor> index_concat;
index_concat.push_back(arange_tensor_reshape);
index_concat.push_back(x_reshape);
auto index_tensor = concat<T>(index_concat, 1);

auto update_tensor = full<T>({x_dims}, 1, x.dtype());

ans = reshape<T>(
cast<T>(scatter_nd_add<T>(input_tensor, index_tensor, update_tensor),
DataType::FLOAT32),
output_dim);
}
return ans;
auto start = full<T>({1}, 0, x.dtype());
auto step = full<T>({1}, 1, x.dtype());
auto arange_class =
backend::arange_with_tensor<T>(start, num_classes, step, x.dtype());
auto reshape_x = backend::unsqueeze<T>(x, {-1});
auto equal_res = backend::equal<T>(reshape_x, arange_class);
return cast<T>(equal_res, phi::DataType::FLOAT32);
}

template <typename T>
Expand Down