diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 8b3b878c0..f2ff94995 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -16,8 +16,8 @@ namespace ctranslate2 { StorageView make_relative_positions(dim_t queries_length, dim_t keys_length, dim_t max_position, - dim_t left_max_position, - dim_t right_max_position) { + dim_t left_max_position = 0, + dim_t right_max_position = 0) { StorageView positions({queries_length, keys_length}, DataType::INT32); auto* positions_data = positions.data(); diff --git a/tests/layers_test.cc b/tests/layers_test.cc index 3a8e40958..3b4704161 100644 --- a/tests/layers_test.cc +++ b/tests/layers_test.cc @@ -21,6 +21,16 @@ TEST(LayerTest, MakeRelativePositions2D) { expect_storage_eq(positions, expected); } +TEST(LayerTest, MakeRelativePositions2D) { + const StorageView positions = layers::make_relative_positions(4, 4, 0, 3, 2); + const StorageView expected({4, 4}, std::vector{ + 3, 4, 5, 5, + 2, 3, 4, 5, + 1, 2, 3, 4, + 0, 1, 2, 3}); + expect_storage_eq(positions, expected); +} + TEST_P(LayerDeviceFPTest, Alibi) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype;