Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] Support aclnn for argsort_grad #1244

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 160 additions & 13 deletions backends/npu/kernels/argsort_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,28 @@
#include "kernels/funcs/npu_funcs.h"
#include "kernels/funcs/npu_op_runner.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/kernels/funcs/tensor_formatter.h"

namespace custom_kernel {

template <typename T, typename Context>
void CastKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
phi::DataType dtype,
phi::DenseTensor* out);

template <typename T, typename Context>
void TransposeKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const std::vector<int>& axis,
phi::DenseTensor* out);

template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const phi::DenseTensor& y,
phi::DenseTensor* out);

template <typename Context, typename T>
static void TranposeNPU(const Context& dev_ctx,
const aclrtStream& stream,
Expand All @@ -34,12 +53,12 @@ static void TranposeNPU(const Context& dev_ctx,
}

template <typename Context, typename T, typename Type>
static void FullAssignNPU(const Context& dev_ctx,
const aclrtStream& stream,
const phi::DDim in_dims,
const phi::DenseTensor& input,
const phi::DenseTensor& indices,
phi::DenseTensor* t_out) {
static void AclopFullAssignNPU(const Context& dev_ctx,
const aclrtStream& stream,
const phi::DDim in_dims,
const phi::DenseTensor& input,
const phi::DenseTensor& indices,
phi::DenseTensor* t_out) {
const int64_t input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
Expand Down Expand Up @@ -86,13 +105,138 @@ static void FullAssignNPU(const Context& dev_ctx,
runner.Run(stream);
}

template <typename Context, typename T, typename Type>
static void FullAssignNPU(const Context& dev_ctx,
const phi::DDim in_dims,
const phi::DenseTensor& input,
const phi::DenseTensor& indices,
phi::DenseTensor* t_out) {
DO_COMPATIBILITY(
aclnnScatterNd,
(custom_kernel::AclopFullAssignNPU<Context, T, Type>(
dev_ctx, dev_ctx.stream(), in_dims, input, indices, t_out)));
const int64_t input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];

phi::DenseTensor input_tmp(input);
input_tmp.Resize(
phi::make_ddim(std::vector<int64_t>{input_height * input_width, 1}));

phi::DenseTensor indices_tmp(indices);
indices_tmp.Resize(
phi::make_ddim(std::vector<int64_t>{input_height, input_width}));

std::vector<int64_t> indexs_value;
for (Type i = 0; i < input_height; i++) {
indexs_value.push_back(i * input_width);
}
phi::DenseTensor indexs_tmp;
phi::DenseTensorMeta indexs_tmp_meta = {
indices.dtype(), phi::make_ddim(std::vector<int64_t>{input_height, 1})};
indexs_tmp.set_meta(indexs_tmp_meta);
dev_ctx.template Alloc<int64_t>(&indexs_tmp);
TensorFromVector<int64_t>(dev_ctx, indexs_value, dev_ctx, &indexs_tmp);
indexs_tmp.Resize(phi::make_ddim(std::vector<int64_t>{input_height, 1}));

phi::DenseTensor indices_index;
phi::DenseTensorMeta indices_index_meta = {indices.dtype(),
indices_tmp.dims()};
indices_index.set_meta(indices_index_meta);
dev_ctx.template Alloc<int64_t>(&indices_index);
custom_kernel::AddKernel<T, Context>(
dev_ctx, indices_tmp, indexs_tmp, &indices_index);

indices_index.Resize(
phi::make_ddim(std::vector<int64_t>{input_height * input_width, 1}));

phi::DenseTensor indices_index_int;
phi::DenseTensorMeta meta = {phi::DataType::INT64, indices_index.dims()};
indices_index_int.set_meta(meta);
custom_kernel::CastKernel<T, Context>(
dev_ctx, indices_index, phi::DataType::INT64, &indices_index_int);

dev_ctx.template Alloc<T>(t_out);
phi::DenseTensor out_tmp(*t_out);
out_tmp.Resize(input_tmp.dims());
EXEC_NPU_CMD(aclnnScatterNd,
dev_ctx,
input_tmp,
indices_index_int,
input_tmp,
out_tmp);
out_tmp.Resize(t_out->dims());
}

template <typename T, typename Context>
void AclopArgsortGradKernel(const Context& dev_ctx,
const phi::DenseTensor& indices,
const phi::DenseTensor& input,
const phi::DenseTensor& out_grad,
int axis,
bool descending,
bool stable,
phi::DenseTensor* in_grad) {
auto stream = dev_ctx.stream();
auto in_dims = indices.dims();
auto rank = input.dims().size();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
dev_ctx.template Alloc<T>(in_grad);
if (out_grad.numel() == 0) return;

if (rank == 0) {
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad);
return;
}

// Do full assign
if (axis == -1 || axis + 1 == in_dims.size()) {
AclopFullAssignNPU<Context, T, int64_t>(
dev_ctx, stream, in_dims, out_grad, indices, in_grad);
} else {
std::vector<int64_t> perm;
for (int64_t i = 0; i < in_dims.size(); i++) {
perm.emplace_back(i);
}
std::swap(perm[axis], perm[in_dims.size() - 1]);

std::vector<int64_t> shape;
for (size_t i = 0; i < perm.size(); i++) {
shape.emplace_back(in_dims[perm[i]]);
}
auto trans_dims = phi::make_ddim(shape);
phi::DenseTensor trans_dout;
phi::DenseTensor trans_ids;
phi::DenseTensorMeta trans_dout_meta = {out_grad.dtype(), trans_dims};
phi::DenseTensorMeta trans_ids_meta = {indices.dtype(), trans_dims};
trans_dout.set_meta(trans_dout_meta);
trans_ids.set_meta(trans_ids_meta);
dev_ctx.template Alloc<T>(&trans_dout);
dev_ctx.template Alloc<T>(&trans_ids);

TranposeNPU<Context, T>(dev_ctx, stream, &perm, out_grad, &trans_dout);
TranposeNPU<Context, int64_t>(dev_ctx, stream, &perm, indices, &trans_ids);

phi::DenseTensor trans_dx;
phi::DenseTensorMeta trans_dx_meta = {out_grad.dtype(), trans_dims};
trans_dx.set_meta(trans_dx_meta);
dev_ctx.template Alloc<T>(&trans_dx);

AclopFullAssignNPU<Context, T, int64_t>(
dev_ctx, stream, trans_dims, trans_dout, trans_ids, &trans_dx);

TranposeNPU<Context, T>(dev_ctx, stream, &perm, trans_dx, in_grad);
}
}

template <typename T, typename Context>
void ArgsortGradKernel(const Context& dev_ctx,
const phi::DenseTensor& indices,
const phi::DenseTensor& input,
const phi::DenseTensor& out_grad,
int axis,
bool descending,
bool stable,
phi::DenseTensor* in_grad) {
auto stream = dev_ctx.stream();
auto in_dims = indices.dims();
Expand All @@ -109,10 +253,10 @@ void ArgsortGradKernel(const Context& dev_ctx,
// Do full assign
if (axis == -1 || axis + 1 == in_dims.size()) {
FullAssignNPU<Context, T, int64_t>(
dev_ctx, stream, in_dims, out_grad, indices, in_grad);
dev_ctx, in_dims, out_grad, indices, in_grad);
} else {
std::vector<int64_t> perm;
for (int64_t i = 0; i < in_dims.size(); i++) {
std::vector<int> perm;
for (int i = 0; i < in_dims.size(); i++) {
perm.emplace_back(i);
}
std::swap(perm[axis], perm[in_dims.size() - 1]);
Expand All @@ -131,18 +275,21 @@ void ArgsortGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(&trans_dout);
dev_ctx.template Alloc<T>(&trans_ids);

TranposeNPU<Context, T>(dev_ctx, stream, &perm, out_grad, &trans_dout);
TranposeNPU<Context, int64_t>(dev_ctx, stream, &perm, indices, &trans_ids);
custom_kernel::TransposeKernel<T, Context>(
dev_ctx, out_grad, perm, &trans_dout);
custom_kernel::TransposeKernel<int64_t, Context>(
dev_ctx, indices, perm, &trans_ids);

phi::DenseTensor trans_dx;
phi::DenseTensorMeta trans_dx_meta = {out_grad.dtype(), trans_dims};
trans_dx.set_meta(trans_dx_meta);
dev_ctx.template Alloc<T>(&trans_dx);

FullAssignNPU<Context, T, int64_t>(
dev_ctx, stream, trans_dims, trans_dout, trans_ids, &trans_dx);
dev_ctx, trans_dims, trans_dout, trans_ids, &trans_dx);

TranposeNPU<Context, T>(dev_ctx, stream, &perm, trans_dx, in_grad);
custom_kernel::TransposeKernel<T, Context>(
dev_ctx, trans_dx, perm, in_grad);
}
}

Expand Down