From ceef34d3f63bad566e94ade2440fb4db4065bdda Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 1 Aug 2023 10:05:04 -0700 Subject: [PATCH] Update KerasNLP getting started guide for multi-backend keras (#1456) * Update the getting started guide for multi-backend keras * Address comments --- guides/ipynb/keras_nlp/getting_started.ipynb | 128 ++-- guides/keras_nlp/getting_started.py | 64 +- guides/md/keras_nlp/getting_started.md | 612 ++++++++++--------- 3 files changed, 436 insertions(+), 368 deletions(-) diff --git a/guides/ipynb/keras_nlp/getting_started.ipynb b/guides/ipynb/keras_nlp/getting_started.ipynb index 3ae5fa8b81..e0d73ffc34 100644 --- a/guides/ipynb/keras_nlp/getting_started.ipynb +++ b/guides/ipynb/keras_nlp/getting_started.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -16,7 +15,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -27,14 +25,19 @@ "KerasNLP is a natural language processing library that supports users through\n", "their entire development cycle. Our workflows are built from modular components\n", "that have state-of-the-art preset weights and architectures when used\n", - "out-of-the-box and are easily customizable when more control is needed. We\n", - "emphasize in-graph computation for all workflows so that developers can expect\n", - "easy productionization using the TensorFlow ecosystem.\n", + "out-of-the-box and are easily customizable when more control is needed.\n", "\n", "This library is an extension of the core Keras API; all high-level modules are\n", "[`Layers`](/api/layers/) or [`Models`](/api/models/). If you are familiar with Keras,\n", "congratulations! You already understand most of KerasNLP.\n", "\n", + "KerasNLP uses the [Keras Core](https://keras.io/keras_core/) library to work\n", + "with any of TensorFlow, Pytorch and Jax. In the guide below, we will use the\n", + "`jax` backend for training our models, and [tf.data](https://www.tensorflow.org/guide/data)\n", + "for efficiently running our input preprocessing. But feel free to mix things up!\n", + "This guide runs in TensorFlow or PyTorch backends with zero changes, simply update\n", + "the `KERAS_BACKEND` below.\n", + "\n", "This guide demonstrates our modular approach using a sentiment analysis example at six\n", "levels of complexity:\n", "\n", @@ -53,33 +56,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ - "!pip install -q --upgrade keras-nlp tensorflow" + "!pip install -q --upgrade keras-nlp" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", + "import os\n", "\n", - "# Use mixed precision for optimal performance\n", - "keras.mixed_precision.set_global_policy(\"mixed_float16\")" + "os.environ[\"KERAS_BACKEND\"] = \"jax\" # or \"tensorflow\" or \"torch\"\n", + "\n", + "import keras_nlp\n", + "import keras_core as keras" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -93,7 +95,7 @@ "modules:\n", "\n", "* **Tokenizer**: `keras_nlp.models.XXTokenizer`\n", - " * **What it does**: Converts strings to `tf.RaggedTensor`s of token ids.\n", + " * **What it does**: Converts strings to sequences of token ids.\n", " * **Why it's important**: The raw bytes of a string are too high dimensional to be useful\n", " features so we first map them to a small number of tokens, for example `\"The quick brown\n", " fox\"` to `[\"the\", \"qu\", \"##ick\", \"br\", \"##own\", \"fox\"]`.\n", @@ -134,7 +136,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -152,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -166,29 +167,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "BATCH_SIZE = 16\n", - "imdb_train = tf.keras.utils.text_dataset_from_directory(\n", + "imdb_train = keras.utils.text_dataset_from_directory(\n", " \"aclImdb/train\",\n", " batch_size=BATCH_SIZE,\n", ")\n", - "imdb_test = tf.keras.utils.text_dataset_from_directory(\n", + "imdb_test = keras.utils.text_dataset_from_directory(\n", " \"aclImdb/test\",\n", " batch_size=BATCH_SIZE,\n", ")\n", "\n", "# Inspect first review\n", "# Format is (review text tensor, label tensor)\n", - "print(imdb_train.unbatch().take(1).get_single_element())" + "print(imdb_train.unbatch().take(1).get_single_element())\n", + "" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -208,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -220,7 +221,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -256,7 +256,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -266,7 +265,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -295,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -313,7 +311,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -324,7 +321,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -343,7 +339,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -358,25 +353,32 @@ "In this workflow we train the model over three epochs using `tf.data.Dataset.cache()`,\n", "which computes the preprocessing once and caches the result before fitting begins.\n", "\n", - "**Note:** this code only works if your data fits in memory. If not, pass a `filename` to\n", - "`cache()`." + "**Note:** we can use `tf.data` for preprocessing while running on the\n", + "Jax or PyTorch backend. The input dataset will automatically be converted to\n", + "backend native tensor types during fit. In fact, given the efficiency of `tf.data`\n", + "for running preprocessing, this is good practice on all backends." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ + "import tensorflow as tf\n", + "\n", "preprocessor = keras_nlp.models.BertPreprocessor.from_preset(\n", " \"bert_tiny_en_uncased\",\n", " sequence_length=512,\n", ")\n", + "\n", "# Apply the preprocessor to every sample of train and test data using `map()`.\n", "# `tf.data.AUTOTUNE` and `prefetch()` are options to tune performance, see\n", "# https://www.tensorflow.org/guide/data_performance for details.\n", + "\n", + "# Note: only call `cache()` if you training data fits in CPU memory!\n", "imdb_train_cached = (\n", " imdb_train.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)\n", ")\n", @@ -395,7 +397,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -408,7 +409,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -421,12 +421,14 @@ "constructor to get the vocabulary matching pretraining.\n", "\n", "**Note:** `BertTokenizer` does not pad sequences by default, so the output is\n", - "a `tf.RaggedTensor`." + "ragged (each sequence has varying length). The `MultiSegmentPacker` below\n", + "handles padding these ragged sequences to dense tensor types (e.g. `tf.Tensor`\n", + "or `torch.Tensor`)." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -470,7 +472,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -496,7 +497,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -527,8 +528,8 @@ "model = keras.Model(inputs, outputs)\n", "model.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " optimizer=keras.optimizers.experimental.AdamW(5e-5),\n", - " metrics=keras.metrics.SparseCategoricalAccuracy(),\n", + " optimizer=keras.optimizers.AdamW(5e-5),\n", + " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", " jit_compile=True,\n", ")\n", "model.summary()\n", @@ -540,7 +541,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -552,7 +552,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -582,7 +581,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -593,7 +591,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -648,7 +646,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -659,7 +656,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -680,10 +677,10 @@ ")\n", "\n", "inputs = {\n", - " \"token_ids\": keras.Input(shape=(None,), dtype=tf.int32),\n", - " \"segment_ids\": keras.Input(shape=(None,), dtype=tf.int32),\n", - " \"padding_mask\": keras.Input(shape=(None,), dtype=tf.int32),\n", - " \"mask_positions\": keras.Input(shape=(None,), dtype=tf.int32),\n", + " \"token_ids\": keras.Input(shape=(None,), dtype=tf.int32, name=\"token_ids\"),\n", + " \"segment_ids\": keras.Input(shape=(None,), dtype=tf.int32, name=\"segment_ids\"),\n", + " \"padding_mask\": keras.Input(shape=(None,), dtype=tf.int32, name=\"padding_mask\"),\n", + " \"mask_positions\": keras.Input(shape=(None,), dtype=tf.int32, name=\"mask_positions\"),\n", "}\n", "\n", "# Encoded token sequence\n", @@ -692,15 +689,15 @@ "# Predict an output word for each masked input token.\n", "# We use the input token embedding to project from our encoded vectors to\n", "# vocabulary logits, which has been shown to improve training efficiency.\n", - "outputs = mlm_head(sequence, mask_positions=inputs[\"mask_positions\"])\n", + "outputs = mlm_head(sequence, masked_positions=inputs[\"mask_positions\"])\n", "\n", "# Define and compile our pretraining model.\n", "pretraining_model = keras.Model(inputs, outputs)\n", "pretraining_model.summary()\n", "pretraining_model.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " optimizer=keras.optimizers.experimental.AdamW(learning_rate=5e-4),\n", - " weighted_metrics=keras.metrics.SparseCategoricalAccuracy(),\n", + " optimizer=keras.optimizers.AdamW(learning_rate=5e-4),\n", + " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", " jit_compile=True,\n", ")\n", "\n", @@ -713,7 +710,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -723,7 +719,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -745,7 +740,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -756,7 +750,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -778,7 +772,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -789,7 +782,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -819,7 +812,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -830,7 +822,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -862,7 +854,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -873,7 +864,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -881,8 +872,8 @@ "source": [ "model.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " optimizer=keras.optimizers.experimental.AdamW(5e-5),\n", - " metrics=keras.metrics.SparseCategoricalAccuracy(),\n", + " optimizer=keras.optimizers.AdamW(5e-5),\n", + " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", " jit_compile=True,\n", ")\n", "model.fit(\n", @@ -893,7 +884,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text" @@ -934,4 +924,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/guides/keras_nlp/getting_started.py b/guides/keras_nlp/getting_started.py index b75cb85e6e..22dd256179 100644 --- a/guides/keras_nlp/getting_started.py +++ b/guides/keras_nlp/getting_started.py @@ -12,14 +12,19 @@ KerasNLP is a natural language processing library that supports users through their entire development cycle. Our workflows are built from modular components that have state-of-the-art preset weights and architectures when used -out-of-the-box and are easily customizable when more control is needed. We -emphasize in-graph computation for all workflows so that developers can expect -easy productionization using the TensorFlow ecosystem. +out-of-the-box and are easily customizable when more control is needed. This library is an extension of the core Keras API; all high-level modules are [`Layers`](/api/layers/) or [`Models`](/api/models/). If you are familiar with Keras, congratulations! You already understand most of KerasNLP. +KerasNLP uses the [Keras Core](https://keras.io/keras_core/) library to work +with any of TensorFlow, Pytorch and Jax. In the guide below, we will use the +`jax` backend for training our models, and [tf.data](https://www.tensorflow.org/guide/data) +for efficiently running our input preprocessing. But feel free to mix things up! +This guide runs in TensorFlow or PyTorch backends with zero changes, simply update +the `KERAS_BACKEND` below. + This guide demonstrates our modular approach using a sentiment analysis example at six levels of complexity: @@ -37,15 +42,15 @@ """ """shell -pip install -q --upgrade keras-nlp tensorflow +pip install -q --upgrade keras-nlp """ -import keras_nlp -import tensorflow as tf -from tensorflow import keras +import os -# Use mixed precision for optimal performance -keras.mixed_precision.set_global_policy("mixed_float16") +os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch" + +import keras_nlp +import keras_core as keras """ ## API quickstart @@ -56,7 +61,7 @@ modules: * **Tokenizer**: `keras_nlp.models.XXTokenizer` - * **What it does**: Converts strings to `tf.RaggedTensor`s of token ids. + * **What it does**: Converts strings to sequences of token ids. * **Why it's important**: The raw bytes of a string are too high dimensional to be useful features so we first map them to a small number of tokens, for example `"The quick brown fox"` to `["the", "qu", "##ick", "br", "##own", "fox"]`. @@ -115,11 +120,11 @@ """ BATCH_SIZE = 16 -imdb_train = tf.keras.utils.text_dataset_from_directory( +imdb_train = keras.utils.text_dataset_from_directory( "aclImdb/train", batch_size=BATCH_SIZE, ) -imdb_test = tf.keras.utils.text_dataset_from_directory( +imdb_test = keras.utils.text_dataset_from_directory( "aclImdb/test", batch_size=BATCH_SIZE, ) @@ -231,17 +236,24 @@ In this workflow we train the model over three epochs using `tf.data.Dataset.cache()`, which computes the preprocessing once and caches the result before fitting begins. -**Note:** this code only works if your data fits in memory. If not, pass a `filename` to -`cache()`. +**Note:** we can use `tf.data` for preprocessing while running on the +Jax or PyTorch backend. The input dataset will automatically be converted to +backend native tensor types during fit. In fact, given the efficiency of `tf.data` +for running preprocessing, this is good practice on all backends. """ +import tensorflow as tf + preprocessor = keras_nlp.models.BertPreprocessor.from_preset( "bert_tiny_en_uncased", sequence_length=512, ) + # Apply the preprocessor to every sample of train and test data using `map()`. # `tf.data.AUTOTUNE` and `prefetch()` are options to tune performance, see # https://www.tensorflow.org/guide/data_performance for details. + +# Note: only call `cache()` if you training data fits in CPU memory! imdb_train_cached = ( imdb_train.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) ) @@ -273,7 +285,9 @@ constructor to get the vocabulary matching pretraining. **Note:** `BertTokenizer` does not pad sequences by default, so the output is -a `tf.RaggedTensor`. +ragged (each sequence has varying length). The `MultiSegmentPacker` below +handles padding these ragged sequences to dense tensor types (e.g. `tf.Tensor` +or `torch.Tensor`). """ tokenizer = keras_nlp.models.BertTokenizer.from_preset("bert_tiny_en_uncased") @@ -356,8 +370,8 @@ def preprocessor(x, y): model = keras.Model(inputs, outputs) model.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.experimental.AdamW(5e-5), - metrics=keras.metrics.SparseCategoricalAccuracy(), + optimizer=keras.optimizers.AdamW(5e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) model.summary() @@ -467,10 +481,10 @@ def preprocess(inputs, label): ) inputs = { - "token_ids": keras.Input(shape=(None,), dtype=tf.int32), - "segment_ids": keras.Input(shape=(None,), dtype=tf.int32), - "padding_mask": keras.Input(shape=(None,), dtype=tf.int32), - "mask_positions": keras.Input(shape=(None,), dtype=tf.int32), + "token_ids": keras.Input(shape=(None,), dtype=tf.int32, name="token_ids"), + "segment_ids": keras.Input(shape=(None,), dtype=tf.int32, name="segment_ids"), + "padding_mask": keras.Input(shape=(None,), dtype=tf.int32, name="padding_mask"), + "mask_positions": keras.Input(shape=(None,), dtype=tf.int32, name="mask_positions"), } # Encoded token sequence @@ -486,8 +500,8 @@ def preprocess(inputs, label): pretraining_model.summary() pretraining_model.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.experimental.AdamW(learning_rate=5e-4), - weighted_metrics=keras.metrics.SparseCategoricalAccuracy(), + optimizer=keras.optimizers.AdamW(learning_rate=5e-4), + weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) @@ -597,8 +611,8 @@ def preprocess(x, y): model.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.experimental.AdamW(5e-5), - metrics=keras.metrics.SparseCategoricalAccuracy(), + optimizer=keras.optimizers.AdamW(5e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) model.fit( diff --git a/guides/md/keras_nlp/getting_started.md b/guides/md/keras_nlp/getting_started.md index 9fe3a896d9..9d1d4c01a6 100644 --- a/guides/md/keras_nlp/getting_started.md +++ b/guides/md/keras_nlp/getting_started.md @@ -16,14 +16,19 @@ KerasNLP is a natural language processing library that supports users through their entire development cycle. Our workflows are built from modular components that have state-of-the-art preset weights and architectures when used -out-of-the-box and are easily customizable when more control is needed. We -emphasize in-graph computation for all workflows so that developers can expect -easy productionization using the TensorFlow ecosystem. +out-of-the-box and are easily customizable when more control is needed. This library is an extension of the core Keras API; all high-level modules are [`Layers`](/api/layers/) or [`Models`](/api/models/). If you are familiar with Keras, congratulations! You already understand most of KerasNLP. +KerasNLP uses the [Keras Core](https://keras.io/keras_core/) library to work +with any of TensorFlow, Pytorch and Jax. In the guide below, we will use the +`jax` backend for training our models, and [tf.data](https://www.tensorflow.org/guide/data) +for efficiently running our input preprocessing. But feel free to mix things up! +This guide runs in TensorFlow or PyTorch backends with zero changes, simply update +the `KERAS_BACKEND` below. + This guide demonstrates our modular approach using a sentiment analysis example at six levels of complexity: @@ -41,23 +46,22 @@ reference for the complexity of the material: ```python -!pip install -q --upgrade keras-nlp tensorflow +!pip install -q --upgrade keras-nlp ``` + ```python -import keras_nlp -import tensorflow as tf -from tensorflow import keras +import os + +os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch" -# Use mixed precision for optimal performance -keras.mixed_precision.set_global_policy("mixed_float16") +import keras_nlp +import keras_core as keras ``` +
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ +┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ +│ padding_mask │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ segment_ids │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ token_ids │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ bert_backbone_3 │ [(None, 128), │ 4,385,… │ padding_mask[0][0], │ +│ (BertBackbone) │ (None, None, │ │ segment_ids[0][0], │ +│ │ 128)] │ │ token_ids[0][0] │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ transformer_encoder │ (None, None, 128) │ 198,272 │ bert_backbone_3[0][… │ +│ (TransformerEncode… │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ transformer_encode… │ (None, None, 128) │ 198,272 │ transformer_encoder… │ +│ (TransformerEncode… │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ get_item_4 │ (None, 128) │ 0 │ transformer_encoder… │ +│ (GetItem) │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ dense_20 (Dense) │ (None, 2) │ 258 │ get_item_4[0][0] │ +└─────────────────────┴───────────────────┴─────────┴──────────────────────┘ ++ + + + +
Total params: 4,782,722 (145.96 MB) ++ + + + +
Trainable params: 396,802 (12.11 MB) ++ + + + +
Non-trainable params: 4,385,920 (133.85 MB) ++ + +
Model: "functional_3"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ +┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ +│ mask_positions │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ padding_mask │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ segment_ids │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ token_ids │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ bert_backbone_4 │ [(None, 128), │ 4,385,… │ mask_positions[0][0… │ +│ (BertBackbone) │ (None, None, │ │ padding_mask[0][0], │ +│ │ 128)] │ │ segment_ids[0][0], │ +│ │ │ │ token_ids[0][0] │ +├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ +│ masked_lm_head │ (None, 30522) │ 3,954,… │ bert_backbone_4[0][… │ +│ (MaskedLMHead) │ │ │ mask_positions[0][0] │ +└─────────────────────┴───────────────────┴─────────┴──────────────────────┘ ++ + + + +
Total params: 4,433,210 (135.29 MB) ++ + + + +
Trainable params: 4,433,210 (135.29 MB) ++ + + + +
Non-trainable params: 0 (0.00 B) ++ + +
Model: "functional_5"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ +│ token_ids (InputLayer) │ (None, None) │ 0 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ token_and_position_embedding │ (None, None, 64) │ 1,259,648 │ +│ (TokenAndPositionEmbedding) │ │ │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ transformer_encoder_2 │ (None, None, 64) │ 33,472 │ +│ (TransformerEncoder) │ │ │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ get_item_6 (GetItem) │ (None, 64) │ 0 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ dense_28 (Dense) │ (None, 2) │ 130 │ +└─────────────────────────────────┴───────────────────────────┴────────────┘ ++ + + + +
Total params: 1,293,250 (39.47 MB) ++ + + + +
Trainable params: 1,293,250 (39.47 MB) ++ + + + +
Non-trainable params: 0 (0.00 B) ++ + + ### Train the transformer directly on the classification objective ```python model.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.experimental.AdamW(5e-5), - metrics=keras.metrics.SparseCategoricalAccuracy(), + optimizer=keras.optimizers.AdamW(5e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) model.fit( @@ -981,13 +1045,13 @@ model.fit(