diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 447c5cc16..79637443f 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -59,8 +59,25 @@ namespace ctranslate2 { template void primitives::convert(const float16_t*, float*, dim_t); template void primitives::convert(const float*, bfloat16_t*, dim_t); template void primitives::convert(const bfloat16_t*, float*, dim_t); - template void primitives::convert(const float16_t*, bfloat16_t*, dim_t); - template void primitives::convert(const bfloat16_t*, float16_t*, dim_t); + + struct convert_via_float { + template + __device__ float operator()(T x) const { + return x; + } + }; + + template<> + template<> + void primitives::convert(const float16_t* x, bfloat16_t* y, dim_t size) { + cuda::unary_transform(x, y, size, convert_via_float()); + } + + template<> + template<> + void primitives::convert(const bfloat16_t* x, float16_t* y, dim_t size) { + cuda::unary_transform(x, y, size, convert_via_float()); + } template<> template