Skip to content

Commit

Permalink
Fix masks in English-to-Spanish translation example (#1046)
Browse files Browse the repository at this point in the history
* fix masks for machine translation example

* black and modified date

* remove manual masking and rely on automatic masks instead

* replace + by Add layer to preserve mask

* regenerate ipynb and md files

* fix black
  • Loading branch information
nikoladze committed Aug 17, 2023
1 parent a42469e commit 0b12c5f
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 153 deletions.
140 changes: 71 additions & 69 deletions examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
"**Date created:** 2021/05/26<br>\n",
"**Last modified:** 2023/02/25<br>\n",
"**Last modified:** 2023/08/17<br>\n",
"**Description:** Implementing a sequence-to-sequene Transformer and training it on a machine translation task."
]
},
Expand Down Expand Up @@ -53,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -84,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -113,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -139,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -161,7 +161,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -196,7 +196,7 @@
"The English layer will use the default string standardization (strip punctuation characters)\n",
"and splitting scheme (split on whitespace), while\n",
"the Spanish layer will use a custom standardization, where we add the character\n",
"`\"¿\"` to the set of punctuation characters to be stripped.\n",
"`\"\u00bf\"` to the set of punctuation characters to be stripped.\n",
"\n",
"Note: in a production-grade machine translation model, I would not recommend\n",
"stripping the punctuation characters in either language. Instead, I would recommend turning\n",
Expand All @@ -206,13 +206,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"strip_chars = string.punctuation + \"¿\"\n",
"strip_chars = string.punctuation + \"\u00bf\"\n",
"strip_chars = strip_chars.replace(\"[\", \"\")\n",
"strip_chars = strip_chars.replace(\"]\", \"\")\n",
"\n",
Expand All @@ -227,7 +227,9 @@
"\n",
"\n",
"eng_vectorization = TextVectorization(\n",
" max_tokens=vocab_size, output_mode=\"int\", output_sequence_length=sequence_length,\n",
" max_tokens=vocab_size,\n",
" output_mode=\"int\",\n",
" output_sequence_length=sequence_length,\n",
")\n",
"spa_vectorization = TextVectorization(\n",
" max_tokens=vocab_size,\n",
Expand Down Expand Up @@ -263,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -273,7 +275,13 @@
"def format_dataset(eng, spa):\n",
" eng = eng_vectorization(eng)\n",
" spa = spa_vectorization(spa)\n",
" return ({\"encoder_inputs\": eng, \"decoder_inputs\": spa[:, :-1],}, spa[:, 1:])\n",
" return (\n",
" {\n",
" \"encoder_inputs\": eng,\n",
" \"decoder_inputs\": spa[:, :-1],\n",
" },\n",
" spa[:, 1:],\n",
" )\n",
"\n",
"\n",
"def make_dataset(pairs):\n",
Expand Down Expand Up @@ -302,7 +310,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -333,7 +341,7 @@
"The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond).\n",
"\n",
"A key detail that makes this possible is causal masking\n",
"(see method `get_causal_attention_mask()` on the `TransformerDecoder`).\n",
"(`use_causal_mask=True` in the first attention layer of the `TransformerDecoder`).\n",
"The `TransformerDecoder` sees the entire sequences at once, and thus we must make\n",
"sure that it only uses information from target tokens 0 to N when predicting token N+1\n",
"(otherwise, it could use information from the future, which would\n",
Expand All @@ -342,7 +350,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -359,28 +367,30 @@
" num_heads=num_heads, key_dim=embed_dim\n",
" )\n",
" self.dense_proj = keras.Sequential(\n",
" [layers.Dense(dense_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
" [\n",
" layers.Dense(dense_dim, activation=\"relu\"),\n",
" layers.Dense(embed_dim),\n",
" ]\n",
" )\n",
" self.layernorm_1 = layers.LayerNormalization()\n",
" self.layernorm_2 = layers.LayerNormalization()\n",
" self.supports_masking = True\n",
"\n",
" def call(self, inputs, mask=None):\n",
" if mask is not None:\n",
" padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=\"int32\")\n",
" attention_output = self.attention(\n",
" query=inputs, value=inputs, key=inputs, attention_mask=padding_mask\n",
" )\n",
" attention_output = self.attention(query=inputs, value=inputs, key=inputs)\n",
" proj_input = self.layernorm_1(inputs + attention_output)\n",
" proj_output = self.dense_proj(proj_input)\n",
" return self.layernorm_2(proj_input + proj_output)\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
" config.update({\n",
" \"embed_dim\": self.embed_dim,\n",
" \"dense_dim\": self.dense_dim,\n",
" \"num_heads\": self.num_heads,\n",
" })\n",
" config.update(\n",
" {\n",
" \"embed_dim\": self.embed_dim,\n",
" \"dense_dim\": self.dense_dim,\n",
" \"num_heads\": self.num_heads,\n",
" }\n",
" )\n",
" return config\n",
"\n",
"\n",
Expand All @@ -406,13 +416,16 @@
"\n",
" def compute_mask(self, inputs, mask=None):\n",
" return tf.math.not_equal(inputs, 0)\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
" config.update({\n",
" \"sequence_length\": self.sequence_length,\n",
" \"vocab_size\": self.vocab_size,\n",
" \"embed_dim\": self.embed_dim,\n",
" })\n",
" config.update(\n",
" {\n",
" \"sequence_length\": self.sequence_length,\n",
" \"vocab_size\": self.vocab_size,\n",
" \"embed_dim\": self.embed_dim,\n",
" }\n",
" )\n",
" return config\n",
"\n",
"\n",
Expand All @@ -429,55 +442,44 @@
" num_heads=num_heads, key_dim=embed_dim\n",
" )\n",
" self.dense_proj = keras.Sequential(\n",
" [layers.Dense(latent_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
" [\n",
" layers.Dense(latent_dim, activation=\"relu\"),\n",
" layers.Dense(embed_dim),\n",
" ]\n",
" )\n",
" self.layernorm_1 = layers.LayerNormalization()\n",
" self.layernorm_2 = layers.LayerNormalization()\n",
" self.layernorm_3 = layers.LayerNormalization()\n",
" self.add = layers.Add() # instead of `+` to preserve mask\n",
" self.supports_masking = True\n",
"\n",
" def call(self, inputs, encoder_outputs, mask=None):\n",
" causal_mask = self.get_causal_attention_mask(inputs)\n",
" if mask is not None:\n",
" padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=\"int32\")\n",
" padding_mask = tf.minimum(padding_mask, causal_mask)\n",
"\n",
" attention_output_1 = self.attention_1(\n",
" query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n",
" query=inputs, value=inputs, key=inputs, use_causal_mask=True\n",
" )\n",
" out_1 = self.layernorm_1(inputs + attention_output_1)\n",
" out_1 = self.layernorm_1(self.add([inputs, attention_output_1]))\n",
"\n",
" attention_output_2 = self.attention_2(\n",
" query=out_1,\n",
" value=encoder_outputs,\n",
" key=encoder_outputs,\n",
" attention_mask=padding_mask,\n",
" )\n",
" out_2 = self.layernorm_2(out_1 + attention_output_2)\n",
" out_2 = self.layernorm_2(self.add([out_1, attention_output_2]))\n",
"\n",
" proj_output = self.dense_proj(out_2)\n",
" return self.layernorm_3(out_2 + proj_output)\n",
"\n",
" def get_causal_attention_mask(self, inputs):\n",
" input_shape = tf.shape(inputs)\n",
" batch_size, sequence_length = input_shape[0], input_shape[1]\n",
" i = tf.range(sequence_length)[:, tf.newaxis]\n",
" j = tf.range(sequence_length)\n",
" mask = tf.cast(i >= j, dtype=\"int32\")\n",
" mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))\n",
" mult = tf.concat(\n",
" [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],\n",
" axis=0,\n",
" )\n",
" return tf.tile(mask, mult)\n",
" return self.layernorm_3(self.add([out_2, proj_output]))\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
" config.update({\n",
" \"embed_dim\": self.embed_dim,\n",
" \"latent_dim\": self.latent_dim,\n",
" \"num_heads\": self.num_heads,\n",
" })\n",
" return config\n"
" config.update(\n",
" {\n",
" \"embed_dim\": self.embed_dim,\n",
" \"latent_dim\": self.latent_dim,\n",
" \"num_heads\": self.num_heads,\n",
" }\n",
" )\n",
" return config\n",
""
]
},
{
Expand All @@ -491,7 +493,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -537,7 +539,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -568,7 +570,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -610,19 +612,19 @@
"After 30 epochs, we get results such as:\n",
"\n",
"> She handed him the money.\n",
"> [start] ella le pasó el dinero [end]\n",
"> [start] ella le pas\u00f3 el dinero [end]\n",
"\n",
"> Tom has never heard Mary sing.\n",
"> [start] tom nunca ha oído cantar a mary [end]\n",
"> [start] tom nunca ha o\u00eddo cantar a mary [end]\n",
"\n",
"> Perhaps she will come tomorrow.\n",
"> [start] tal vez ella vendrá mañana [end]\n",
"> [start] tal vez ella vendr\u00e1 ma\u00f1ana [end]\n",
"\n",
"> I love to write.\n",
"> [start] me encanta escribir [end]\n",
"\n",
"> His French is improving little by little.\n",
"> [start] su francés va a [UNK] sólo un poco [end]\n",
"> [start] su franc\u00e9s va a [UNK] s\u00f3lo un poco [end]\n",
"\n",
"> My hotel told me to call you.\n",
"> [start] mi hotel me dijo que te [UNK] [end]"
Expand Down Expand Up @@ -658,4 +660,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit 0b12c5f

Please sign in to comment.