Skip to content

Commit

Permalink
Added: docstring for return_attention_scores and added a test to chek…
Browse files Browse the repository at this point in the history
… the working of the argument
  • Loading branch information
anirudhr20 committed Sep 25, 2024
1 parent fbcb810 commit db38e19
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
11 changes: 8 additions & 3 deletions keras_hub/src/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,12 @@ def build(self, inputs_shape):
self.built = True

def call(
self, inputs, padding_mask=None, attention_mask=None, training=None, return_attention_scores=False
self,
inputs,
padding_mask=None,
attention_mask=None,
training=None,
return_attention_scores=False,
):
"""Forward pass of the TransformerEncoder.
Expand All @@ -199,6 +204,7 @@ def call(
[batch_size, sequence_length, sequence_length].
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`.
Returns:
A Tensor of the same shape as the `inputs`.
Expand Down Expand Up @@ -232,7 +238,6 @@ def call(
training=training,
)


x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
Expand All @@ -248,7 +253,7 @@ def call(
x = x + residual
if not self.normalize_first:
x = self._feedforward_layer_norm(x)

if return_attention_scores:
return x, attention_scores

Expand Down
11 changes: 11 additions & 0 deletions keras_hub/src/layers/modeling/transformer_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,14 @@ def test_mask_propagation(self):
inputs._keras_mask = mask
outputs = encoder(inputs)
self.assertAllEqual(outputs._keras_mask, mask)

def test_attention_scores(self):
encoder = TransformerEncoder(intermediate_dim=4, num_heads=2)
inputs = random.uniform(shape=[1, 4, 6])
outputs, attention_scores = encoder(
inputs, return_attention_scores=True
)
print(attention_scores)
assert outputs.shape == inputs.shape
# attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length)
assert attention_scores.shape == [1, 2, 4, 4]

0 comments on commit db38e19

Please sign in to comment.