From 0b12c5f419c226cdac42d38a2fb562d4fc5e4077 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann Date: Fri, 18 Aug 2023 01:24:29 +0200 Subject: [PATCH] Fix masks in English-to-Spanish translation example (#1046) * fix masks for machine translation example * black and modified date * remove manual masking and rely on automatic masks instead * replace + by Add layer to preserve mask * regenerate ipynb and md files * fix black --- ...machine_translation_with_transformer.ipynb | 140 +++++++++--------- ...al_machine_translation_with_transformer.md | 138 ++++++++++------- ...al_machine_translation_with_transformer.py | 38 +---- 3 files changed, 163 insertions(+), 153 deletions(-) diff --git a/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb b/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb index 3de5b8993a..dc561b7169 100644 --- a/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb +++ b/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb @@ -10,7 +10,7 @@ "\n", "**Author:** [fchollet](https://twitter.com/fchollet)
\n", "**Date created:** 2021/05/26
\n", - "**Last modified:** 2023/02/25
\n", + "**Last modified:** 2023/08/17
\n", "**Description:** Implementing a sequence-to-sequene Transformer and training it on a machine translation task." ] }, @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -139,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -161,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -196,7 +196,7 @@ "The English layer will use the default string standardization (strip punctuation characters)\n", "and splitting scheme (split on whitespace), while\n", "the Spanish layer will use a custom standardization, where we add the character\n", - "`\"¿\"` to the set of punctuation characters to be stripped.\n", + "`\"\u00bf\"` to the set of punctuation characters to be stripped.\n", "\n", "Note: in a production-grade machine translation model, I would not recommend\n", "stripping the punctuation characters in either language. Instead, I would recommend turning\n", @@ -206,13 +206,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ - "strip_chars = string.punctuation + \"¿\"\n", + "strip_chars = string.punctuation + \"\u00bf\"\n", "strip_chars = strip_chars.replace(\"[\", \"\")\n", "strip_chars = strip_chars.replace(\"]\", \"\")\n", "\n", @@ -227,7 +227,9 @@ "\n", "\n", "eng_vectorization = TextVectorization(\n", - " max_tokens=vocab_size, output_mode=\"int\", output_sequence_length=sequence_length,\n", + " max_tokens=vocab_size,\n", + " output_mode=\"int\",\n", + " output_sequence_length=sequence_length,\n", ")\n", "spa_vectorization = TextVectorization(\n", " max_tokens=vocab_size,\n", @@ -263,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -273,7 +275,13 @@ "def format_dataset(eng, spa):\n", " eng = eng_vectorization(eng)\n", " spa = spa_vectorization(spa)\n", - " return ({\"encoder_inputs\": eng, \"decoder_inputs\": spa[:, :-1],}, spa[:, 1:])\n", + " return (\n", + " {\n", + " \"encoder_inputs\": eng,\n", + " \"decoder_inputs\": spa[:, :-1],\n", + " },\n", + " spa[:, 1:],\n", + " )\n", "\n", "\n", "def make_dataset(pairs):\n", @@ -302,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -333,7 +341,7 @@ "The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond).\n", "\n", "A key detail that makes this possible is causal masking\n", - "(see method `get_causal_attention_mask()` on the `TransformerDecoder`).\n", + "(`use_causal_mask=True` in the first attention layer of the `TransformerDecoder`).\n", "The `TransformerDecoder` sees the entire sequences at once, and thus we must make\n", "sure that it only uses information from target tokens 0 to N when predicting token N+1\n", "(otherwise, it could use information from the future, which would\n", @@ -342,7 +350,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -359,28 +367,30 @@ " num_heads=num_heads, key_dim=embed_dim\n", " )\n", " self.dense_proj = keras.Sequential(\n", - " [layers.Dense(dense_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", + " [\n", + " layers.Dense(dense_dim, activation=\"relu\"),\n", + " layers.Dense(embed_dim),\n", + " ]\n", " )\n", " self.layernorm_1 = layers.LayerNormalization()\n", " self.layernorm_2 = layers.LayerNormalization()\n", " self.supports_masking = True\n", "\n", " def call(self, inputs, mask=None):\n", - " if mask is not None:\n", - " padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=\"int32\")\n", - " attention_output = self.attention(\n", - " query=inputs, value=inputs, key=inputs, attention_mask=padding_mask\n", - " )\n", + " attention_output = self.attention(query=inputs, value=inputs, key=inputs)\n", " proj_input = self.layernorm_1(inputs + attention_output)\n", " proj_output = self.dense_proj(proj_input)\n", " return self.layernorm_2(proj_input + proj_output)\n", + "\n", " def get_config(self):\n", " config = super().get_config()\n", - " config.update({\n", - " \"embed_dim\": self.embed_dim,\n", - " \"dense_dim\": self.dense_dim,\n", - " \"num_heads\": self.num_heads,\n", - " })\n", + " config.update(\n", + " {\n", + " \"embed_dim\": self.embed_dim,\n", + " \"dense_dim\": self.dense_dim,\n", + " \"num_heads\": self.num_heads,\n", + " }\n", + " )\n", " return config\n", "\n", "\n", @@ -406,13 +416,16 @@ "\n", " def compute_mask(self, inputs, mask=None):\n", " return tf.math.not_equal(inputs, 0)\n", + "\n", " def get_config(self):\n", " config = super().get_config()\n", - " config.update({\n", - " \"sequence_length\": self.sequence_length,\n", - " \"vocab_size\": self.vocab_size,\n", - " \"embed_dim\": self.embed_dim,\n", - " })\n", + " config.update(\n", + " {\n", + " \"sequence_length\": self.sequence_length,\n", + " \"vocab_size\": self.vocab_size,\n", + " \"embed_dim\": self.embed_dim,\n", + " }\n", + " )\n", " return config\n", "\n", "\n", @@ -429,55 +442,44 @@ " num_heads=num_heads, key_dim=embed_dim\n", " )\n", " self.dense_proj = keras.Sequential(\n", - " [layers.Dense(latent_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", + " [\n", + " layers.Dense(latent_dim, activation=\"relu\"),\n", + " layers.Dense(embed_dim),\n", + " ]\n", " )\n", " self.layernorm_1 = layers.LayerNormalization()\n", " self.layernorm_2 = layers.LayerNormalization()\n", " self.layernorm_3 = layers.LayerNormalization()\n", + " self.add = layers.Add() # instead of `+` to preserve mask\n", " self.supports_masking = True\n", "\n", " def call(self, inputs, encoder_outputs, mask=None):\n", - " causal_mask = self.get_causal_attention_mask(inputs)\n", - " if mask is not None:\n", - " padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=\"int32\")\n", - " padding_mask = tf.minimum(padding_mask, causal_mask)\n", - "\n", " attention_output_1 = self.attention_1(\n", - " query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n", + " query=inputs, value=inputs, key=inputs, use_causal_mask=True\n", " )\n", - " out_1 = self.layernorm_1(inputs + attention_output_1)\n", + " out_1 = self.layernorm_1(self.add([inputs, attention_output_1]))\n", "\n", " attention_output_2 = self.attention_2(\n", " query=out_1,\n", " value=encoder_outputs,\n", " key=encoder_outputs,\n", - " attention_mask=padding_mask,\n", " )\n", - " out_2 = self.layernorm_2(out_1 + attention_output_2)\n", + " out_2 = self.layernorm_2(self.add([out_1, attention_output_2]))\n", "\n", " proj_output = self.dense_proj(out_2)\n", - " return self.layernorm_3(out_2 + proj_output)\n", - "\n", - " def get_causal_attention_mask(self, inputs):\n", - " input_shape = tf.shape(inputs)\n", - " batch_size, sequence_length = input_shape[0], input_shape[1]\n", - " i = tf.range(sequence_length)[:, tf.newaxis]\n", - " j = tf.range(sequence_length)\n", - " mask = tf.cast(i >= j, dtype=\"int32\")\n", - " mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))\n", - " mult = tf.concat(\n", - " [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],\n", - " axis=0,\n", - " )\n", - " return tf.tile(mask, mult)\n", + " return self.layernorm_3(self.add([out_2, proj_output]))\n", + "\n", " def get_config(self):\n", " config = super().get_config()\n", - " config.update({\n", - " \"embed_dim\": self.embed_dim,\n", - " \"latent_dim\": self.latent_dim,\n", - " \"num_heads\": self.num_heads,\n", - " })\n", - " return config\n" + " config.update(\n", + " {\n", + " \"embed_dim\": self.embed_dim,\n", + " \"latent_dim\": self.latent_dim,\n", + " \"num_heads\": self.num_heads,\n", + " }\n", + " )\n", + " return config\n", + "" ] }, { @@ -491,7 +493,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -537,7 +539,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -568,7 +570,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -610,19 +612,19 @@ "After 30 epochs, we get results such as:\n", "\n", "> She handed him the money.\n", - "> [start] ella le pasó el dinero [end]\n", + "> [start] ella le pas\u00f3 el dinero [end]\n", "\n", "> Tom has never heard Mary sing.\n", - "> [start] tom nunca ha oído cantar a mary [end]\n", + "> [start] tom nunca ha o\u00eddo cantar a mary [end]\n", "\n", "> Perhaps she will come tomorrow.\n", - "> [start] tal vez ella vendrá mañana [end]\n", + "> [start] tal vez ella vendr\u00e1 ma\u00f1ana [end]\n", "\n", "> I love to write.\n", "> [start] me encanta escribir [end]\n", "\n", "> His French is improving little by little.\n", - "> [start] su francés va a [UNK] sólo un poco [end]\n", + "> [start] su franc\u00e9s va a [UNK] s\u00f3lo un poco [end]\n", "\n", "> My hotel told me to call you.\n", "> [start] mi hotel me dijo que te [UNK] [end]" @@ -658,4 +660,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/examples/nlp/md/neural_machine_translation_with_transformer.md b/examples/nlp/md/neural_machine_translation_with_transformer.md index 82ddd2dfe4..3ddd67c402 100644 --- a/examples/nlp/md/neural_machine_translation_with_transformer.md +++ b/examples/nlp/md/neural_machine_translation_with_transformer.md @@ -2,7 +2,7 @@ **Author:** [fchollet](https://twitter.com/fchollet)
**Date created:** 2021/05/26
-**Last modified:** 2023/02/25
+**Last modified:** 2023/08/17
**Description:** Implementing a sequence-to-sequene Transformer and training it on a machine translation task. @@ -92,11 +92,11 @@ for _ in range(5):
``` -("You can dance, can't you?", '[start] Puedes bailar, ¿verdad? [end]') -('I passed by her house yesterday.', '[start] Me pasé por su casa ayer. [end]') -('I like tulips.', '[start] Me gustan los tulipanes. [end]') -('He is fluent in French.', '[start] Habla un francés fluido. [end]') -('Tom asked me what I had been doing.', '[start] Tom me preguntó qué había estado haciendo. [end]') +("I've got to find Tom.", '[start] Tengo que encontrar a Tom. [end]') +("I'll keep it a secret. Don't worry.", '[start] Lo mantendré en secreto. No te preocupes. [end]') +('I have good news for you.', '[start] Tengo buenas noticias para ustedes. [end]') +('I can see the top of the mountain.', '[start] Puedo ver la cima de la montaña. [end]') +("Next week, we're heading to the mountain.", '[start] La próxima semana vamos a la montaña. [end]') ```
@@ -162,7 +162,9 @@ def custom_standardization(input_string): eng_vectorization = TextVectorization( - max_tokens=vocab_size, output_mode="int", output_sequence_length=sequence_length, + max_tokens=vocab_size, + output_mode="int", + output_sequence_length=sequence_length, ) spa_vectorization = TextVectorization( max_tokens=vocab_size, @@ -195,7 +197,13 @@ it provides the next words in the target sentence -- what the model will try to def format_dataset(eng, spa): eng = eng_vectorization(eng) spa = spa_vectorization(spa) - return ({"encoder_inputs": eng, "decoder_inputs": spa[:, :-1],}, spa[:, 1:]) + return ( + { + "encoder_inputs": eng, + "decoder_inputs": spa[:, :-1], + }, + spa[:, 1:], + ) def make_dataset(pairs): @@ -245,7 +253,7 @@ to the `TransformerDecoder`, together with the target sequence so far (target wo The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond). A key detail that makes this possible is causal masking -(see method `get_causal_attention_mask()` on the `TransformerDecoder`). +(`use_causal_mask=True` in the first attention layer of the `TransformerDecoder`). The `TransformerDecoder` sees the entire sequences at once, and thus we must make sure that it only uses information from target tokens 0 to N when predicting token N+1 (otherwise, it could use information from the future, which would @@ -264,22 +272,32 @@ class TransformerEncoder(layers.Layer): num_heads=num_heads, key_dim=embed_dim ) self.dense_proj = keras.Sequential( - [layers.Dense(dense_dim, activation="relu"), layers.Dense(embed_dim),] + [ + layers.Dense(dense_dim, activation="relu"), + layers.Dense(embed_dim), + ] ) self.layernorm_1 = layers.LayerNormalization() self.layernorm_2 = layers.LayerNormalization() self.supports_masking = True def call(self, inputs, mask=None): - if mask is not None: - padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32") - attention_output = self.attention( - query=inputs, value=inputs, key=inputs, attention_mask=padding_mask - ) + attention_output = self.attention(query=inputs, value=inputs, key=inputs) proj_input = self.layernorm_1(inputs + attention_output) proj_output = self.dense_proj(proj_input) return self.layernorm_2(proj_input + proj_output) + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "dense_dim": self.dense_dim, + "num_heads": self.num_heads, + } + ) + return config + class PositionalEmbedding(layers.Layer): def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs): @@ -304,6 +322,17 @@ class PositionalEmbedding(layers.Layer): def compute_mask(self, inputs, mask=None): return tf.math.not_equal(inputs, 0) + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "vocab_size": self.vocab_size, + "embed_dim": self.embed_dim, + } + ) + return config + class TransformerDecoder(layers.Layer): def __init__(self, embed_dim, latent_dim, num_heads, **kwargs): @@ -318,47 +347,43 @@ class TransformerDecoder(layers.Layer): num_heads=num_heads, key_dim=embed_dim ) self.dense_proj = keras.Sequential( - [layers.Dense(latent_dim, activation="relu"), layers.Dense(embed_dim),] + [ + layers.Dense(latent_dim, activation="relu"), + layers.Dense(embed_dim), + ] ) self.layernorm_1 = layers.LayerNormalization() self.layernorm_2 = layers.LayerNormalization() self.layernorm_3 = layers.LayerNormalization() + self.add = layers.Add() # instead of `+` to preserve mask self.supports_masking = True def call(self, inputs, encoder_outputs, mask=None): - causal_mask = self.get_causal_attention_mask(inputs) - if mask is not None: - padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32") - padding_mask = tf.minimum(padding_mask, causal_mask) - attention_output_1 = self.attention_1( - query=inputs, value=inputs, key=inputs, attention_mask=causal_mask + query=inputs, value=inputs, key=inputs, use_causal_mask=True ) - out_1 = self.layernorm_1(inputs + attention_output_1) + out_1 = self.layernorm_1(self.add([inputs, attention_output_1])) attention_output_2 = self.attention_2( query=out_1, value=encoder_outputs, key=encoder_outputs, - attention_mask=padding_mask, ) - out_2 = self.layernorm_2(out_1 + attention_output_2) + out_2 = self.layernorm_2(self.add([out_1, attention_output_2])) proj_output = self.dense_proj(out_2) - return self.layernorm_3(out_2 + proj_output) - - def get_causal_attention_mask(self, inputs): - input_shape = tf.shape(inputs) - batch_size, sequence_length = input_shape[0], input_shape[1] - i = tf.range(sequence_length)[:, tf.newaxis] - j = tf.range(sequence_length) - mask = tf.cast(i >= j, dtype="int32") - mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) - mult = tf.concat( - [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], - axis=0, + return self.layernorm_3(self.add([out_2, proj_output])) + + def get_config(self): + config = super().get_config() + config.update( + { + "embed_dim": self.embed_dim, + "latent_dim": self.latent_dim, + "num_heads": self.num_heads, + } ) - return tf.tile(mask, mult) + return config ``` @@ -413,26 +438,31 @@ transformer.fit(train_ds, epochs=epochs, validation_data=val_ds) ``` Model: "transformer" __________________________________________________________________________________________________ -Layer (type) Output Shape Param # Connected to + Layer (type) Output Shape Param # Connected to ================================================================================================== -encoder_inputs (InputLayer) [(None, None)] 0 -__________________________________________________________________________________________________ -positional_embedding (Positiona (None, None, 256) 3845120 encoder_inputs[0][0] -__________________________________________________________________________________________________ -decoder_inputs (InputLayer) [(None, None)] 0 -__________________________________________________________________________________________________ -transformer_encoder (Transforme (None, None, 256) 3155456 positional_embedding[0][0] -__________________________________________________________________________________________________ -model_1 (Functional) (None, None, 15000) 12959640 decoder_inputs[0][0] - transformer_encoder[0][0] + encoder_inputs (InputLayer [(None, None)] 0 [] + ) + + positional_embedding (Posi (None, None, 256) 3845120 ['encoder_inputs[0][0]'] + tionalEmbedding) + + decoder_inputs (InputLayer [(None, None)] 0 [] + ) + + transformer_encoder (Trans (None, None, 256) 3155456 ['positional_embedding[0][0]'] + formerEncoder) + + model_1 (Functional) (None, None, 15000) 1295964 ['decoder_inputs[0][0]', + 0 'transformer_encoder[0][0]'] + ================================================================================================== -Total params: 19,960,216 -Trainable params: 19,960,216 -Non-trainable params: 0 +Total params: 19960216 (76.14 MB) +Trainable params: 19960216 (76.14 MB) +Non-trainable params: 0 (0.00 Byte) __________________________________________________________________________________________________ -1302/1302 [==============================] - 1297s 993ms/step - loss: 1.6495 - accuracy: 0.4284 - val_loss: 1.2843 - val_accuracy: 0.5211 +1302/1302 [==============================] - 99s 69ms/step - loss: 3.6480 - accuracy: 0.4675 - val_loss: 2.4991 - val_accuracy: 0.6096 - + ``` diff --git a/examples/nlp/neural_machine_translation_with_transformer.py b/examples/nlp/neural_machine_translation_with_transformer.py index ac518352ac..05ff450c3d 100644 --- a/examples/nlp/neural_machine_translation_with_transformer.py +++ b/examples/nlp/neural_machine_translation_with_transformer.py @@ -2,7 +2,7 @@ Title: English-to-Spanish translation with a sequence-to-sequence Transformer Author: [fchollet](https://twitter.com/fchollet) Date created: 2021/05/26 -Last modified: 2023/02/25 +Last modified: 2023/08/17 Description: Implementing a sequence-to-sequene Transformer and training it on a machine translation task. Accelerator: GPU """ @@ -210,7 +210,7 @@ def make_dataset(pairs): The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond). A key detail that makes this possible is causal masking -(see method `get_causal_attention_mask()` on the `TransformerDecoder`). +(`use_causal_mask=True` in the first attention layer of the `TransformerDecoder`). The `TransformerDecoder` sees the entire sequences at once, and thus we must make sure that it only uses information from target tokens 0 to N when predicting token N+1 (otherwise, it could use information from the future, which would @@ -238,11 +238,7 @@ def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): self.supports_masking = True def call(self, inputs, mask=None): - if mask is not None: - padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32") - attention_output = self.attention( - query=inputs, value=inputs, key=inputs, attention_mask=padding_mask - ) + attention_output = self.attention(query=inputs, value=inputs, key=inputs) proj_input = self.layernorm_1(inputs + attention_output) proj_output = self.dense_proj(proj_input) return self.layernorm_2(proj_input + proj_output) @@ -315,42 +311,24 @@ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs): self.layernorm_1 = layers.LayerNormalization() self.layernorm_2 = layers.LayerNormalization() self.layernorm_3 = layers.LayerNormalization() + self.add = layers.Add() # instead of `+` to preserve mask self.supports_masking = True def call(self, inputs, encoder_outputs, mask=None): - causal_mask = self.get_causal_attention_mask(inputs) - if mask is not None: - padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32") - padding_mask = tf.minimum(padding_mask, causal_mask) - attention_output_1 = self.attention_1( - query=inputs, value=inputs, key=inputs, attention_mask=causal_mask + query=inputs, value=inputs, key=inputs, use_causal_mask=True ) - out_1 = self.layernorm_1(inputs + attention_output_1) + out_1 = self.layernorm_1(self.add([inputs, attention_output_1])) attention_output_2 = self.attention_2( query=out_1, value=encoder_outputs, key=encoder_outputs, - attention_mask=padding_mask, ) - out_2 = self.layernorm_2(out_1 + attention_output_2) + out_2 = self.layernorm_2(self.add([out_1, attention_output_2])) proj_output = self.dense_proj(out_2) - return self.layernorm_3(out_2 + proj_output) - - def get_causal_attention_mask(self, inputs): - input_shape = tf.shape(inputs) - batch_size, sequence_length = input_shape[0], input_shape[1] - i = tf.range(sequence_length)[:, tf.newaxis] - j = tf.range(sequence_length) - mask = tf.cast(i >= j, dtype="int32") - mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) - mult = tf.concat( - [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], - axis=0, - ) - return tf.tile(mask, mult) + return self.layernorm_3(self.add([out_2, proj_output])) def get_config(self): config = super().get_config()