Skip to content

Commit

Permalink
Fix compilation of convert for older architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Jul 5, 2023
1 parent fcbbbe3 commit 0df93a7
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/cuda/primitives.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,25 @@ namespace ctranslate2 {
template void primitives<Device::CUDA>::convert(const float16_t*, float*, dim_t);
template void primitives<Device::CUDA>::convert(const float*, bfloat16_t*, dim_t);
template void primitives<Device::CUDA>::convert(const bfloat16_t*, float*, dim_t);
template void primitives<Device::CUDA>::convert(const float16_t*, bfloat16_t*, dim_t);
template void primitives<Device::CUDA>::convert(const bfloat16_t*, float16_t*, dim_t);

struct convert_via_float {
template <typename T>
__device__ float operator()(T x) const {
return x;
}
};

template<>
template<>
void primitives<Device::CUDA>::convert(const float16_t* x, bfloat16_t* y, dim_t size) {
cuda::unary_transform(x, y, size, convert_via_float());
}

template<>
template<>
void primitives<Device::CUDA>::convert(const bfloat16_t* x, float16_t* y, dim_t size) {
cuda::unary_transform(x, y, size, convert_via_float());
}

template<>
template <typename T>
Expand Down

0 comments on commit 0df93a7

Please sign in to comment.