diff --git a/examples/nlp/abstractive_summarization_with_bart.py b/examples/nlp/abstractive_summarization_with_bart.py new file mode 100644 index 0000000000..f712383ee4 --- /dev/null +++ b/examples/nlp/abstractive_summarization_with_bart.py @@ -0,0 +1,240 @@ +""" +Title: Abstractive Text Summarization with BART +Author: [Abheesht Sharma](https://github.com/abheesht17/) +Date created: 2023/07/08 +Last modified: 2023/07/08 +Description: Use KerasNLP to fine-tune BART on the abstractive summarization task. +Accelerator: GPU +""" + +""" +## Introduction + +In the era of information overload, it has become crucial to extract the crux +of a long document or a conversation and express it in a few sentences. Owing +to the fact that summarization has widespread applications in different domains, +it has become a key, well-studied NLP task in recent years. + +[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461) +is a Transformer-based encoder-decoder model, often used for +sequence-to-sequence tasks like summarization and neural machine translation. +BART is pre-trained in a self-supervised fashion on a large text corpus. During +pre-training, the text is corrupted and BART is trained to reconstruct the +original text (hence called a "denoising autoencoder"). Some pre-training tasks +include token masking, token deletion, sentence permutation (shuffle sentences +and train BART to fix the order), etc. + +In this example, we will demonstrate how to fine-tune BART on the abstractive +summarization task (on conversations!) using KerasNLP, and generate summaries +using the fine-tuned model. +""" + +""" +## Setup + +Before we start implementing the pipeline, let's install and import all the +libraries we need. We'll be using the KerasNLP library. We will also need a +couple of utility libraries. +""" + +"""shell +pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q +""" + +""" +This examples uses [Keras Core](https://keras.io/keras_core/) to work in any of +`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras Core is baked into +KerasNLP, simply change the `"KERAS_BACKEND"` environment variable to select +the backend of your choice. We select the JAX backend below. +""" + +import os + +os.environ["KERAS_BACKEND"] = "jax" + +""" +Import all necessary libraries. +""" + +import py7zr +import time + +import keras_nlp +import tensorflow as tf +import tensorflow_datasets as tfds + +import keras_core as keras + +""" +Let's also define our hyperparameters. +""" + +BATCH_SIZE = 8 +NUM_BATCHES = 600 +EPOCHS = 1 # Can be set to a higher value for better results +MAX_ENCODER_SEQUENCE_LENGTH = 512 +MAX_DECODER_SEQUENCE_LENGTH = 128 +MAX_GENERATION_LENGTH = 40 + +""" +## Dataset + +Let's load the [SAMSum dataset](https://arxiv.org/abs/1911.12237). This dataset +contains around 15,000 pairs of conversations/dialogues and summaries. +""" + +# Download the dataset. +filename = keras.utils.get_file( + "corpus.7z", + origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z", +) + +# Extract the `.7z` file. +with py7zr.SevenZipFile(filename, mode="r") as z: + z.extractall(path="/root/tensorflow_datasets/downloads/manual") + +# Load data using TFDS. +samsum_ds = tfds.load("samsum", split="train", as_supervised=True) + +""" +The dataset has two fields: `dialogue` and `summary`. Let's see a sample. +""" +for dialogue, summary in samsum_ds: + print(dialogue.numpy()) + print(summary.numpy()) + break + +""" +We'll now batch the dataset and retain only a subset of the dataset for the +purpose of this example. The dialogue is fed to the encoder, and the +corresponding summary serves as input to the decoder. We will, therefore, change +the format of the dataset to a dictionary having two keys: `"encoder_text"` and +`"decoder_text"`.This is how `keras_nlp.models.BartSeq2SeqLMPreprocessor` +expects the input format to be. +""" + +train_ds = ( + samsum_ds.map( + lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary} + ) + .batch(BATCH_SIZE) + .cache() +) +train_ds = train_ds.take(NUM_BATCHES) + +""" +## Fine-tune BART + +Let's load the model and preprocessor first. We use sequence lengths of 512 +and 128 for the encoder and decoder, respectively, instead of 1024 (which is the +default sequence length). This will allow us to run this example quickly +on Colab. + +If you observe carefully, the preprocessor is attached to the model. What this +means is that we don't have to worry about preprocessing the text inputs; +everything will be done internally. The preprocessor tokenizes the encoder text +and the decoder text, adds special tokens and pads them. To generate labels +for auto-regressive training, the preprocessor shifts the decoder text one +position to the right. This is done because at every timestep, the model is +trained to predict the next token. +""" + +preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset( + "bart_base_en", + encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH, + decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH, +) +bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset( + "bart_base_en", preprocessor=preprocessor +) + +bart_lm.summary() + +""" +Define the optimizer and loss. We use the Adam optimizer with a linearly +decaying learning rate. Compile the model. +""" + +optimizer = keras.optimizers.AdamW( + learning_rate=5e-5, + weight_decay=0.01, + epsilon=1e-6, + global_clipnorm=1.0, # Gradient clipping. +) +# Exclude layernorm and bias terms from weight decay. +optimizer.exclude_from_weight_decay(var_names=["bias"]) +optimizer.exclude_from_weight_decay(var_names=["gamma"]) +optimizer.exclude_from_weight_decay(var_names=["beta"]) + +loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + +bart_lm.compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=["accuracy"], +) + +""" +Let's train the model! +""" + +bart_lm.fit(train_ds, epochs=EPOCHS) + +""" +## Generate summaries and evaluate them! + +Now that the model has been trained, let's get to the fun part - actually +generating summaries! Let's pick the first 100 samples from the validation set +and generate summaries for them. We will use the default decoding strategy, i.e., +greedy search. + +Generation in KerasNLP is highly optimized. It is backed by the power of XLA. +Secondly, key/value tensors in the self-attention layer and cross-attention layer +in the decoder are cached to avoid recomputation at every timestep. +""" + + +def generate_text(model, input_text, max_length=200, print_time_taken=False): + start = time.time() + output = model.generate(input_text, max_length=max_length) + end = time.time() + print(f"Total Time Elapsed: {end - start:.2f}s") + return output + + +# Load the dataset. +val_ds = tfds.load("samsum", split="validation", as_supervised=True) +val_ds = val_ds.take(100) + +dialogues = [] +ground_truth_summaries = [] +for dialogue, summary in val_ds: + dialogues.append(dialogue.numpy()) + ground_truth_summaries.append(summary.numpy()) + +# Let's make a dummy call - the first call to XLA generally takes a bit longer. +_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH) + +# Generate summaries. +generated_summaries = generate_text( + bart_lm, + val_ds.map(lambda dialogue, _: dialogue).batch(8), + max_length=MAX_GENERATION_LENGTH, + print_time_taken=True, +) + +""" +Let's see some of the summaries. +""" +for dialogue, generated_summary, ground_truth_summary in zip( + dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5] +): + print("Dialogue:", dialogue) + print("Generated Summary:", generated_summary) + print("Ground Truth Summary:", ground_truth_summary) + print("=============================") + +""" +The generated summaries look awesome! Not bad for a model trained only for 1 +epoch and on 5000 examples :) +""" diff --git a/examples/nlp/ipynb/abstractive_summarization_with_bart.ipynb b/examples/nlp/ipynb/abstractive_summarization_with_bart.ipynb new file mode 100644 index 0000000000..22feb911e9 --- /dev/null +++ b/examples/nlp/ipynb/abstractive_summarization_with_bart.ipynb @@ -0,0 +1,457 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# Abstractive Text Summarization with BART\n", + "\n", + "**Author:** [Abheesht Sharma](https://github.com/abheesht17/)
\n", + "**Date created:** 2023/07/08
\n", + "**Last modified:** 2023/07/08
\n", + "**Description:** Use KerasNLP to fine-tune BART on the abstractive summarization task." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Introduction\n", + "\n", + "In the era of information overload, it has become crucial to extract the crux\n", + "of a long document or a conversation and express it in a few sentences. Owing\n", + "to the fact that summarization has widespread applications in different domains,\n", + "it has become a key, well-studied NLP task in recent years.\n", + "\n", + "[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461)\n", + "is a Transformer-based encoder-decoder model, often used for\n", + "sequence-to-sequence tasks like summarization and neural machine translation.\n", + "BART is pre-trained in a self-supervised fashion on a large text corpus. During\n", + "pre-training, the text is corrupted and BART is trained to reconstruct the\n", + "original text (hence called a \"denoising autoencoder\"). Some pre-training tasks\n", + "include token masking, token deletion, sentence permutation (shuffle sentences\n", + "and train BART to fix the order), etc.\n", + "\n", + "In this example, we will demonstrate how to fine-tune BART on the abstractive\n", + "summarization task (on conversations!) using KerasNLP, and generate summaries\n", + "using the fine-tuned model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Setup\n", + "\n", + "Before we start implementing the pipeline, let's install and import all the\n", + "libraries we need. We'll be using the KerasNLP library. We will also need a\n", + "couple of utility libraries." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "!pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "This examples uses [Keras Core](https://keras.io/keras_core/) to work in any of\n", + "`\"tensorflow\"`, `\"jax\"` or `\"torch\"`. Support for Keras Core is baked into\n", + "KerasNLP, simply change the `\"KERAS_BACKEND\"` environment variable to select\n", + "the backend of your choice. We select the JAX backend below." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Import all necessary libraries." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import py7zr\n", + "import time\n", + "\n", + "import keras_nlp\n", + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "\n", + "import keras_core as keras" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Let's also define our hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "BATCH_SIZE = 8\n", + "NUM_BATCHES = 600\n", + "EPOCHS = 1 # Can be set to a higher value for better results\n", + "MAX_ENCODER_SEQUENCE_LENGTH = 512\n", + "MAX_DECODER_SEQUENCE_LENGTH = 128\n", + "MAX_GENERATION_LENGTH = 40" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Dataset\n", + "\n", + "Let's load the [SAMSum dataset](https://arxiv.org/abs/1911.12237). This dataset\n", + "contains around 15,000 pairs of conversations/dialogues and summaries." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "# Download the dataset.\n", + "filename = keras.utils.get_file(\n", + " \"corpus.7z\",\n", + " origin=\"https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z\",\n", + ")\n", + "\n", + "# Extract the `.7z` file.\n", + "with py7zr.SevenZipFile(filename, mode=\"r\") as z:\n", + " z.extractall(path=\"/root/tensorflow_datasets/downloads/manual\")\n", + "\n", + "# Load data using TFDS.\n", + "samsum_ds = tfds.load(\"samsum\", split=\"train\", as_supervised=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "The dataset has two fields: `dialogue` and `summary`. Let's see a sample." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "for dialogue, summary in samsum_ds:\n", + " print(dialogue.numpy())\n", + " print(summary.numpy())\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We'll now batch the dataset and retain only a subset of the dataset for the\n", + "purpose of this example. The dialogue is fed to the encoder, and the\n", + "corresponding summary serves as input to the decoder. We will, therefore, change\n", + "the format of the dataset to a dictionary having two keys: `\"encoder_text\"` and\n", + "`\"decoder_text\"`.This is how `keras_nlp.models.BartSeq2SeqLMPreprocessor`\n", + "expects the input format to be." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "train_ds = (\n", + " samsum_ds.map(\n", + " lambda dialogue, summary: {\"encoder_text\": dialogue, \"decoder_text\": summary}\n", + " )\n", + " .batch(BATCH_SIZE)\n", + " .cache()\n", + ")\n", + "train_ds = train_ds.take(NUM_BATCHES)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Fine-tune BART\n", + "\n", + "Let's load the model and preprocessor first. We use sequence lengths of 512\n", + "and 128 for the encoder and decoder, respectively, instead of 1024 (which is the\n", + "default sequence length). This will allow us to run this example quickly\n", + "on Colab.\n", + "\n", + "If you observe carefully, the preprocessor is attached to the model. What this\n", + "means is that we don't have to worry about preprocessing the text inputs;\n", + "everything will be done internally. The preprocessor tokenizes the encoder text\n", + "and the decoder text, adds special tokens and pads them. To generate labels\n", + "for auto-regressive training, the preprocessor shifts the decoder text one\n", + "position to the right. This is done because at every timestep, the model is\n", + "trained to predict the next token." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset(\n", + " \"bart_base_en\",\n", + " encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,\n", + " decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,\n", + ")\n", + "bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(\n", + " \"bart_base_en\", preprocessor=preprocessor\n", + ")\n", + "\n", + "bart_lm.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Define the optimizer and loss. We use the Adam optimizer with a linearly\n", + "decaying learning rate. Compile the model." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "optimizer = keras.optimizers.AdamW(\n", + " learning_rate=5e-5,\n", + " weight_decay=0.01,\n", + " epsilon=1e-6,\n", + " global_clipnorm=1.0, # Gradient clipping.\n", + ")\n", + "# Exclude layernorm and bias terms from weight decay.\n", + "optimizer.exclude_from_weight_decay(var_names=[\"bias\"])\n", + "optimizer.exclude_from_weight_decay(var_names=[\"gamma\"])\n", + "optimizer.exclude_from_weight_decay(var_names=[\"beta\"])\n", + "\n", + "loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", + "\n", + "bart_lm.compile(\n", + " optimizer=optimizer,\n", + " loss=loss,\n", + " weighted_metrics=[\"accuracy\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Let's train the model!" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "bart_lm.fit(train_ds, epochs=EPOCHS)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Generate summaries and evaluate them!\n", + "\n", + "Now that the model has been trained, let's get to the fun part - actually\n", + "generating summaries! Let's pick the first 100 samples from the validation set\n", + "and generate summaries for them. We will use the default decoding strategy, i.e.,\n", + "greedy search.\n", + "\n", + "Generation in KerasNLP is highly optimized. It is backed by the power of XLA.\n", + "Secondly, key/value tensors in the self-attention layer and cross-attention layer\n", + "in the decoder are cached to avoid recomputation at every timestep." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def generate_text(model, input_text, max_length=200, print_time_taken=False):\n", + " start = time.time()\n", + " output = model.generate(input_text, max_length=max_length)\n", + " end = time.time()\n", + " print(f\"Total Time Elapsed: {end - start:.2f}s\")\n", + " return output\n", + "\n", + "\n", + "# Load the dataset.\n", + "val_ds = tfds.load(\"samsum\", split=\"validation\", as_supervised=True)\n", + "val_ds = val_ds.take(100)\n", + "\n", + "dialogues = []\n", + "ground_truth_summaries = []\n", + "for dialogue, summary in val_ds:\n", + " dialogues.append(dialogue.numpy())\n", + " ground_truth_summaries.append(summary.numpy())\n", + "\n", + "# Let's make a dummy call - the first call to XLA generally takes a bit longer.\n", + "_ = generate_text(bart_lm, \"sample text\", max_length=MAX_GENERATION_LENGTH)\n", + "\n", + "# Generate summaries.\n", + "generated_summaries = generate_text(\n", + " bart_lm,\n", + " val_ds.map(lambda dialogue, _: dialogue).batch(8),\n", + " max_length=MAX_GENERATION_LENGTH,\n", + " print_time_taken=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Let's see some of the summaries." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "for dialogue, generated_summary, ground_truth_summary in zip(\n", + " dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]\n", + "):\n", + " print(\"Dialogue:\", dialogue)\n", + " print(\"Generated Summary:\", generated_summary)\n", + " print(\"Ground Truth Summary:\", ground_truth_summary)\n", + " print(\"=============================\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "The generated summaries look awesome! Not bad for a model trained only for 1\n", + "epoch and on 5000 examples :)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "abstractive_summarization_with_bart", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/nlp/md/abstractive_summarization_with_bart.md b/examples/nlp/md/abstractive_summarization_with_bart.md new file mode 100644 index 0000000000..06c1ea8018 --- /dev/null +++ b/examples/nlp/md/abstractive_summarization_with_bart.md @@ -0,0 +1,427 @@ +# Abstractive Text Summarization with BART + +**Author:** [Abheesht Sharma](https://github.com/abheesht17/)
+**Date created:** 2023/07/08
+**Last modified:** 2023/07/08
+**Description:** Use KerasNLP to fine-tune BART on the abstractive summarization task. + + + [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/abstractive_summarization_with_bart.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/nlp/abstractive_summarization_with_bart.py) + + + +--- +## Introduction + +In the era of information overload, it has become crucial to extract the crux +of a long document or a conversation and express it in a few sentences. Owing +to the fact that summarization has widespread applications in different domains, +it has become a key, well-studied NLP task in recent years. + +[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461) +is a Transformer-based encoder-decoder model, often used for +sequence-to-sequence tasks like summarization and neural machine translation. +BART is pre-trained in a self-supervised fashion on a large text corpus. During +pre-training, the text is corrupted and BART is trained to reconstruct the +original text (hence called a "denoising autoencoder"). Some pre-training tasks +include token masking, token deletion, sentence permutation (shuffle sentences +and train BART to fix the order), etc. + +In this example, we will demonstrate how to fine-tune BART on the abstractive +summarization task (on conversations!) using KerasNLP, and generate summaries +using the fine-tuned model. + +--- +## Setup + +Before we start implementing the pipeline, let's install and import all the +libraries we need. We'll be using the KerasNLP library. We will also need a +couple of utility libraries. + + +```python +!pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q +``` + +
+``` + Installing build dependencies ... [?25l[?25hdone + Getting requirements to build wheel ... [?25l[?25hdone + Preparing metadata (pyproject.toml) ... [?25l[?25hdone + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.4/66.4 kB 1.4 MB/s eta 0:00:00 + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 34.8 MB/s eta 0:00:00 + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 412.3/412.3 kB 30.4 MB/s eta 0:00:00 + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.8/138.8 kB 15.1 MB/s eta 0:00:00 + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.8/49.8 kB 5.8 MB/s eta 0:00:00 + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 61.4 MB/s eta 0:00:00 + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 93.1/93.1 kB 10.1 MB/s eta 0:00:00 +[?25h Building wheel for keras-nlp (pyproject.toml) ... [?25l[?25hdone + +``` +
+This examples uses [Keras Core](https://keras.io/keras_core/) to work in any of +`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras Core is baked into +KerasNLP, simply change the `"KERAS_BACKEND"` environment variable to select +the backend of your choice. We select the JAX backend below. + + +```python +import os + +os.environ["KERAS_BACKEND"] = "jax" +``` + +Import all necessary libraries. + + +```python +import py7zr +import time + +import keras_nlp +import tensorflow as tf +import tensorflow_datasets as tfds + +import keras_core as keras +``` + +
+``` +Using JAX backend. + +``` +
+Let's also define our hyperparameters. + + +```python +BATCH_SIZE = 8 +NUM_BATCHES = 600 +EPOCHS = 1 # Can be set to a higher value for better results +MAX_ENCODER_SEQUENCE_LENGTH = 512 +MAX_DECODER_SEQUENCE_LENGTH = 128 +MAX_GENERATION_LENGTH = 40 +``` + +--- +## Dataset + +Let's load the [SAMSum dataset](https://arxiv.org/abs/1911.12237). This dataset +contains around 15,000 pairs of conversations/dialogues and summaries. + + +```python +# Download the dataset. +filename = keras.utils.get_file( + "corpus.7z", + origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z", +) + +# Extract the `.7z` file. +with py7zr.SevenZipFile(filename, mode="r") as z: + z.extractall(path="/root/tensorflow_datasets/downloads/manual") + +# Load data using TFDS. +samsum_ds = tfds.load("samsum", split="train", as_supervised=True) +``` + +
+``` +Downloading data from https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z + 2944100/2944100 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step +Downloading and preparing dataset Unknown size (download: Unknown size, generated: 10.71 MiB, total: 10.71 MiB) to /root/tensorflow_datasets/samsum/1.0.0... + +Generating splits...: 0%| | 0/3 [00:00 +The dataset has two fields: `dialogue` and `summary`. Let's see a sample. + + +```python +for dialogue, summary in samsum_ds: + print(dialogue.numpy()) + print(summary.numpy()) + break +``` + +
+``` +b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. " +b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.' + +``` +
+We'll now batch the dataset and retain only a subset of the dataset for the +purpose of this example. The dialogue is fed to the encoder, and the +corresponding summary serves as input to the decoder. We will, therefore, change +the format of the dataset to a dictionary having two keys: `"encoder_text"` and +`"decoder_text"`.This is how `keras_nlp.models.BartSeq2SeqLMPreprocessor` +expects the input format to be. + + +```python +train_ds = ( + samsum_ds.map( + lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary} + ) + .batch(BATCH_SIZE) + .cache() +) +train_ds = train_ds.take(NUM_BATCHES) +``` + +--- +## Fine-tune BART + +Let's load the model and preprocessor first. We use sequence lengths of 512 +and 128 for the encoder and decoder, respectively, instead of 1024 (which is the +default sequence length). This will allow us to run this example quickly +on Colab. + +If you observe carefully, the preprocessor is attached to the model. What this +means is that we don't have to worry about preprocessing the text inputs; +everything will be done internally. The preprocessor tokenizes the encoder text +and the decoder text, adds special tokens and pads them. To generate labels +for auto-regressive training, the preprocessor shifts the decoder text one +position to the right. This is done because at every timestep, the model is +trained to predict the next token. + + +```python +preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset( + "bart_base_en", + encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH, + decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH, +) +bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset( + "bart_base_en", preprocessor=preprocessor +) + +bart_lm.summary() +``` + +
+``` +Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/vocab.json + 898823/898823 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step +Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/merges.txt + 456318/456318 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step +Downloading data from https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/model.h5 + 557969120/557969120 ━━━━━━━━━━━━━━━━━━━━ 29s 0us/step + +``` +
+
Preprocessor: "bart_seq2_seq_lm_preprocessor"
+
+ + + + +
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ Tokenizer (type)                                                                                Vocab # ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ bart_tokenizer (BartTokenizer)                     │                                              50,265 │
+└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
+
+ + + + +
Model: "bart_seq2_seq_lm"
+
+ + + + +
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)                   Output Shape                   Param #  Connected to                   ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ decoder_padding_mask          │ (None, None)              │           0 │ -                              │
+│ (InputLayer)                  │                           │             │                                │
+├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
+│ decoder_token_ids             │ (None, None)              │           0 │ -                              │
+│ (InputLayer)                  │                           │             │                                │
+├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
+│ encoder_padding_mask          │ (None, None)              │           0 │ -                              │
+│ (InputLayer)                  │                           │             │                                │
+├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
+│ encoder_token_ids             │ (None, None)              │           0 │ -                              │
+│ (InputLayer)                  │                           │             │                                │
+├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
+│ bart_backbone (BartBackbone)  │ [(None, None, 768),       │ 139,417,344 │ decoder_padding_mask[0][0],    │
+│                               │ (None, None, 768)]        │             │ decoder_token_ids[0][0],       │
+│                               │                           │             │ encoder_padding_mask[0][0],    │
+│                               │                           │             │ encoder_token_ids[0][0]        │
+├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
+│ reverse_embedding             │ (None, 50265)             │  38,603,520 │ bart_backbone[0][0]            │
+│ (ReverseEmbedding)            │                           │             │                                │
+└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
+
+ + + + +
 Total params: 139,417,344 (4.15 GB)
+
+ + + + +
 Trainable params: 139,417,344 (4.15 GB)
+
+ + + + +
 Non-trainable params: 0 (0.00 B)
+
+ + + +Define the optimizer and loss. We use the Adam optimizer with a linearly +decaying learning rate. Compile the model. + + +```python +optimizer = keras.optimizers.AdamW( + learning_rate=5e-5, + weight_decay=0.01, + epsilon=1e-6, + global_clipnorm=1.0, # Gradient clipping. +) +# Exclude layernorm and bias terms from weight decay. +optimizer.exclude_from_weight_decay(var_names=["bias"]) +optimizer.exclude_from_weight_decay(var_names=["gamma"]) +optimizer.exclude_from_weight_decay(var_names=["beta"]) + +loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + +bart_lm.compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=["accuracy"], +) +``` + +Let's train the model! + + +```python +bart_lm.fit(train_ds, epochs=EPOCHS) +``` + +
+``` + 600/600 ━━━━━━━━━━━━━━━━━━━━ 398s 586ms/step - loss: 0.4330 + + + +``` +
+--- +## Generate summaries and evaluate them! + +Now that the model has been trained, let's get to the fun part - actually +generating summaries! Let's pick the first 100 samples from the validation set +and generate summaries for them. We will use the default decoding strategy, i.e., +greedy search. + +Generation in KerasNLP is highly optimized. It is backed by the power of XLA. +Secondly, key/value tensors in the self-attention layer and cross-attention layer +in the decoder are cached to avoid recomputation at every timestep. + + +```python + +def generate_text(model, input_text, max_length=200, print_time_taken=False): + start = time.time() + output = model.generate(input_text, max_length=max_length) + end = time.time() + print(f"Total Time Elapsed: {end - start:.2f}s") + return output + + +# Load the dataset. +val_ds = tfds.load("samsum", split="validation", as_supervised=True) +val_ds = val_ds.take(100) + +dialogues = [] +ground_truth_summaries = [] +for dialogue, summary in val_ds: + dialogues.append(dialogue.numpy()) + ground_truth_summaries.append(summary.numpy()) + +# Let's make a dummy call - the first call to XLA generally takes a bit longer. +_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH) + +# Generate summaries. +generated_summaries = generate_text( + bart_lm, + val_ds.map(lambda dialogue, _: dialogue).batch(8), + max_length=MAX_GENERATION_LENGTH, + print_time_taken=True, +) +``` + +
+``` +Total Time Elapsed: 21.22s +Total Time Elapsed: 49.00s + +``` +
+Let's see some of the summaries. + + +```python +for dialogue, generated_summary, ground_truth_summary in zip( + dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5] +): + print("Dialogue:", dialogue) + print("Generated Summary:", generated_summary) + print("Ground Truth Summary:", ground_truth_summary) + print("=============================") +``` + +
+``` +Dialogue: b'Tony: Is the boss in?\r\nClaire: Not yet.\r\nTony: Could let me know when he comes, please? \r\nClaire: Of course.\r\nTony: Thank you.' +Generated Summary: Tony will let Claire know when her boss comes. +Ground Truth Summary: b"The boss isn't in yet. Claire will let Tony know when he comes." +============================= +Dialogue: b"James: What shouldl I get her?\r\nTim: who?\r\nJames: gees Mary my girlfirend\r\nTim: Am I really the person you should be asking?\r\nJames: oh come on it's her birthday on Sat\r\nTim: ask Sandy\r\nTim: I honestly am not the right person to ask this\r\nJames: ugh fine!" +Generated Summary: Mary's girlfriend is birthday. James and Tim are going to ask Sandy to buy her. +Ground Truth Summary: b"Mary's birthday is on Saturday. Her boyfriend, James, is looking for gift ideas. Tim suggests that he ask Sandy." +============================= +Dialogue: b"Mary: So, how's Israel? Have you been on the beach?\r\nKate: It's so expensive! But they say, it's Tel Aviv... Tomorrow we are going to Jerusalem.\r\nMary: I've heard Israel is expensive, Monica was there on vacation last year, she complained about how pricey it is. Are you going to the Dead Sea before it dies? ahahahha\r\nKate: ahahhaha yup, in few days." +Generated Summary: Kate is on vacation in Tel Aviv. Mary will visit the Dead Sea in a few days. +Ground Truth Summary: b'Mary and Kate discuss how expensive Israel is. Kate is in Tel Aviv now, planning to travel to Jerusalem tomorrow, and to the Dead Sea few days later.' +============================= +Dialogue: b"Giny: do we have rice?\r\nRiley: nope, it's finished\r\nGiny: fuck!\r\nGiny: ok, I'll buy" +Generated Summary: Giny wants to buy rice from Riley. +Ground Truth Summary: b"Giny and Riley don't have any rice left. Giny will buy some." +============================= +Dialogue: b"Jude: i'll be in warsaw at the beginning of december so we could meet again\r\nLeon: !!!\r\nLeon: at the beginning means...?\r\nLeon: cuz I won't be here during the first weekend\r\nJude: 10\r\nJude: but i think it's a monday, so never mind i guess :D\r\nLeon: yeah monday doesn't really work for me :D\r\nLeon: :<\r\nJude: oh well next time :d\r\nLeon: yeah...!" +Generated Summary: Jude and Leon will meet again this weekend at 10 am. +Ground Truth Summary: b'Jude is coming to Warsaw on the 10th of December and wants to see Leon. Leon has no time.' +============================= + +``` +
+The generated summaries look awesome! Not bad for a model trained only for 1 +epoch and on 5000 examples :)