Skip to content

Commit

Permalink
Create LossCTCKernels.h
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Sep 19, 2024
1 parent db3d203 commit 30904df
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/ATen/native/xpu/sycl/LossCTCKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once
#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

TORCH_XPU_API std::tuple<Tensor, Tensor> ctc_loss_kernel(
const Tensor& log_probs,
const Tensor& targets,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
int64_t BLANK,
bool zero_infinity);

} // namespace at::native::xpu

0 comments on commit 30904df

Please sign in to comment.