From db38e190ef98159396aa8c4996d0c236c5ef723b Mon Sep 17 00:00:00 2001 From: Anirudh R Date: Wed, 25 Sep 2024 13:16:03 -0700 Subject: [PATCH] Added: docstring for return_attention_scores and added a test to chek the working of the argument --- keras_hub/src/layers/modeling/transformer_encoder.py | 11 ++++++++--- .../src/layers/modeling/transformer_encoder_test.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/layers/modeling/transformer_encoder.py b/keras_hub/src/layers/modeling/transformer_encoder.py index cdc8050f8..b4a975ff8 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder.py +++ b/keras_hub/src/layers/modeling/transformer_encoder.py @@ -184,7 +184,12 @@ def build(self, inputs_shape): self.built = True def call( - self, inputs, padding_mask=None, attention_mask=None, training=None, return_attention_scores=False + self, + inputs, + padding_mask=None, + attention_mask=None, + training=None, + return_attention_scores=False, ): """Forward pass of the TransformerEncoder. @@ -199,6 +204,7 @@ def call( [batch_size, sequence_length, sequence_length]. training: a boolean indicating whether the layer should behave in training mode or in inference mode. + return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`. Returns: A Tensor of the same shape as the `inputs`. @@ -232,7 +238,6 @@ def call( training=training, ) - x = self._self_attention_dropout(x, training=training) x = x + residual if not self.normalize_first: @@ -248,7 +253,7 @@ def call( x = x + residual if not self.normalize_first: x = self._feedforward_layer_norm(x) - + if return_attention_scores: return x, attention_scores diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index 9640d02a1..ab6fb06ea 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -109,3 +109,14 @@ def test_mask_propagation(self): inputs._keras_mask = mask outputs = encoder(inputs) self.assertAllEqual(outputs._keras_mask, mask) + + def test_attention_scores(self): + encoder = TransformerEncoder(intermediate_dim=4, num_heads=2) + inputs = random.uniform(shape=[1, 4, 6]) + outputs, attention_scores = encoder( + inputs, return_attention_scores=True + ) + print(attention_scores) + assert outputs.shape == inputs.shape + # attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length) + assert attention_scores.shape == [1, 2, 4, 4]