Skip to content

Commit

Permalink
[NPU] Fix inplace multiply_grad (#1274)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-jl944 authored Jun 4, 2024
1 parent 2444df5 commit 677b0fb
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions backends/npu/kernels/elementwise_mul_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ void MultiplyGradKernel(const Context& dev_ctx,
if (dx) {
phi::DenseTensor trans_y;
NpuBroadcast<T>(dev_ctx, &y, y_axis, dst_dims, &trans_y);
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedWith(dout)) {
dx->clear();
dx->Resize(x.dims());
}
if (dx->dims() == dout.dims()) {
dev_ctx.template Alloc<T>(dx);
EXEC_NPU_CMD(aclnnMul, dev_ctx, dout, trans_y, *dx);
Expand Down

0 comments on commit 677b0fb

Please sign in to comment.