-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
KerasNLP: Abstractive Summarization with BART (#1422)
* KerasNLP: Abstractive Summarization with BART * Add ROUGE score * mixed precision * Address comments * Address comments * Add generated files
- Loading branch information
1 parent
36d6eb0
commit cd737c8
Showing
3 changed files
with
1,124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
""" | ||
Title: Abstractive Text Summarization with BART | ||
Author: [Abheesht Sharma](https://github.com/abheesht17/) | ||
Date created: 2023/07/08 | ||
Last modified: 2023/07/08 | ||
Description: Use KerasNLP to fine-tune BART on the abstractive summarization task. | ||
Accelerator: GPU | ||
""" | ||
|
||
""" | ||
## Introduction | ||
In the era of information overload, it has become crucial to extract the crux | ||
of a long document or a conversation and express it in a few sentences. Owing | ||
to the fact that summarization has widespread applications in different domains, | ||
it has become a key, well-studied NLP task in recent years. | ||
[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461) | ||
is a Transformer-based encoder-decoder model, often used for | ||
sequence-to-sequence tasks like summarization and neural machine translation. | ||
BART is pre-trained in a self-supervised fashion on a large text corpus. During | ||
pre-training, the text is corrupted and BART is trained to reconstruct the | ||
original text (hence called a "denoising autoencoder"). Some pre-training tasks | ||
include token masking, token deletion, sentence permutation (shuffle sentences | ||
and train BART to fix the order), etc. | ||
In this example, we will demonstrate how to fine-tune BART on the abstractive | ||
summarization task (on conversations!) using KerasNLP, and generate summaries | ||
using the fine-tuned model. | ||
""" | ||
|
||
""" | ||
## Setup | ||
Before we start implementing the pipeline, let's install and import all the | ||
libraries we need. We'll be using the KerasNLP library. We will also need a | ||
couple of utility libraries. | ||
""" | ||
|
||
"""shell | ||
pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q | ||
""" | ||
|
||
""" | ||
This examples uses [Keras Core](https://keras.io/keras_core/) to work in any of | ||
`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras Core is baked into | ||
KerasNLP, simply change the `"KERAS_BACKEND"` environment variable to select | ||
the backend of your choice. We select the JAX backend below. | ||
""" | ||
|
||
import os | ||
|
||
os.environ["KERAS_BACKEND"] = "jax" | ||
|
||
""" | ||
Import all necessary libraries. | ||
""" | ||
|
||
import py7zr | ||
import time | ||
|
||
import keras_nlp | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
|
||
import keras_core as keras | ||
|
||
""" | ||
Let's also define our hyperparameters. | ||
""" | ||
|
||
BATCH_SIZE = 8 | ||
NUM_BATCHES = 600 | ||
EPOCHS = 1 # Can be set to a higher value for better results | ||
MAX_ENCODER_SEQUENCE_LENGTH = 512 | ||
MAX_DECODER_SEQUENCE_LENGTH = 128 | ||
MAX_GENERATION_LENGTH = 40 | ||
|
||
""" | ||
## Dataset | ||
Let's load the [SAMSum dataset](https://arxiv.org/abs/1911.12237). This dataset | ||
contains around 15,000 pairs of conversations/dialogues and summaries. | ||
""" | ||
|
||
# Download the dataset. | ||
filename = keras.utils.get_file( | ||
"corpus.7z", | ||
origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z", | ||
) | ||
|
||
# Extract the `.7z` file. | ||
with py7zr.SevenZipFile(filename, mode="r") as z: | ||
z.extractall(path="/root/tensorflow_datasets/downloads/manual") | ||
|
||
# Load data using TFDS. | ||
samsum_ds = tfds.load("samsum", split="train", as_supervised=True) | ||
|
||
""" | ||
The dataset has two fields: `dialogue` and `summary`. Let's see a sample. | ||
""" | ||
for dialogue, summary in samsum_ds: | ||
print(dialogue.numpy()) | ||
print(summary.numpy()) | ||
break | ||
|
||
""" | ||
We'll now batch the dataset and retain only a subset of the dataset for the | ||
purpose of this example. The dialogue is fed to the encoder, and the | ||
corresponding summary serves as input to the decoder. We will, therefore, change | ||
the format of the dataset to a dictionary having two keys: `"encoder_text"` and | ||
`"decoder_text"`.This is how `keras_nlp.models.BartSeq2SeqLMPreprocessor` | ||
expects the input format to be. | ||
""" | ||
|
||
train_ds = ( | ||
samsum_ds.map( | ||
lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary} | ||
) | ||
.batch(BATCH_SIZE) | ||
.cache() | ||
) | ||
train_ds = train_ds.take(NUM_BATCHES) | ||
|
||
""" | ||
## Fine-tune BART | ||
Let's load the model and preprocessor first. We use sequence lengths of 512 | ||
and 128 for the encoder and decoder, respectively, instead of 1024 (which is the | ||
default sequence length). This will allow us to run this example quickly | ||
on Colab. | ||
If you observe carefully, the preprocessor is attached to the model. What this | ||
means is that we don't have to worry about preprocessing the text inputs; | ||
everything will be done internally. The preprocessor tokenizes the encoder text | ||
and the decoder text, adds special tokens and pads them. To generate labels | ||
for auto-regressive training, the preprocessor shifts the decoder text one | ||
position to the right. This is done because at every timestep, the model is | ||
trained to predict the next token. | ||
""" | ||
|
||
preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset( | ||
"bart_base_en", | ||
encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH, | ||
decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH, | ||
) | ||
bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset( | ||
"bart_base_en", preprocessor=preprocessor | ||
) | ||
|
||
bart_lm.summary() | ||
|
||
""" | ||
Define the optimizer and loss. We use the Adam optimizer with a linearly | ||
decaying learning rate. Compile the model. | ||
""" | ||
|
||
optimizer = keras.optimizers.AdamW( | ||
learning_rate=5e-5, | ||
weight_decay=0.01, | ||
epsilon=1e-6, | ||
global_clipnorm=1.0, # Gradient clipping. | ||
) | ||
# Exclude layernorm and bias terms from weight decay. | ||
optimizer.exclude_from_weight_decay(var_names=["bias"]) | ||
optimizer.exclude_from_weight_decay(var_names=["gamma"]) | ||
optimizer.exclude_from_weight_decay(var_names=["beta"]) | ||
|
||
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) | ||
|
||
bart_lm.compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
weighted_metrics=["accuracy"], | ||
) | ||
|
||
""" | ||
Let's train the model! | ||
""" | ||
|
||
bart_lm.fit(train_ds, epochs=EPOCHS) | ||
|
||
""" | ||
## Generate summaries and evaluate them! | ||
Now that the model has been trained, let's get to the fun part - actually | ||
generating summaries! Let's pick the first 100 samples from the validation set | ||
and generate summaries for them. We will use the default decoding strategy, i.e., | ||
greedy search. | ||
Generation in KerasNLP is highly optimized. It is backed by the power of XLA. | ||
Secondly, key/value tensors in the self-attention layer and cross-attention layer | ||
in the decoder are cached to avoid recomputation at every timestep. | ||
""" | ||
|
||
|
||
def generate_text(model, input_text, max_length=200, print_time_taken=False): | ||
start = time.time() | ||
output = model.generate(input_text, max_length=max_length) | ||
end = time.time() | ||
print(f"Total Time Elapsed: {end - start:.2f}s") | ||
return output | ||
|
||
|
||
# Load the dataset. | ||
val_ds = tfds.load("samsum", split="validation", as_supervised=True) | ||
val_ds = val_ds.take(100) | ||
|
||
dialogues = [] | ||
ground_truth_summaries = [] | ||
for dialogue, summary in val_ds: | ||
dialogues.append(dialogue.numpy()) | ||
ground_truth_summaries.append(summary.numpy()) | ||
|
||
# Let's make a dummy call - the first call to XLA generally takes a bit longer. | ||
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH) | ||
|
||
# Generate summaries. | ||
generated_summaries = generate_text( | ||
bart_lm, | ||
val_ds.map(lambda dialogue, _: dialogue).batch(8), | ||
max_length=MAX_GENERATION_LENGTH, | ||
print_time_taken=True, | ||
) | ||
|
||
""" | ||
Let's see some of the summaries. | ||
""" | ||
for dialogue, generated_summary, ground_truth_summary in zip( | ||
dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5] | ||
): | ||
print("Dialogue:", dialogue) | ||
print("Generated Summary:", generated_summary) | ||
print("Ground Truth Summary:", ground_truth_summary) | ||
print("=============================") | ||
|
||
""" | ||
The generated summaries look awesome! Not bad for a model trained only for 1 | ||
epoch and on 5000 examples :) | ||
""" |
Oops, something went wrong.