diff --git a/src/ops/position_encodings_add.cc b/src/ops/position_encodings_add.cc index 4896f41a6..280f3f05f 100644 --- a/src/ops/position_encodings_add.cc +++ b/src/ops/position_encodings_add.cc @@ -30,18 +30,20 @@ namespace ctranslate2 { output.resize_as(input); - DEVICE_AND_FLOAT_DISPATCH( - "PositionEncodingsAdd", input.device(), input.dtype(), - ({ - if (offsets) - compute(step, offsets, input, encodings, output); - else - primitives::add_batch_broadcast(encodings.data() + step * depth, - input.data(), - output.data(), - time * depth, - input.size()); - })); + if (offsets) { + DEVICE_AND_FLOAT_DISPATCH( + "PositionEncodingsAdd", input.device(), input.dtype(), + (compute(step, offsets, input, encodings, output))); + + } else { + DEVICE_AND_FLOAT_DISPATCH( + "PositionEncodingsAdd", input.device(), input.dtype(), + (primitives::add_batch_broadcast(encodings.data() + step * depth, + input.data(), + output.data(), + time * depth, + input.size()))); + } } }