diff --git a/examples/nlp/abstractive_summarization_with_bart.py b/examples/nlp/abstractive_summarization_with_bart.py
new file mode 100644
index 0000000000..f712383ee4
--- /dev/null
+++ b/examples/nlp/abstractive_summarization_with_bart.py
@@ -0,0 +1,240 @@
+"""
+Title: Abstractive Text Summarization with BART
+Author: [Abheesht Sharma](https://github.com/abheesht17/)
+Date created: 2023/07/08
+Last modified: 2023/07/08
+Description: Use KerasNLP to fine-tune BART on the abstractive summarization task.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In the era of information overload, it has become crucial to extract the crux
+of a long document or a conversation and express it in a few sentences. Owing
+to the fact that summarization has widespread applications in different domains,
+it has become a key, well-studied NLP task in recent years.
+
+[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461)
+is a Transformer-based encoder-decoder model, often used for
+sequence-to-sequence tasks like summarization and neural machine translation.
+BART is pre-trained in a self-supervised fashion on a large text corpus. During
+pre-training, the text is corrupted and BART is trained to reconstruct the
+original text (hence called a "denoising autoencoder"). Some pre-training tasks
+include token masking, token deletion, sentence permutation (shuffle sentences
+and train BART to fix the order), etc.
+
+In this example, we will demonstrate how to fine-tune BART on the abstractive
+summarization task (on conversations!) using KerasNLP, and generate summaries
+using the fine-tuned model.
+"""
+
+"""
+## Setup
+
+Before we start implementing the pipeline, let's install and import all the
+libraries we need. We'll be using the KerasNLP library. We will also need a
+couple of utility libraries.
+"""
+
+"""shell
+pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q
+"""
+
+"""
+This examples uses [Keras Core](https://keras.io/keras_core/) to work in any of
+`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras Core is baked into
+KerasNLP, simply change the `"KERAS_BACKEND"` environment variable to select
+the backend of your choice. We select the JAX backend below.
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax"
+
+"""
+Import all necessary libraries.
+"""
+
+import py7zr
+import time
+
+import keras_nlp
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+import keras_core as keras
+
+"""
+Let's also define our hyperparameters.
+"""
+
+BATCH_SIZE = 8
+NUM_BATCHES = 600
+EPOCHS = 1 # Can be set to a higher value for better results
+MAX_ENCODER_SEQUENCE_LENGTH = 512
+MAX_DECODER_SEQUENCE_LENGTH = 128
+MAX_GENERATION_LENGTH = 40
+
+"""
+## Dataset
+
+Let's load the [SAMSum dataset](https://arxiv.org/abs/1911.12237). This dataset
+contains around 15,000 pairs of conversations/dialogues and summaries.
+"""
+
+# Download the dataset.
+filename = keras.utils.get_file(
+ "corpus.7z",
+ origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
+)
+
+# Extract the `.7z` file.
+with py7zr.SevenZipFile(filename, mode="r") as z:
+ z.extractall(path="/root/tensorflow_datasets/downloads/manual")
+
+# Load data using TFDS.
+samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
+
+"""
+The dataset has two fields: `dialogue` and `summary`. Let's see a sample.
+"""
+for dialogue, summary in samsum_ds:
+ print(dialogue.numpy())
+ print(summary.numpy())
+ break
+
+"""
+We'll now batch the dataset and retain only a subset of the dataset for the
+purpose of this example. The dialogue is fed to the encoder, and the
+corresponding summary serves as input to the decoder. We will, therefore, change
+the format of the dataset to a dictionary having two keys: `"encoder_text"` and
+`"decoder_text"`.This is how `keras_nlp.models.BartSeq2SeqLMPreprocessor`
+expects the input format to be.
+"""
+
+train_ds = (
+ samsum_ds.map(
+ lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
+ )
+ .batch(BATCH_SIZE)
+ .cache()
+)
+train_ds = train_ds.take(NUM_BATCHES)
+
+"""
+## Fine-tune BART
+
+Let's load the model and preprocessor first. We use sequence lengths of 512
+and 128 for the encoder and decoder, respectively, instead of 1024 (which is the
+default sequence length). This will allow us to run this example quickly
+on Colab.
+
+If you observe carefully, the preprocessor is attached to the model. What this
+means is that we don't have to worry about preprocessing the text inputs;
+everything will be done internally. The preprocessor tokenizes the encoder text
+and the decoder text, adds special tokens and pads them. To generate labels
+for auto-regressive training, the preprocessor shifts the decoder text one
+position to the right. This is done because at every timestep, the model is
+trained to predict the next token.
+"""
+
+preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset(
+ "bart_base_en",
+ encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
+ decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
+)
+bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(
+ "bart_base_en", preprocessor=preprocessor
+)
+
+bart_lm.summary()
+
+"""
+Define the optimizer and loss. We use the Adam optimizer with a linearly
+decaying learning rate. Compile the model.
+"""
+
+optimizer = keras.optimizers.AdamW(
+ learning_rate=5e-5,
+ weight_decay=0.01,
+ epsilon=1e-6,
+ global_clipnorm=1.0, # Gradient clipping.
+)
+# Exclude layernorm and bias terms from weight decay.
+optimizer.exclude_from_weight_decay(var_names=["bias"])
+optimizer.exclude_from_weight_decay(var_names=["gamma"])
+optimizer.exclude_from_weight_decay(var_names=["beta"])
+
+loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+
+bart_lm.compile(
+ optimizer=optimizer,
+ loss=loss,
+ weighted_metrics=["accuracy"],
+)
+
+"""
+Let's train the model!
+"""
+
+bart_lm.fit(train_ds, epochs=EPOCHS)
+
+"""
+## Generate summaries and evaluate them!
+
+Now that the model has been trained, let's get to the fun part - actually
+generating summaries! Let's pick the first 100 samples from the validation set
+and generate summaries for them. We will use the default decoding strategy, i.e.,
+greedy search.
+
+Generation in KerasNLP is highly optimized. It is backed by the power of XLA.
+Secondly, key/value tensors in the self-attention layer and cross-attention layer
+in the decoder are cached to avoid recomputation at every timestep.
+"""
+
+
+def generate_text(model, input_text, max_length=200, print_time_taken=False):
+ start = time.time()
+ output = model.generate(input_text, max_length=max_length)
+ end = time.time()
+ print(f"Total Time Elapsed: {end - start:.2f}s")
+ return output
+
+
+# Load the dataset.
+val_ds = tfds.load("samsum", split="validation", as_supervised=True)
+val_ds = val_ds.take(100)
+
+dialogues = []
+ground_truth_summaries = []
+for dialogue, summary in val_ds:
+ dialogues.append(dialogue.numpy())
+ ground_truth_summaries.append(summary.numpy())
+
+# Let's make a dummy call - the first call to XLA generally takes a bit longer.
+_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)
+
+# Generate summaries.
+generated_summaries = generate_text(
+ bart_lm,
+ val_ds.map(lambda dialogue, _: dialogue).batch(8),
+ max_length=MAX_GENERATION_LENGTH,
+ print_time_taken=True,
+)
+
+"""
+Let's see some of the summaries.
+"""
+for dialogue, generated_summary, ground_truth_summary in zip(
+ dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
+):
+ print("Dialogue:", dialogue)
+ print("Generated Summary:", generated_summary)
+ print("Ground Truth Summary:", ground_truth_summary)
+ print("=============================")
+
+"""
+The generated summaries look awesome! Not bad for a model trained only for 1
+epoch and on 5000 examples :)
+"""
diff --git a/examples/nlp/ipynb/abstractive_summarization_with_bart.ipynb b/examples/nlp/ipynb/abstractive_summarization_with_bart.ipynb
new file mode 100644
index 0000000000..22feb911e9
--- /dev/null
+++ b/examples/nlp/ipynb/abstractive_summarization_with_bart.ipynb
@@ -0,0 +1,457 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Abstractive Text Summarization with BART\n",
+ "\n",
+ "**Author:** [Abheesht Sharma](https://github.com/abheesht17/)
\n",
+ "**Date created:** 2023/07/08
\n",
+ "**Last modified:** 2023/07/08
\n",
+ "**Description:** Use KerasNLP to fine-tune BART on the abstractive summarization task."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "In the era of information overload, it has become crucial to extract the crux\n",
+ "of a long document or a conversation and express it in a few sentences. Owing\n",
+ "to the fact that summarization has widespread applications in different domains,\n",
+ "it has become a key, well-studied NLP task in recent years.\n",
+ "\n",
+ "[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461)\n",
+ "is a Transformer-based encoder-decoder model, often used for\n",
+ "sequence-to-sequence tasks like summarization and neural machine translation.\n",
+ "BART is pre-trained in a self-supervised fashion on a large text corpus. During\n",
+ "pre-training, the text is corrupted and BART is trained to reconstruct the\n",
+ "original text (hence called a \"denoising autoencoder\"). Some pre-training tasks\n",
+ "include token masking, token deletion, sentence permutation (shuffle sentences\n",
+ "and train BART to fix the order), etc.\n",
+ "\n",
+ "In this example, we will demonstrate how to fine-tune BART on the abstractive\n",
+ "summarization task (on conversations!) using KerasNLP, and generate summaries\n",
+ "using the fine-tuned model."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Setup\n",
+ "\n",
+ "Before we start implementing the pipeline, let's install and import all the\n",
+ "libraries we need. We'll be using the KerasNLP library. We will also need a\n",
+ "couple of utility libraries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "This examples uses [Keras Core](https://keras.io/keras_core/) to work in any of\n",
+ "`\"tensorflow\"`, `\"jax\"` or `\"torch\"`. Support for Keras Core is baked into\n",
+ "KerasNLP, simply change the `\"KERAS_BACKEND\"` environment variable to select\n",
+ "the backend of your choice. We select the JAX backend below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"jax\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Import all necessary libraries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import py7zr\n",
+ "import time\n",
+ "\n",
+ "import keras_nlp\n",
+ "import tensorflow as tf\n",
+ "import tensorflow_datasets as tfds\n",
+ "\n",
+ "import keras_core as keras"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's also define our hyperparameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "BATCH_SIZE = 8\n",
+ "NUM_BATCHES = 600\n",
+ "EPOCHS = 1 # Can be set to a higher value for better results\n",
+ "MAX_ENCODER_SEQUENCE_LENGTH = 512\n",
+ "MAX_DECODER_SEQUENCE_LENGTH = 128\n",
+ "MAX_GENERATION_LENGTH = 40"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Dataset\n",
+ "\n",
+ "Let's load the [SAMSum dataset](https://arxiv.org/abs/1911.12237). This dataset\n",
+ "contains around 15,000 pairs of conversations/dialogues and summaries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# Download the dataset.\n",
+ "filename = keras.utils.get_file(\n",
+ " \"corpus.7z\",\n",
+ " origin=\"https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z\",\n",
+ ")\n",
+ "\n",
+ "# Extract the `.7z` file.\n",
+ "with py7zr.SevenZipFile(filename, mode=\"r\") as z:\n",
+ " z.extractall(path=\"/root/tensorflow_datasets/downloads/manual\")\n",
+ "\n",
+ "# Load data using TFDS.\n",
+ "samsum_ds = tfds.load(\"samsum\", split=\"train\", as_supervised=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "The dataset has two fields: `dialogue` and `summary`. Let's see a sample."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "for dialogue, summary in samsum_ds:\n",
+ " print(dialogue.numpy())\n",
+ " print(summary.numpy())\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We'll now batch the dataset and retain only a subset of the dataset for the\n",
+ "purpose of this example. The dialogue is fed to the encoder, and the\n",
+ "corresponding summary serves as input to the decoder. We will, therefore, change\n",
+ "the format of the dataset to a dictionary having two keys: `\"encoder_text\"` and\n",
+ "`\"decoder_text\"`.This is how `keras_nlp.models.BartSeq2SeqLMPreprocessor`\n",
+ "expects the input format to be."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "train_ds = (\n",
+ " samsum_ds.map(\n",
+ " lambda dialogue, summary: {\"encoder_text\": dialogue, \"decoder_text\": summary}\n",
+ " )\n",
+ " .batch(BATCH_SIZE)\n",
+ " .cache()\n",
+ ")\n",
+ "train_ds = train_ds.take(NUM_BATCHES)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Fine-tune BART\n",
+ "\n",
+ "Let's load the model and preprocessor first. We use sequence lengths of 512\n",
+ "and 128 for the encoder and decoder, respectively, instead of 1024 (which is the\n",
+ "default sequence length). This will allow us to run this example quickly\n",
+ "on Colab.\n",
+ "\n",
+ "If you observe carefully, the preprocessor is attached to the model. What this\n",
+ "means is that we don't have to worry about preprocessing the text inputs;\n",
+ "everything will be done internally. The preprocessor tokenizes the encoder text\n",
+ "and the decoder text, adds special tokens and pads them. To generate labels\n",
+ "for auto-regressive training, the preprocessor shifts the decoder text one\n",
+ "position to the right. This is done because at every timestep, the model is\n",
+ "trained to predict the next token."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset(\n",
+ " \"bart_base_en\",\n",
+ " encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,\n",
+ " decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,\n",
+ ")\n",
+ "bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(\n",
+ " \"bart_base_en\", preprocessor=preprocessor\n",
+ ")\n",
+ "\n",
+ "bart_lm.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Define the optimizer and loss. We use the Adam optimizer with a linearly\n",
+ "decaying learning rate. Compile the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "optimizer = keras.optimizers.AdamW(\n",
+ " learning_rate=5e-5,\n",
+ " weight_decay=0.01,\n",
+ " epsilon=1e-6,\n",
+ " global_clipnorm=1.0, # Gradient clipping.\n",
+ ")\n",
+ "# Exclude layernorm and bias terms from weight decay.\n",
+ "optimizer.exclude_from_weight_decay(var_names=[\"bias\"])\n",
+ "optimizer.exclude_from_weight_decay(var_names=[\"gamma\"])\n",
+ "optimizer.exclude_from_weight_decay(var_names=[\"beta\"])\n",
+ "\n",
+ "loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
+ "\n",
+ "bart_lm.compile(\n",
+ " optimizer=optimizer,\n",
+ " loss=loss,\n",
+ " weighted_metrics=[\"accuracy\"],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's train the model!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "bart_lm.fit(train_ds, epochs=EPOCHS)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Generate summaries and evaluate them!\n",
+ "\n",
+ "Now that the model has been trained, let's get to the fun part - actually\n",
+ "generating summaries! Let's pick the first 100 samples from the validation set\n",
+ "and generate summaries for them. We will use the default decoding strategy, i.e.,\n",
+ "greedy search.\n",
+ "\n",
+ "Generation in KerasNLP is highly optimized. It is backed by the power of XLA.\n",
+ "Secondly, key/value tensors in the self-attention layer and cross-attention layer\n",
+ "in the decoder are cached to avoid recomputation at every timestep."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def generate_text(model, input_text, max_length=200, print_time_taken=False):\n",
+ " start = time.time()\n",
+ " output = model.generate(input_text, max_length=max_length)\n",
+ " end = time.time()\n",
+ " print(f\"Total Time Elapsed: {end - start:.2f}s\")\n",
+ " return output\n",
+ "\n",
+ "\n",
+ "# Load the dataset.\n",
+ "val_ds = tfds.load(\"samsum\", split=\"validation\", as_supervised=True)\n",
+ "val_ds = val_ds.take(100)\n",
+ "\n",
+ "dialogues = []\n",
+ "ground_truth_summaries = []\n",
+ "for dialogue, summary in val_ds:\n",
+ " dialogues.append(dialogue.numpy())\n",
+ " ground_truth_summaries.append(summary.numpy())\n",
+ "\n",
+ "# Let's make a dummy call - the first call to XLA generally takes a bit longer.\n",
+ "_ = generate_text(bart_lm, \"sample text\", max_length=MAX_GENERATION_LENGTH)\n",
+ "\n",
+ "# Generate summaries.\n",
+ "generated_summaries = generate_text(\n",
+ " bart_lm,\n",
+ " val_ds.map(lambda dialogue, _: dialogue).batch(8),\n",
+ " max_length=MAX_GENERATION_LENGTH,\n",
+ " print_time_taken=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's see some of the summaries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "for dialogue, generated_summary, ground_truth_summary in zip(\n",
+ " dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]\n",
+ "):\n",
+ " print(\"Dialogue:\", dialogue)\n",
+ " print(\"Generated Summary:\", generated_summary)\n",
+ " print(\"Ground Truth Summary:\", ground_truth_summary)\n",
+ " print(\"=============================\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "The generated summaries look awesome! Not bad for a model trained only for 1\n",
+ "epoch and on 5000 examples :)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "abstractive_summarization_with_bart",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/nlp/md/abstractive_summarization_with_bart.md b/examples/nlp/md/abstractive_summarization_with_bart.md
new file mode 100644
index 0000000000..06c1ea8018
--- /dev/null
+++ b/examples/nlp/md/abstractive_summarization_with_bart.md
@@ -0,0 +1,427 @@
+# Abstractive Text Summarization with BART
+
+**Author:** [Abheesht Sharma](https://github.com/abheesht17/)
+**Date created:** 2023/07/08
+**Last modified:** 2023/07/08
+**Description:** Use KerasNLP to fine-tune BART on the abstractive summarization task.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/abstractive_summarization_with_bart.ipynb) • [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/nlp/abstractive_summarization_with_bart.py)
+
+
+
+---
+## Introduction
+
+In the era of information overload, it has become crucial to extract the crux
+of a long document or a conversation and express it in a few sentences. Owing
+to the fact that summarization has widespread applications in different domains,
+it has become a key, well-studied NLP task in recent years.
+
+[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461)
+is a Transformer-based encoder-decoder model, often used for
+sequence-to-sequence tasks like summarization and neural machine translation.
+BART is pre-trained in a self-supervised fashion on a large text corpus. During
+pre-training, the text is corrupted and BART is trained to reconstruct the
+original text (hence called a "denoising autoencoder"). Some pre-training tasks
+include token masking, token deletion, sentence permutation (shuffle sentences
+and train BART to fix the order), etc.
+
+In this example, we will demonstrate how to fine-tune BART on the abstractive
+summarization task (on conversations!) using KerasNLP, and generate summaries
+using the fine-tuned model.
+
+---
+## Setup
+
+Before we start implementing the pipeline, let's install and import all the
+libraries we need. We'll be using the KerasNLP library. We will also need a
+couple of utility libraries.
+
+
+```python
+!pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q
+```
+
+
Preprocessor: "bart_seq2_seq_lm_preprocessor"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Tokenizer (type) ┃ Vocab # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ bart_tokenizer (BartTokenizer) │ 50,265 │ +└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘ ++ + + + +
Model: "bart_seq2_seq_lm"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ decoder_padding_mask │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ +│ decoder_token_ids │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ +│ encoder_padding_mask │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ +│ encoder_token_ids │ (None, None) │ 0 │ - │ +│ (InputLayer) │ │ │ │ +├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ +│ bart_backbone (BartBackbone) │ [(None, None, 768), │ 139,417,344 │ decoder_padding_mask[0][0], │ +│ │ (None, None, 768)] │ │ decoder_token_ids[0][0], │ +│ │ │ │ encoder_padding_mask[0][0], │ +│ │ │ │ encoder_token_ids[0][0] │ +├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ +│ reverse_embedding │ (None, 50265) │ 38,603,520 │ bart_backbone[0][0] │ +│ (ReverseEmbedding) │ │ │ │ +└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘ ++ + + + +
Total params: 139,417,344 (4.15 GB) ++ + + + +
Trainable params: 139,417,344 (4.15 GB) ++ + + + +
Non-trainable params: 0 (0.00 B) ++ + + +Define the optimizer and loss. We use the Adam optimizer with a linearly +decaying learning rate. Compile the model. + + +```python +optimizer = keras.optimizers.AdamW( + learning_rate=5e-5, + weight_decay=0.01, + epsilon=1e-6, + global_clipnorm=1.0, # Gradient clipping. +) +# Exclude layernorm and bias terms from weight decay. +optimizer.exclude_from_weight_decay(var_names=["bias"]) +optimizer.exclude_from_weight_decay(var_names=["gamma"]) +optimizer.exclude_from_weight_decay(var_names=["beta"]) + +loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + +bart_lm.compile( + optimizer=optimizer, + loss=loss, + weighted_metrics=["accuracy"], +) +``` + +Let's train the model! + + +```python +bart_lm.fit(train_ds, epochs=EPOCHS) +``` + +