Skip to content

Commit

Permalink
Migrate Review classification using Active learning to Keras 3. (kera…
Browse files Browse the repository at this point in the history
…s-team#1857)

* Migarte to Keras 3

* Migrate to Keras 3

* trim extra epoch outputs
  • Loading branch information
sachinprasadhs authored and sitamgithub-MSIT committed May 30, 2024
1 parent f27b049 commit 90a1b7f
Show file tree
Hide file tree
Showing 11 changed files with 1,080 additions and 283 deletions.
34 changes: 16 additions & 18 deletions examples/nlp/active_learning_review_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Title: Review Classification using Active Learning
Author: [Darshan Deshpande](https://twitter.com/getdarshan)
Date created: 2021/10/29
Last modified: 2021/10/29
Last modified: 2024/05/08
Description: Demonstrating the advantages of active learning through review classification.
Accelerator: GPU
Converted to Keras 3 by: [Sachin Prasad](https://github.com/sachinprasadhs)
"""

"""
Expand Down Expand Up @@ -51,10 +52,14 @@
## Importing required libraries
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow" # @param ["tensorflow", "jax", "torch"]
import keras
from keras import ops
from keras import layers
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import re
import string
Expand Down Expand Up @@ -169,16 +174,8 @@
"""


def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
return tf.strings.regex_replace(
stripped_html, f"[{re.escape(string.punctuation)}]", ""
)


vectorizer = layers.TextVectorization(
3000, standardize=custom_standardization, output_sequence_length=150
3000, standardize="lower_and_strip_punctuation", output_sequence_length=150
)
# Adapting the dataset
vectorizer.adapt(
Expand Down Expand Up @@ -289,7 +286,7 @@ def train_full_model(full_train_dataset, val_dataset, test_dataset):
callbacks=[
keras.callbacks.EarlyStopping(patience=4, verbose=1),
keras.callbacks.ModelCheckpoint(
"FullModelCheckpoint.h5", verbose=1, save_best_only=True
"FullModelCheckpoint.keras", verbose=1, save_best_only=True
),
],
)
Expand All @@ -303,7 +300,7 @@ def train_full_model(full_train_dataset, val_dataset, test_dataset):
)

# Loading the best checkpoint
model = keras.models.load_model("FullModelCheckpoint.h5")
model = keras.models.load_model("FullModelCheckpoint.keras")

print("-" * 100)
print(
Expand Down Expand Up @@ -370,6 +367,7 @@ def train_active_learning_models(
num_iterations=3,
sampling_size=5000,
):

# Creating lists for storing metrics
losses, val_losses, accuracies, val_accuracies = [], [], [], []

Expand All @@ -389,7 +387,7 @@ def train_active_learning_models(
# Defining checkpoints.
# The checkpoint callback is reused throughout the training since it only saves the best overall model.
checkpoint = keras.callbacks.ModelCheckpoint(
"AL_Model.h5", save_best_only=True, verbose=1
"AL_Model.keras", save_best_only=True, verbose=1
)
# Here, patience is set to 4. This can be set higher if desired.
early_stopping = keras.callbacks.EarlyStopping(patience=4, verbose=1)
Expand All @@ -413,9 +411,9 @@ def train_active_learning_models(
predictions = model.predict(test_dataset)

# Generating labels from the output probabilities
rounded = tf.where(tf.greater(predictions, 0.5), 1, 0)
rounded = ops.where(ops.greater(predictions, 0.5), 1, 0)

# Evaluating the number of zeros and ones incorrectly classified
# Evaluating the number of zeros and ones incorrrectly classified
_, _, false_negatives, false_positives = model.evaluate(test_dataset, verbose=0)

print("-" * 100)
Expand Down Expand Up @@ -482,7 +480,7 @@ def train_active_learning_models(
)

# Loading the best model from this training loop
model = keras.models.load_model("AL_Model.h5")
model = keras.models.load_model("AL_Model.keras")

# Plotting the overall history and evaluating the final model
plot_history(losses, val_losses, accuracies, val_accuracies)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 29 additions & 30 deletions examples/nlp/ipynb/active_learning_review_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Darshan Deshpande](https://twitter.com/getdarshan)<br>\n",
"**Date created:** 2021/10/29<br>\n",
"**Last modified:** 2021/10/29<br>\n",
"**Last modified:** 2024/05/08<br>\n",
"**Description:** Demonstrating the advantages of active learning through review classification."
]
},
Expand All @@ -33,7 +33,7 @@
"ensure consistency in labeling of new data.\n",
"\n",
"The process starts with annotating a small subset of the full dataset and training an\n",
"initial model. The best model checkpoint is saved and then tested on a balanced test\n",
"initial model. The best model checkpoint is saved and then tested on\u00a0a balanced test\n",
"set. The test set must be carefully sampled because the full training process will be\n",
"dependent on it. Once we have the initial evaluation scores, the oracle is tasked with\n",
"labeling more samples; the number of data points to be sampled is usually determined by\n",
Expand Down Expand Up @@ -70,16 +70,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\" # @param [\"tensorflow\", \"jax\", \"torch\"]\n",
"import keras\n",
"from keras import ops\n",
"from keras import layers\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import matplotlib.pyplot as plt\n",
"import re\n",
"import string\n",
Expand All @@ -102,7 +106,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -133,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -229,23 +233,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"def custom_standardization(input_data):\n",
" lowercase = tf.strings.lower(input_data)\n",
" stripped_html = tf.strings.regex_replace(lowercase, \"<br />\", \" \")\n",
" return tf.strings.regex_replace(\n",
" stripped_html, f\"[{re.escape(string.punctuation)}]\", \"\"\n",
" )\n",
"\n",
"\n",
"vectorizer = layers.TextVectorization(\n",
" 3000, standardize=custom_standardization, output_sequence_length=150\n",
" 3000, standardize=\"lower_and_strip_punctuation\", output_sequence_length=150\n",
")\n",
"# Adapting the dataset\n",
"vectorizer.adapt(\n",
Expand Down Expand Up @@ -283,12 +279,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"# Helper function for merging new history objects with older ones\n",
"def append_history(losses, val_losses, accuracy, val_accuracy, history):\n",
" losses = losses + history.history[\"loss\"]\n",
Expand All @@ -312,7 +309,8 @@
" plt.legend([\"train_accuracy\", \"val_accuracy\"])\n",
" plt.xlabel(\"Epochs\")\n",
" plt.ylabel(\"Accuracy\")\n",
" plt.show()\n"
" plt.show()\n",
""
]
},
{
Expand All @@ -331,7 +329,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -351,7 +349,8 @@
" ]\n",
" )\n",
" model.summary()\n",
" return model\n"
" return model\n",
""
]
},
{
Expand All @@ -368,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -395,7 +394,7 @@
" callbacks=[\n",
" keras.callbacks.EarlyStopping(patience=4, verbose=1),\n",
" keras.callbacks.ModelCheckpoint(\n",
" \"FullModelCheckpoint.h5\", verbose=1, save_best_only=True\n",
" \"FullModelCheckpoint.keras\", verbose=1, save_best_only=True\n",
" ),\n",
" ],\n",
" )\n",
Expand All @@ -409,7 +408,7 @@
" )\n",
"\n",
" # Loading the best checkpoint\n",
" model = keras.models.load_model(\"FullModelCheckpoint.h5\")\n",
" model = keras.models.load_model(\"FullModelCheckpoint.keras\")\n",
"\n",
" print(\"-\" * 100)\n",
" print(\n",
Expand Down Expand Up @@ -474,7 +473,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -510,7 +509,7 @@
" # Defining checkpoints.\n",
" # The checkpoint callback is reused throughout the training since it only saves the best overall model.\n",
" checkpoint = keras.callbacks.ModelCheckpoint(\n",
" \"AL_Model.h5\", save_best_only=True, verbose=1\n",
" \"AL_Model.keras\", save_best_only=True, verbose=1\n",
" )\n",
" # Here, patience is set to 4. This can be set higher if desired.\n",
" early_stopping = keras.callbacks.EarlyStopping(patience=4, verbose=1)\n",
Expand All @@ -534,9 +533,9 @@
" predictions = model.predict(test_dataset)\n",
"\n",
" # Generating labels from the output probabilities\n",
" rounded = tf.where(tf.greater(predictions, 0.5), 1, 0)\n",
" rounded = ops.where(ops.greater(predictions, 0.5), 1, 0)\n",
"\n",
" # Evaluating the number of zeros and ones classified\n",
" # Evaluating the number of zeros and ones incorrrectly classified\n",
" _, _, false_negatives, false_positives = model.evaluate(test_dataset, verbose=0)\n",
"\n",
" print(\"-\" * 100)\n",
Expand Down Expand Up @@ -603,7 +602,7 @@
" )\n",
"\n",
" # Loading the best model from this training loop\n",
" model = keras.models.load_model(\"AL_Model.h5\")\n",
" model = keras.models.load_model(\"AL_Model.keras\")\n",
"\n",
" # Plotting the overall history and evaluating the final model\n",
" plot_history(losses, val_losses, accuracies, val_accuracies)\n",
Expand Down Expand Up @@ -682,4 +681,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit 90a1b7f

Please sign in to comment.