Skip to content

Commit

Permalink
update the logic for make_relative_positions
Browse files Browse the repository at this point in the history
  • Loading branch information
hkwon committed Sep 11, 2024
1 parent 2b2be50 commit 4de051d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ namespace ctranslate2 {
StorageView make_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t max_position,
dim_t left_max_position,
dim_t right_max_position) {
dim_t left_max_position = 0,
dim_t right_max_position = 0) {
StorageView positions({queries_length, keys_length}, DataType::INT32);
auto* positions_data = positions.data<int32_t>();

Expand Down
10 changes: 10 additions & 0 deletions tests/layers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ TEST(LayerTest, MakeRelativePositions2D) {
expect_storage_eq(positions, expected);
}

TEST(LayerTest, MakeRelativePositions2D) {
const StorageView positions = layers::make_relative_positions(4, 4, 0, 3, 2);
const StorageView expected({4, 4}, std::vector<int32_t>{
3, 4, 5, 5,
2, 3, 4, 5,
1, 2, 3, 4,
0, 1, 2, 3});
expect_storage_eq(positions, expected);
}

TEST_P(LayerDeviceFPTest, Alibi) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
Expand Down

0 comments on commit 4de051d

Please sign in to comment.