diff --git a/examples/vision/img/shiftvit/shiftvit_46_0.png b/examples/vision/img/shiftvit/shiftvit_46_0.png
new file mode 100644
index 0000000000..ad616c3bd1
Binary files /dev/null and b/examples/vision/img/shiftvit/shiftvit_46_0.png differ
diff --git a/examples/vision/ipynb/shiftvit.ipynb b/examples/vision/ipynb/shiftvit.ipynb
index 13a30f6c2c..682d161424 100644
--- a/examples/vision/ipynb/shiftvit.ipynb
+++ b/examples/vision/ipynb/shiftvit.ipynb
@@ -8,9 +8,9 @@
"source": [
"# A Vision Transformer without Attention\n",
"\n",
- "**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
\n",
+ "**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)
\n",
"**Date created:** 2022/02/24
\n",
- "**Last modified:** 2022/03/01
\n",
+ "**Last modified:** 2022/10/15
\n",
"**Description:** A minimal implementation of ShiftViT."
]
},
@@ -39,12 +39,19 @@
"In this example, we minimally implement the paper with close alignement to the author's\n",
"[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).\n",
"\n",
- "This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can\n",
- "be installed using the following command:\n",
- "\n",
- "```shell\n",
- "pip install -qq -U tensorflow-addons\n",
- "```"
+ "This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can\n",
+ "be installed using the following command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -qq -U tensorflow-addons"
]
},
{
@@ -70,9 +77,11 @@
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
- "\n",
"import tensorflow_addons as tfa\n",
"\n",
+ "import pathlib\n",
+ "import glob\n",
+ "\n",
"# Setting seed for reproducibiltiy\n",
"SEED = 42\n",
"keras.utils.set_random_seed(SEED)"
@@ -128,6 +137,21 @@
" # TRAINING\n",
" epochs = 100\n",
"\n",
+ " # INFERENCE\n",
+ " label_map = {\n",
+ " 0: \"airplane\",\n",
+ " 1: \"automobile\",\n",
+ " 2: \"bird\",\n",
+ " 3: \"cat\",\n",
+ " 4: \"deer\",\n",
+ " 5: \"dog\",\n",
+ " 6: \"frog\",\n",
+ " 7: \"horse\",\n",
+ " 8: \"ship\",\n",
+ " 9: \"truck\",\n",
+ " }\n",
+ " tf_ds_batch_size = 20\n",
+ "\n",
"\n",
"config = Config()"
]
@@ -284,7 +308,7 @@
"source": [
"#### The MLP block\n",
"\n",
- "The MLP block is intended to be a stack of densely-connected layers.s"
+ "The MLP block is intended to be a stack of densely-connected layers"
]
},
{
@@ -315,7 +339,10 @@
"\n",
" self.mlp = keras.Sequential(\n",
" [\n",
- " layers.Dense(units=initial_filters, activation=tf.nn.gelu,),\n",
+ " layers.Dense(\n",
+ " units=initial_filters,\n",
+ " activation=tf.nn.gelu,\n",
+ " ),\n",
" layers.Dropout(rate=self.mlp_dropout_rate),\n",
" layers.Dense(units=input_channels),\n",
" layers.Dropout(rate=self.mlp_dropout_rate),\n",
@@ -382,7 +409,7 @@
"source": [
"#### Block\n",
"\n",
- "The most important operation in this paper is the **shift opperation**. In this section,\n",
+ "The most important operation in this paper is the **shift operation**. In this section,\n",
"we describe the shift operation and compare it with its original implementation provided\n",
"by the authors.\n",
"\n",
@@ -693,6 +720,24 @@
" if self.is_merge:\n",
" x = self.patch_merge(x)\n",
" return x\n",
+ "\n",
+ " # Since this is a custom layer, we need to overwrite get_config()\n",
+ " # so that model can be easily saved & loaded after training\n",
+ " def get_config(self):\n",
+ " config = super().get_config()\n",
+ " config.update(\n",
+ " {\n",
+ " \"epsilon\": self.epsilon,\n",
+ " \"mlp_dropout_rate\": self.mlp_dropout_rate,\n",
+ " \"num_shift_blocks\": self.num_shift_blocks,\n",
+ " \"stochastic_depth_rate\": self.stochastic_depth_rate,\n",
+ " \"is_merge\": self.is_merge,\n",
+ " \"num_div\": self.num_div,\n",
+ " \"shift_pixel\": self.shift_pixel,\n",
+ " \"mlp_expand_ratio\": self.mlp_expand_ratio,\n",
+ " }\n",
+ " )\n",
+ " return config\n",
""
]
},
@@ -780,6 +825,8 @@
" )\n",
" self.global_avg_pool = layers.GlobalAveragePooling2D()\n",
"\n",
+ " self.classifier = layers.Dense(config.num_classes)\n",
+ "\n",
" def get_config(self):\n",
" config = super().get_config()\n",
" config.update(\n",
@@ -788,6 +835,7 @@
" \"patch_projection\": self.patch_projection,\n",
" \"stages\": self.stages,\n",
" \"global_avg_pool\": self.global_avg_pool,\n",
+ " \"classifier\": self.classifier,\n",
" }\n",
" )\n",
" return config\n",
@@ -807,7 +855,8 @@
" x = stage(x, training=training)\n",
"\n",
" # Get the logits.\n",
- " logits = self.global_avg_pool(x)\n",
+ " x = self.global_avg_pool(x)\n",
+ " logits = self.classifier(x)\n",
"\n",
" # Calculate the loss and return it.\n",
" total_loss = self.compiled_loss(labels, logits)\n",
@@ -824,6 +873,7 @@
" self.data_augmentation.trainable_variables,\n",
" self.patch_projection.trainable_variables,\n",
" self.global_avg_pool.trainable_variables,\n",
+ " self.classifier.trainable_variables,\n",
" ]\n",
" train_vars = train_vars + [stage.trainable_variables for stage in self.stages]\n",
"\n",
@@ -845,6 +895,15 @@
" # Update the metrics\n",
" self.compiled_metrics.update_state(labels, logits)\n",
" return {m.name: m.result() for m in self.metrics}\n",
+ "\n",
+ " def call(self, images):\n",
+ " augmented_images = self.data_augmentation(images)\n",
+ " x = self.patch_projection(augmented_images)\n",
+ " for stage in self.stages:\n",
+ " x = stage(x, training=False)\n",
+ " x = self.global_avg_pool(x)\n",
+ " logits = self.classifier(x)\n",
+ " return logits\n",
""
]
},
@@ -976,6 +1035,15 @@
" return tf.where(\n",
" step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
" )\n",
+ "\n",
+ " def get_config(self):\n",
+ " config = {\n",
+ " \"lr_start\": self.lr_start,\n",
+ " \"lr_max\": self.lr_max,\n",
+ " \"total_steps\": self.total_steps,\n",
+ " \"warmup_steps\": self.warmup_steps,\n",
+ " }\n",
+ " return config\n",
""
]
},
@@ -996,6 +1064,11 @@
},
"outputs": [],
"source": [
+ "# pass sample data to the model so that input shape is available at the time of\n",
+ "# saving the model\n",
+ "sample_ds, _ = next(iter(train_ds))\n",
+ "model(sample_ds, training=False)\n",
+ "\n",
"# Get the total number of steps for training.\n",
"total_steps = int((len(x_train) / config.batch_size) * config.epochs)\n",
"\n",
@@ -1005,7 +1078,10 @@
"\n",
"# Initialize the warmupcosine schedule.\n",
"scheduled_lrs = WarmUpCosine(\n",
- " lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps,\n",
+ " lr_start=1e-5,\n",
+ " lr_max=1e-3,\n",
+ " warmup_steps=warmup_steps,\n",
+ " total_steps=total_steps,\n",
")\n",
"\n",
"# Get the optimizer.\n",
@@ -1029,7 +1105,11 @@
" epochs=config.epochs,\n",
" validation_data=val_ds,\n",
" callbacks=[\n",
- " keras.callbacks.EarlyStopping(monitor=\"val_accuracy\", patience=5, mode=\"auto\",)\n",
+ " keras.callbacks.EarlyStopping(\n",
+ " monitor=\"val_accuracy\",\n",
+ " patience=5,\n",
+ " mode=\"auto\",\n",
+ " )\n",
" ],\n",
")\n",
"\n",
@@ -1041,6 +1121,211 @@
"print(f\"Top 5 test accuracy: {acc_top5*100:0.2f}%\")"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Save trained model\n",
+ "\n",
+ "Since we created the model by Subclassing, we can't save the model in HDF5 format.\n",
+ "\n",
+ "It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model.save(\"ShiftViT\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Model inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "**Download sample data for inference**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip\n",
+ "!unzip -q inference_set.zip"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "**Load saved model**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# Custom objects are not included when the model is saved.\n",
+ "# At loading time, these objects need to be passed for reconstruction of the model\n",
+ "saved_model = tf.keras.models.load_model(\n",
+ " \"ShiftViT\",\n",
+ " custom_objects={\"WarmUpCosine\": WarmUpCosine, \"AdamW\": tfa.optimizers.AdamW},\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "**Utility functions for inference**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def process_image(img_path):\n",
+ " # read image file from string path\n",
+ " img = tf.io.read_file(img_path)\n",
+ "\n",
+ " # decode jpeg to uint8 tensor\n",
+ " img = tf.io.decode_jpeg(img, channels=3)\n",
+ "\n",
+ " # resize image to match input size accepted by model\n",
+ " # use `method` as `nearest` to preserve dtype of input passed to `resize()`\n",
+ " img = tf.image.resize(\n",
+ " img, [config.input_shape[0], config.input_shape[1]], method=\"nearest\"\n",
+ " )\n",
+ " return img\n",
+ "\n",
+ "\n",
+ "def create_tf_dataset(image_dir):\n",
+ " data_dir = pathlib.Path(image_dir)\n",
+ "\n",
+ " # create tf.data dataset using directory of images\n",
+ " predict_ds = tf.data.Dataset.list_files(str(data_dir / \"*.jpg\"), shuffle=False)\n",
+ "\n",
+ " # use map to convert string paths to uint8 image tensors\n",
+ " # setting `num_parallel_calls' helps in processing multiple images parallely\n",
+ " predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO)\n",
+ "\n",
+ " # create a Prefetch Dataset for better latency & throughput\n",
+ " predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO)\n",
+ " return predict_ds\n",
+ "\n",
+ "\n",
+ "def predict(predict_ds):\n",
+ " # ShiftViT model returns logits (non-normalized predictions)\n",
+ " logits = saved_model.predict(predict_ds)\n",
+ "\n",
+ " # normalize predictions by calling softmax()\n",
+ " probabilities = tf.nn.softmax(logits)\n",
+ " return probabilities\n",
+ "\n",
+ "\n",
+ "def get_predicted_class(probabilities):\n",
+ " pred_label = np.argmax(probabilities)\n",
+ " predicted_class = config.label_map[pred_label]\n",
+ " return predicted_class\n",
+ "\n",
+ "\n",
+ "def get_confidence_scores(probabilities):\n",
+ " # get the indices of the probability scores sorted in descending order\n",
+ " labels = np.argsort(probabilities)[::-1]\n",
+ " confidences = {\n",
+ " config.label_map[label]: np.round((probabilities[label]) * 100, 2)\n",
+ " for label in labels\n",
+ " }\n",
+ " return confidences\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "**Get predictions**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "img_dir = \"inference_set\"\n",
+ "predict_ds = create_tf_dataset(img_dir)\n",
+ "probabilities = predict(predict_ds)\n",
+ "print(f\"probabilities: {probabilities[0]}\")\n",
+ "confidences = get_confidence_scores(probabilities[0])\n",
+ "print(confidences)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "**View predictions**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(10, 10))\n",
+ "for images in predict_ds:\n",
+ " for i in range(min(6, probabilities.shape[0])):\n",
+ " ax = plt.subplot(3, 3, i + 1)\n",
+ " plt.imshow(images[i].numpy().astype(\"uint8\"))\n",
+ " predicted_class = get_predicted_class(probabilities[i])\n",
+ " plt.title(predicted_class)\n",
+ " plt.axis(\"off\")"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -1069,6 +1354,19 @@
"- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for\n",
"helping us with the Learning Rate Schedule."
]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "**Example available on HuggingFace**\n",
+ "\n",
+ "| Trained Model | Demo |\n",
+ "| :--: | :--: |\n",
+ "| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) |"
+ ]
}
],
"metadata": {
@@ -1100,4 +1398,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
+}
\ No newline at end of file
diff --git a/examples/vision/md/shiftvit.md b/examples/vision/md/shiftvit.md
index 6738bfb777..0b583bd1c2 100644
--- a/examples/vision/md/shiftvit.md
+++ b/examples/vision/md/shiftvit.md
@@ -1,8 +1,8 @@
# A Vision Transformer without Attention
-**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)
**Date created:** 2022/02/24
-**Last modified:** 2022/03/01
+**Last modified:** 2022/10/15
**Description:** A minimal implementation of ShiftViT.
@@ -30,13 +30,15 @@ operation with a shifting operation.
In this example, we minimally implement the paper with close alignement to the author's
[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).
-This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can
+This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can
be installed using the following command:
-```shell
-pip install -qq -U tensorflow-addons
+
+```python
+!pip install -qq -U tensorflow-addons
```
+
---
## Setup and imports
@@ -48,9 +50,11 @@ import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
-
import tensorflow_addons as tfa
+import pathlib
+import glob
+
# Setting seed for reproducibiltiy
SEED = 42
keras.utils.set_random_seed(SEED)
@@ -94,6 +98,21 @@ class Config(object):
# TRAINING
epochs = 100
+ # INFERENCE
+ label_map = {
+ 0: "airplane",
+ 1: "automobile",
+ 2: "bird",
+ 3: "cat",
+ 4: "deer",
+ 5: "dog",
+ 6: "frog",
+ 7: "horse",
+ 8: "ship",
+ 9: "truck",
+ }
+ tf_ds_batch_size = 20
+
config = Config()
```
@@ -127,14 +146,12 @@ test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
```
+Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
+170498071/170498071 [==============================] - 3s 0us/step
Training samples: 40000
Validation samples: 10000
Testing samples: 10000
-2022-03-01 03:10:21.342684: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
-To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
-2022-03-01 03:10:21.850844: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38420 MB memory: -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:61:00.0, compute capability: 8.0
-
```
---
@@ -219,7 +236,7 @@ The Shift Block as shown in Fig. 3, comprises of the following:
#### The MLP block
-The MLP block is intended to be a stack of densely-connected layers.s
+The MLP block is intended to be a stack of densely-connected layers
```python
@@ -243,7 +260,10 @@ class MLP(layers.Layer):
self.mlp = keras.Sequential(
[
- layers.Dense(units=initial_filters, activation=tf.nn.gelu,),
+ layers.Dense(
+ units=initial_filters,
+ activation=tf.nn.gelu,
+ ),
layers.Dropout(rate=self.mlp_dropout_rate),
layers.Dense(units=input_channels),
layers.Dropout(rate=self.mlp_dropout_rate),
@@ -291,7 +311,7 @@ class DropPath(layers.Layer):
#### Block
-The most important operation in this paper is the **shift opperation**. In this section,
+The most important operation in this paper is the **shift operation**. In this section,
we describe the shift operation and compare it with its original implementation provided
by the authors.
@@ -563,6 +583,24 @@ class StackedShiftBlocks(layers.Layer):
x = self.patch_merge(x)
return x
+ # Since this is a custom layer, we need to overwrite get_config()
+ # so that model can be easily saved & loaded after training
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "epsilon": self.epsilon,
+ "mlp_dropout_rate": self.mlp_dropout_rate,
+ "num_shift_blocks": self.num_shift_blocks,
+ "stochastic_depth_rate": self.stochastic_depth_rate,
+ "is_merge": self.is_merge,
+ "num_div": self.num_div,
+ "shift_pixel": self.shift_pixel,
+ "mlp_expand_ratio": self.mlp_expand_ratio,
+ }
+ )
+ return config
+
```
---
@@ -637,6 +675,8 @@ class ShiftViTModel(keras.Model):
)
self.global_avg_pool = layers.GlobalAveragePooling2D()
+ self.classifier = layers.Dense(config.num_classes)
+
def get_config(self):
config = super().get_config()
config.update(
@@ -645,6 +685,7 @@ class ShiftViTModel(keras.Model):
"patch_projection": self.patch_projection,
"stages": self.stages,
"global_avg_pool": self.global_avg_pool,
+ "classifier": self.classifier,
}
)
return config
@@ -664,7 +705,8 @@ class ShiftViTModel(keras.Model):
x = stage(x, training=training)
# Get the logits.
- logits = self.global_avg_pool(x)
+ x = self.global_avg_pool(x)
+ logits = self.classifier(x)
# Calculate the loss and return it.
total_loss = self.compiled_loss(labels, logits)
@@ -681,6 +723,7 @@ class ShiftViTModel(keras.Model):
self.data_augmentation.trainable_variables,
self.patch_projection.trainable_variables,
self.global_avg_pool.trainable_variables,
+ self.classifier.trainable_variables,
]
train_vars = train_vars + [stage.trainable_variables for stage in self.stages]
@@ -703,6 +746,15 @@ class ShiftViTModel(keras.Model):
self.compiled_metrics.update_state(labels, logits)
return {m.name: m.result() for m in self.metrics}
+ def call(self, images):
+ augmented_images = self.data_augmentation(images)
+ x = self.patch_projection(augmented_images)
+ for stage in self.stages:
+ x = stage(x, training=False)
+ x = self.global_avg_pool(x)
+ logits = self.classifier(x)
+ return logits
+
```
---
@@ -810,6 +862,15 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
+ def get_config(self):
+ config = {
+ "lr_start": self.lr_start,
+ "lr_max": self.lr_max,
+ "total_steps": self.total_steps,
+ "warmup_steps": self.warmup_steps,
+ }
+ return config
+
```
---
@@ -817,6 +878,11 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
```python
+# pass sample data to the model so that input shape is available at the time of
+# saving the model
+sample_ds, _ = next(iter(train_ds))
+model(sample_ds, training=False)
+
# Get the total number of steps for training.
total_steps = int((len(x_train) / config.batch_size) * config.epochs)
@@ -826,7 +892,10 @@ warmup_steps = int(total_steps * warmup_epoch_percentage)
# Initialize the warmupcosine schedule.
scheduled_lrs = WarmUpCosine(
- lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps,
+ lr_start=1e-5,
+ lr_max=1e-3,
+ warmup_steps=warmup_steps,
+ total_steps=total_steps,
)
# Get the optimizer.
@@ -850,7 +919,11 @@ history = model.fit(
epochs=config.epochs,
validation_data=val_ds,
callbacks=[
- keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5, mode="auto",)
+ keras.callbacks.EarlyStopping(
+ monitor="val_accuracy",
+ patience=5,
+ mode="auto",
+ )
],
)
@@ -865,109 +938,249 @@ print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
```
Epoch 1/100
-
-2022-03-01 03:10:41.373231: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8202
-2022-03-01 03:10:43.145958: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
-
-157/157 [==============================] - 34s 84ms/step - loss: 3.2975 - accuracy: 0.1084 - top-5-accuracy: 0.4806 - val_loss: 2.1575 - val_accuracy: 0.2017 - val_top-5-accuracy: 0.7184
+157/157 [==============================] - 72s 332ms/step - loss: 2.3844 - accuracy: 0.1444 - top-5-accuracy: 0.6051 - val_loss: 2.0984 - val_accuracy: 0.2610 - val_top-5-accuracy: 0.7638
Epoch 2/100
-157/157 [==============================] - 11s 67ms/step - loss: 2.1727 - accuracy: 0.2289 - top-5-accuracy: 0.7516 - val_loss: 1.8819 - val_accuracy: 0.3182 - val_top-5-accuracy: 0.8386
+157/157 [==============================] - 49s 314ms/step - loss: 1.9457 - accuracy: 0.2893 - top-5-accuracy: 0.8103 - val_loss: 1.9459 - val_accuracy: 0.3356 - val_top-5-accuracy: 0.8614
Epoch 3/100
-157/157 [==============================] - 10s 67ms/step - loss: 1.8169 - accuracy: 0.3426 - top-5-accuracy: 0.8592 - val_loss: 1.6174 - val_accuracy: 0.4053 - val_top-5-accuracy: 0.8934
+157/157 [==============================] - 50s 316ms/step - loss: 1.7093 - accuracy: 0.3810 - top-5-accuracy: 0.8761 - val_loss: 1.5349 - val_accuracy: 0.4585 - val_top-5-accuracy: 0.9045
Epoch 4/100
-157/157 [==============================] - 10s 67ms/step - loss: 1.6215 - accuracy: 0.4092 - top-5-accuracy: 0.8983 - val_loss: 1.4239 - val_accuracy: 0.4903 - val_top-5-accuracy: 0.9216
+157/157 [==============================] - 49s 315ms/step - loss: 1.5473 - accuracy: 0.4374 - top-5-accuracy: 0.9090 - val_loss: 1.4257 - val_accuracy: 0.4862 - val_top-5-accuracy: 0.9298
Epoch 5/100
-157/157 [==============================] - 10s 66ms/step - loss: 1.5081 - accuracy: 0.4571 - top-5-accuracy: 0.9148 - val_loss: 1.3359 - val_accuracy: 0.5161 - val_top-5-accuracy: 0.9369
+157/157 [==============================] - 50s 316ms/step - loss: 1.4316 - accuracy: 0.4816 - top-5-accuracy: 0.9243 - val_loss: 1.4032 - val_accuracy: 0.5092 - val_top-5-accuracy: 0.9362
Epoch 6/100
-157/157 [==============================] - 11s 68ms/step - loss: 1.4282 - accuracy: 0.4868 - top-5-accuracy: 0.9249 - val_loss: 1.2929 - val_accuracy: 0.5347 - val_top-5-accuracy: 0.9404
+157/157 [==============================] - 50s 316ms/step - loss: 1.3588 - accuracy: 0.5131 - top-5-accuracy: 0.9333 - val_loss: 1.2893 - val_accuracy: 0.5411 - val_top-5-accuracy: 0.9457
Epoch 7/100
-157/157 [==============================] - 10s 66ms/step - loss: 1.3465 - accuracy: 0.5181 - top-5-accuracy: 0.9362 - val_loss: 1.2653 - val_accuracy: 0.5497 - val_top-5-accuracy: 0.9449
+157/157 [==============================] - 50s 316ms/step - loss: 1.2894 - accuracy: 0.5385 - top-5-accuracy: 0.9410 - val_loss: 1.2922 - val_accuracy: 0.5416 - val_top-5-accuracy: 0.9432
Epoch 8/100
-157/157 [==============================] - 10s 67ms/step - loss: 1.2907 - accuracy: 0.5400 - top-5-accuracy: 0.9416 - val_loss: 1.1919 - val_accuracy: 0.5753 - val_top-5-accuracy: 0.9515
+157/157 [==============================] - 49s 315ms/step - loss: 1.2388 - accuracy: 0.5568 - top-5-accuracy: 0.9468 - val_loss: 1.2100 - val_accuracy: 0.5733 - val_top-5-accuracy: 0.9545
Epoch 9/100
-157/157 [==============================] - 11s 67ms/step - loss: 1.2247 - accuracy: 0.5644 - top-5-accuracy: 0.9480 - val_loss: 1.1741 - val_accuracy: 0.5742 - val_top-5-accuracy: 0.9563
+157/157 [==============================] - 49s 315ms/step - loss: 1.2043 - accuracy: 0.5698 - top-5-accuracy: 0.9491 - val_loss: 1.2166 - val_accuracy: 0.5675 - val_top-5-accuracy: 0.9520
Epoch 10/100
-157/157 [==============================] - 11s 67ms/step - loss: 1.1983 - accuracy: 0.5760 - top-5-accuracy: 0.9505 - val_loss: 1.4545 - val_accuracy: 0.4804 - val_top-5-accuracy: 0.9198
+157/157 [==============================] - 49s 315ms/step - loss: 1.1694 - accuracy: 0.5861 - top-5-accuracy: 0.9528 - val_loss: 1.1738 - val_accuracy: 0.5883 - val_top-5-accuracy: 0.9541
Epoch 11/100
-157/157 [==============================] - 10s 66ms/step - loss: 1.2002 - accuracy: 0.5766 - top-5-accuracy: 0.9510 - val_loss: 1.1129 - val_accuracy: 0.6055 - val_top-5-accuracy: 0.9593
+157/157 [==============================] - 50s 316ms/step - loss: 1.1290 - accuracy: 0.5994 - top-5-accuracy: 0.9575 - val_loss: 1.1161 - val_accuracy: 0.6064 - val_top-5-accuracy: 0.9618
Epoch 12/100
-157/157 [==============================] - 10s 66ms/step - loss: 1.1309 - accuracy: 0.5990 - top-5-accuracy: 0.9575 - val_loss: 1.0369 - val_accuracy: 0.6341 - val_top-5-accuracy: 0.9638
+157/157 [==============================] - 50s 316ms/step - loss: 1.0861 - accuracy: 0.6157 - top-5-accuracy: 0.9602 - val_loss: 1.1220 - val_accuracy: 0.6133 - val_top-5-accuracy: 0.9576
Epoch 13/100
-157/157 [==============================] - 10s 66ms/step - loss: 1.0786 - accuracy: 0.6204 - top-5-accuracy: 0.9613 - val_loss: 1.0802 - val_accuracy: 0.6193 - val_top-5-accuracy: 0.9594
+157/157 [==============================] - 49s 315ms/step - loss: 1.0766 - accuracy: 0.6178 - top-5-accuracy: 0.9612 - val_loss: 1.0108 - val_accuracy: 0.6402 - val_top-5-accuracy: 0.9681
Epoch 14/100
-157/157 [==============================] - 10s 65ms/step - loss: 1.0438 - accuracy: 0.6330 - top-5-accuracy: 0.9640 - val_loss: 0.9584 - val_accuracy: 0.6596 - val_top-5-accuracy: 0.9713
+157/157 [==============================] - 49s 315ms/step - loss: 1.0179 - accuracy: 0.6416 - top-5-accuracy: 0.9658 - val_loss: 1.0196 - val_accuracy: 0.6405 - val_top-5-accuracy: 0.9667
Epoch 15/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.9957 - accuracy: 0.6496 - top-5-accuracy: 0.9684 - val_loss: 0.9530 - val_accuracy: 0.6636 - val_top-5-accuracy: 0.9712
+157/157 [==============================] - 50s 316ms/step - loss: 1.0028 - accuracy: 0.6470 - top-5-accuracy: 0.9678 - val_loss: 1.0113 - val_accuracy: 0.6415 - val_top-5-accuracy: 0.9672
Epoch 16/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.9710 - accuracy: 0.6599 - top-5-accuracy: 0.9696 - val_loss: 0.8856 - val_accuracy: 0.6863 - val_top-5-accuracy: 0.9756
+157/157 [==============================] - 50s 316ms/step - loss: 0.9613 - accuracy: 0.6611 - top-5-accuracy: 0.9710 - val_loss: 1.0516 - val_accuracy: 0.6406 - val_top-5-accuracy: 0.9596
Epoch 17/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.9316 - accuracy: 0.6706 - top-5-accuracy: 0.9721 - val_loss: 0.9919 - val_accuracy: 0.6480 - val_top-5-accuracy: 0.9671
+157/157 [==============================] - 50s 316ms/step - loss: 0.9262 - accuracy: 0.6740 - top-5-accuracy: 0.9729 - val_loss: 0.9010 - val_accuracy: 0.6844 - val_top-5-accuracy: 0.9750
Epoch 18/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.8899 - accuracy: 0.6884 - top-5-accuracy: 0.9763 - val_loss: 0.8753 - val_accuracy: 0.6949 - val_top-5-accuracy: 0.9752
+157/157 [==============================] - 50s 316ms/step - loss: 0.8768 - accuracy: 0.6916 - top-5-accuracy: 0.9769 - val_loss: 0.8862 - val_accuracy: 0.6908 - val_top-5-accuracy: 0.9767
Epoch 19/100
-157/157 [==============================] - 10s 64ms/step - loss: 0.8529 - accuracy: 0.6979 - top-5-accuracy: 0.9772 - val_loss: 0.8793 - val_accuracy: 0.6943 - val_top-5-accuracy: 0.9754
+157/157 [==============================] - 49s 315ms/step - loss: 0.8595 - accuracy: 0.6984 - top-5-accuracy: 0.9768 - val_loss: 0.8732 - val_accuracy: 0.6982 - val_top-5-accuracy: 0.9738
Epoch 20/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.8509 - accuracy: 0.7009 - top-5-accuracy: 0.9783 - val_loss: 0.8183 - val_accuracy: 0.7174 - val_top-5-accuracy: 0.9763
+157/157 [==============================] - 50s 317ms/step - loss: 0.8252 - accuracy: 0.7103 - top-5-accuracy: 0.9793 - val_loss: 0.9330 - val_accuracy: 0.6745 - val_top-5-accuracy: 0.9718
Epoch 21/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.8087 - accuracy: 0.7143 - top-5-accuracy: 0.9809 - val_loss: 0.7885 - val_accuracy: 0.7276 - val_top-5-accuracy: 0.9769
+157/157 [==============================] - 51s 322ms/step - loss: 0.8003 - accuracy: 0.7180 - top-5-accuracy: 0.9814 - val_loss: 0.8912 - val_accuracy: 0.6948 - val_top-5-accuracy: 0.9728
Epoch 22/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.8004 - accuracy: 0.7192 - top-5-accuracy: 0.9811 - val_loss: 0.7601 - val_accuracy: 0.7371 - val_top-5-accuracy: 0.9805
+157/157 [==============================] - 51s 326ms/step - loss: 0.7651 - accuracy: 0.7317 - top-5-accuracy: 0.9829 - val_loss: 0.7894 - val_accuracy: 0.7277 - val_top-5-accuracy: 0.9791
Epoch 23/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.7665 - accuracy: 0.7304 - top-5-accuracy: 0.9816 - val_loss: 0.7564 - val_accuracy: 0.7412 - val_top-5-accuracy: 0.9808
+157/157 [==============================] - 52s 328ms/step - loss: 0.7372 - accuracy: 0.7415 - top-5-accuracy: 0.9843 - val_loss: 0.7752 - val_accuracy: 0.7284 - val_top-5-accuracy: 0.9804
Epoch 24/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.7599 - accuracy: 0.7344 - top-5-accuracy: 0.9832 - val_loss: 0.7475 - val_accuracy: 0.7389 - val_top-5-accuracy: 0.9822
+157/157 [==============================] - 51s 327ms/step - loss: 0.7324 - accuracy: 0.7423 - top-5-accuracy: 0.9852 - val_loss: 0.7949 - val_accuracy: 0.7340 - val_top-5-accuracy: 0.9792
Epoch 25/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.7398 - accuracy: 0.7427 - top-5-accuracy: 0.9833 - val_loss: 0.7211 - val_accuracy: 0.7504 - val_top-5-accuracy: 0.9829
+157/157 [==============================] - 51s 323ms/step - loss: 0.7051 - accuracy: 0.7512 - top-5-accuracy: 0.9858 - val_loss: 0.7967 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9787
Epoch 26/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.7114 - accuracy: 0.7500 - top-5-accuracy: 0.9857 - val_loss: 0.7385 - val_accuracy: 0.7462 - val_top-5-accuracy: 0.9822
+157/157 [==============================] - 51s 323ms/step - loss: 0.6832 - accuracy: 0.7577 - top-5-accuracy: 0.9870 - val_loss: 0.7840 - val_accuracy: 0.7322 - val_top-5-accuracy: 0.9807
Epoch 27/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.6954 - accuracy: 0.7577 - top-5-accuracy: 0.9851 - val_loss: 0.7477 - val_accuracy: 0.7402 - val_top-5-accuracy: 0.9802
+157/157 [==============================] - 51s 322ms/step - loss: 0.6609 - accuracy: 0.7654 - top-5-accuracy: 0.9877 - val_loss: 0.7447 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9816
Epoch 28/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.6807 - accuracy: 0.7588 - top-5-accuracy: 0.9871 - val_loss: 0.7275 - val_accuracy: 0.7536 - val_top-5-accuracy: 0.9822
+157/157 [==============================] - 50s 319ms/step - loss: 0.6495 - accuracy: 0.7724 - top-5-accuracy: 0.9883 - val_loss: 0.7885 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9817
Epoch 29/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.6719 - accuracy: 0.7648 - top-5-accuracy: 0.9876 - val_loss: 0.7261 - val_accuracy: 0.7487 - val_top-5-accuracy: 0.9815
+157/157 [==============================] - 50s 317ms/step - loss: 0.6491 - accuracy: 0.7707 - top-5-accuracy: 0.9885 - val_loss: 0.7539 - val_accuracy: 0.7458 - val_top-5-accuracy: 0.9821
Epoch 30/100
-157/157 [==============================] - 10s 65ms/step - loss: 0.6578 - accuracy: 0.7696 - top-5-accuracy: 0.9871 - val_loss: 0.6932 - val_accuracy: 0.7641 - val_top-5-accuracy: 0.9833
+157/157 [==============================] - 50s 317ms/step - loss: 0.6213 - accuracy: 0.7823 - top-5-accuracy: 0.9888 - val_loss: 0.7571 - val_accuracy: 0.7470 - val_top-5-accuracy: 0.9815
Epoch 31/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.6489 - accuracy: 0.7740 - top-5-accuracy: 0.9877 - val_loss: 0.7400 - val_accuracy: 0.7486 - val_top-5-accuracy: 0.9820
+157/157 [==============================] - 50s 318ms/step - loss: 0.5976 - accuracy: 0.7902 - top-5-accuracy: 0.9906 - val_loss: 0.7430 - val_accuracy: 0.7508 - val_top-5-accuracy: 0.9817
Epoch 32/100
-157/157 [==============================] - 10s 65ms/step - loss: 0.6290 - accuracy: 0.7812 - top-5-accuracy: 0.9895 - val_loss: 0.6954 - val_accuracy: 0.7628 - val_top-5-accuracy: 0.9847
+157/157 [==============================] - 50s 318ms/step - loss: 0.5932 - accuracy: 0.7898 - top-5-accuracy: 0.9910 - val_loss: 0.7545 - val_accuracy: 0.7469 - val_top-5-accuracy: 0.9793
Epoch 33/100
-157/157 [==============================] - 10s 67ms/step - loss: 0.6194 - accuracy: 0.7826 - top-5-accuracy: 0.9894 - val_loss: 0.6913 - val_accuracy: 0.7619 - val_top-5-accuracy: 0.9842
+157/157 [==============================] - 50s 318ms/step - loss: 0.5977 - accuracy: 0.7850 - top-5-accuracy: 0.9913 - val_loss: 0.7200 - val_accuracy: 0.7569 - val_top-5-accuracy: 0.9830
Epoch 34/100
-157/157 [==============================] - 10s 65ms/step - loss: 0.5917 - accuracy: 0.7930 - top-5-accuracy: 0.9902 - val_loss: 0.6879 - val_accuracy: 0.7715 - val_top-5-accuracy: 0.9831
+157/157 [==============================] - 50s 317ms/step - loss: 0.5552 - accuracy: 0.8041 - top-5-accuracy: 0.9920 - val_loss: 0.7377 - val_accuracy: 0.7552 - val_top-5-accuracy: 0.9818
Epoch 35/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.5878 - accuracy: 0.7916 - top-5-accuracy: 0.9907 - val_loss: 0.6759 - val_accuracy: 0.7720 - val_top-5-accuracy: 0.9849
+157/157 [==============================] - 50s 319ms/step - loss: 0.5509 - accuracy: 0.8056 - top-5-accuracy: 0.9921 - val_loss: 0.8125 - val_accuracy: 0.7331 - val_top-5-accuracy: 0.9782
Epoch 36/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.5713 - accuracy: 0.8004 - top-5-accuracy: 0.9913 - val_loss: 0.6920 - val_accuracy: 0.7657 - val_top-5-accuracy: 0.9841
+157/157 [==============================] - 50s 317ms/step - loss: 0.5296 - accuracy: 0.8116 - top-5-accuracy: 0.9933 - val_loss: 0.6900 - val_accuracy: 0.7680 - val_top-5-accuracy: 0.9849
Epoch 37/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.5590 - accuracy: 0.8040 - top-5-accuracy: 0.9913 - val_loss: 0.6790 - val_accuracy: 0.7718 - val_top-5-accuracy: 0.9831
+157/157 [==============================] - 50s 316ms/step - loss: 0.5151 - accuracy: 0.8170 - top-5-accuracy: 0.9941 - val_loss: 0.7275 - val_accuracy: 0.7610 - val_top-5-accuracy: 0.9841
Epoch 38/100
-157/157 [==============================] - 11s 67ms/step - loss: 0.5445 - accuracy: 0.8114 - top-5-accuracy: 0.9926 - val_loss: 0.6756 - val_accuracy: 0.7720 - val_top-5-accuracy: 0.9852
+157/157 [==============================] - 50s 317ms/step - loss: 0.5069 - accuracy: 0.8217 - top-5-accuracy: 0.9936 - val_loss: 0.7067 - val_accuracy: 0.7703 - val_top-5-accuracy: 0.9835
Epoch 39/100
-157/157 [==============================] - 11s 67ms/step - loss: 0.5292 - accuracy: 0.8155 - top-5-accuracy: 0.9930 - val_loss: 0.6578 - val_accuracy: 0.7807 - val_top-5-accuracy: 0.9845
+157/157 [==============================] - 50s 318ms/step - loss: 0.4771 - accuracy: 0.8304 - top-5-accuracy: 0.9945 - val_loss: 0.7110 - val_accuracy: 0.7668 - val_top-5-accuracy: 0.9836
Epoch 40/100
-157/157 [==============================] - 11s 68ms/step - loss: 0.5169 - accuracy: 0.8181 - top-5-accuracy: 0.9926 - val_loss: 0.6582 - val_accuracy: 0.7795 - val_top-5-accuracy: 0.9849
+157/157 [==============================] - 50s 317ms/step - loss: 0.4675 - accuracy: 0.8350 - top-5-accuracy: 0.9956 - val_loss: 0.7130 - val_accuracy: 0.7688 - val_top-5-accuracy: 0.9829
Epoch 41/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.5108 - accuracy: 0.8217 - top-5-accuracy: 0.9937 - val_loss: 0.6344 - val_accuracy: 0.7846 - val_top-5-accuracy: 0.9855
+157/157 [==============================] - 50s 319ms/step - loss: 0.4586 - accuracy: 0.8382 - top-5-accuracy: 0.9959 - val_loss: 0.7331 - val_accuracy: 0.7598 - val_top-5-accuracy: 0.9806
Epoch 42/100
-157/157 [==============================] - 10s 65ms/step - loss: 0.5056 - accuracy: 0.8220 - top-5-accuracy: 0.9936 - val_loss: 0.6723 - val_accuracy: 0.7744 - val_top-5-accuracy: 0.9851
+157/157 [==============================] - 50s 318ms/step - loss: 0.4558 - accuracy: 0.8380 - top-5-accuracy: 0.9959 - val_loss: 0.7187 - val_accuracy: 0.7722 - val_top-5-accuracy: 0.9832
Epoch 43/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.4824 - accuracy: 0.8317 - top-5-accuracy: 0.9943 - val_loss: 0.6800 - val_accuracy: 0.7771 - val_top-5-accuracy: 0.9834
+157/157 [==============================] - 50s 320ms/step - loss: 0.4356 - accuracy: 0.8450 - top-5-accuracy: 0.9958 - val_loss: 0.7162 - val_accuracy: 0.7693 - val_top-5-accuracy: 0.9850
Epoch 44/100
-157/157 [==============================] - 10s 67ms/step - loss: 0.4719 - accuracy: 0.8339 - top-5-accuracy: 0.9938 - val_loss: 0.6742 - val_accuracy: 0.7785 - val_top-5-accuracy: 0.9840
+157/157 [==============================] - 49s 314ms/step - loss: 0.4425 - accuracy: 0.8433 - top-5-accuracy: 0.9958 - val_loss: 0.7061 - val_accuracy: 0.7698 - val_top-5-accuracy: 0.9853
Epoch 45/100
-157/157 [==============================] - 10s 65ms/step - loss: 0.4605 - accuracy: 0.8379 - top-5-accuracy: 0.9953 - val_loss: 0.6732 - val_accuracy: 0.7781 - val_top-5-accuracy: 0.9841
+157/157 [==============================] - 49s 314ms/step - loss: 0.4072 - accuracy: 0.8551 - top-5-accuracy: 0.9967 - val_loss: 0.7025 - val_accuracy: 0.7820 - val_top-5-accuracy: 0.9848
Epoch 46/100
-157/157 [==============================] - 10s 66ms/step - loss: 0.4608 - accuracy: 0.8390 - top-5-accuracy: 0.9947 - val_loss: 0.6547 - val_accuracy: 0.7846 - val_top-5-accuracy: 0.9852
+157/157 [==============================] - 49s 314ms/step - loss: 0.3865 - accuracy: 0.8644 - top-5-accuracy: 0.9970 - val_loss: 0.7178 - val_accuracy: 0.7740 - val_top-5-accuracy: 0.9844
+Epoch 47/100
+157/157 [==============================] - 49s 313ms/step - loss: 0.3718 - accuracy: 0.8694 - top-5-accuracy: 0.9973 - val_loss: 0.7216 - val_accuracy: 0.7768 - val_top-5-accuracy: 0.9828
+Epoch 48/100
+157/157 [==============================] - 49s 314ms/step - loss: 0.3733 - accuracy: 0.8673 - top-5-accuracy: 0.9970 - val_loss: 0.7440 - val_accuracy: 0.7713 - val_top-5-accuracy: 0.9841
+Epoch 49/100
+157/157 [==============================] - 49s 313ms/step - loss: 0.3531 - accuracy: 0.8741 - top-5-accuracy: 0.9979 - val_loss: 0.7220 - val_accuracy: 0.7738 - val_top-5-accuracy: 0.9848
+Epoch 50/100
+157/157 [==============================] - 49s 314ms/step - loss: 0.3502 - accuracy: 0.8738 - top-5-accuracy: 0.9980 - val_loss: 0.7245 - val_accuracy: 0.7734 - val_top-5-accuracy: 0.9836
TESTING
-40/40 [==============================] - 1s 22ms/step - loss: 0.6801 - accuracy: 0.7720 - top-5-accuracy: 0.9864
-Loss: 0.68
-Top 1 test accuracy: 77.20%
-Top 5 test accuracy: 98.64%
+40/40 [==============================] - 2s 56ms/step - loss: 0.7336 - accuracy: 0.7638 - top-5-accuracy: 0.9855
+Loss: 0.73
+Top 1 test accuracy: 76.38%
+Top 5 test accuracy: 98.55%
```
+---
+## Save trained model
+
+Since we created the model by Subclassing, we can't save the model in HDF5 format.
+
+It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well.
+
+
+```python
+model.save("ShiftViT")
+```
+
+---
+## Model inference
+
+**Download sample data for inference**
+
+
+```python
+!wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip
+!unzip -q inference_set.zip
+```
+
+**Load saved model**
+
+
+```python
+# Custom objects are not included when the model is saved.
+# At loading time, these objects need to be passed for reconstruction of the model
+saved_model = tf.keras.models.load_model(
+ "ShiftViT",
+ custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
+)
+```
+
+
+**Utility functions for inference**
+
+
+```python
+
+def process_image(img_path):
+ # read image file from string path
+ img = tf.io.read_file(img_path)
+
+ # decode jpeg to uint8 tensor
+ img = tf.io.decode_jpeg(img, channels=3)
+
+ # resize image to match input size accepted by model
+ # use `method` as `nearest` to preserve dtype of input passed to `resize()`
+ img = tf.image.resize(
+ img, [config.input_shape[0], config.input_shape[1]], method="nearest"
+ )
+ return img
+
+
+def create_tf_dataset(image_dir):
+ data_dir = pathlib.Path(image_dir)
+
+ # create tf.data dataset using directory of images
+ predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False)
+
+ # use map to convert string paths to uint8 image tensors
+ # setting `num_parallel_calls' helps in processing multiple images parallely
+ predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO)
+
+ # create a Prefetch Dataset for better latency & throughput
+ predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO)
+ return predict_ds
+
+
+def predict(predict_ds):
+ # ShiftViT model returns logits (non-normalized predictions)
+ logits = saved_model.predict(predict_ds)
+
+ # normalize predictions by calling softmax()
+ probabilities = tf.nn.softmax(logits)
+ return probabilities
+
+
+def get_predicted_class(probabilities):
+ pred_label = np.argmax(probabilities)
+ predicted_class = config.label_map[pred_label]
+ return predicted_class
+
+
+def get_confidence_scores(probabilities):
+ # get the indices of the probability scores sorted in descending order
+ labels = np.argsort(probabilities)[::-1]
+ confidences = {
+ config.label_map[label]: np.round((probabilities[label]) * 100, 2)
+ for label in labels
+ }
+ return confidences
+
+```
+
+**Get predictions**
+
+
+```python
+img_dir = "inference_set"
+predict_ds = create_tf_dataset(img_dir)
+probabilities = predict(predict_ds)
+print(f"probabilities: {probabilities[0]}")
+confidences = get_confidence_scores(probabilities[0])
+print(confidences)
+```
+
+
+```
+1/1 [==============================] - 2s 2s/step
+probabilities: [8.7329084e-01 1.3162658e-03 6.1781306e-05 1.9132349e-05 4.4482469e-05
+ 1.8182898e-06 2.2834571e-05 1.1466043e-05 1.2504059e-01 1.9084632e-04]
+{'airplane': 87.33, 'ship': 12.5, 'automobile': 0.13, 'truck': 0.02, 'bird': 0.01, 'deer': 0.0, 'frog': 0.0, 'cat': 0.0, 'horse': 0.0, 'dog': 0.0}
+
+```
+
+**View predictions**
+
+
+```python
+plt.figure(figsize=(10, 10))
+for images in predict_ds:
+ for i in range(min(6, probabilities.shape[0])):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(images[i].numpy().astype("uint8"))
+ predicted_class = get_predicted_class(probabilities[i])
+ plt.title(predicted_class)
+ plt.axis("off")
+```
+
+
+![png](/img/examples/vision/shiftvit/shiftvit_46_0.png)
+
+
---
## Conclusion
@@ -990,3 +1203,9 @@ GPU credits.
library.
- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for
helping us with the Learning Rate Schedule.
+
+**Example available on HuggingFace**
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) |
diff --git a/examples/vision/shiftvit.py b/examples/vision/shiftvit.py
index 3d3476d90c..e6400449c8 100644
--- a/examples/vision/shiftvit.py
+++ b/examples/vision/shiftvit.py
@@ -1,8 +1,8 @@
"""
Title: A Vision Transformer without Attention
-Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)
Date created: 2022/02/24
-Last modified: 2022/03/01
+Last modified: 2022/10/15
Description: A minimal implementation of ShiftViT.
Accelerator: GPU
"""
@@ -26,12 +26,11 @@
In this example, we minimally implement the paper with close alignement to the author's
[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).
-This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can
+This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can
be installed using the following command:
-
-```shell
+"""
+"""shell
pip install -qq -U tensorflow-addons
-```
"""
"""
@@ -44,9 +43,11 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
-
import tensorflow_addons as tfa
+import pathlib
+import glob
+
# Setting seed for reproducibiltiy
SEED = 42
keras.utils.set_random_seed(SEED)
@@ -88,6 +89,21 @@ class Config(object):
# TRAINING
epochs = 100
+ # INFERENCE
+ label_map = {
+ 0: "airplane",
+ 1: "automobile",
+ 2: "bird",
+ 3: "cat",
+ 4: "deer",
+ 5: "dog",
+ 6: "frog",
+ 7: "horse",
+ 8: "ship",
+ 9: "truck",
+ }
+ tf_ds_batch_size = 20
+
config = Config()
@@ -200,7 +216,7 @@ def get_augmentation_model():
"""
#### The MLP block
-The MLP block is intended to be a stack of densely-connected layers.s
+The MLP block is intended to be a stack of densely-connected layers
"""
@@ -273,7 +289,7 @@ def call(self, x, training=False):
"""
#### Block
-The most important operation in this paper is the **shift opperation**. In this section,
+The most important operation in this paper is the **shift operation**. In this section,
we describe the shift operation and compare it with its original implementation provided
by the authors.
@@ -545,6 +561,24 @@ def call(self, x, training=False):
x = self.patch_merge(x)
return x
+ # Since this is a custom layer, we need to overwrite get_config()
+ # so that model can be easily saved & loaded after training
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "epsilon": self.epsilon,
+ "mlp_dropout_rate": self.mlp_dropout_rate,
+ "num_shift_blocks": self.num_shift_blocks,
+ "stochastic_depth_rate": self.stochastic_depth_rate,
+ "is_merge": self.is_merge,
+ "num_div": self.num_div,
+ "shift_pixel": self.shift_pixel,
+ "mlp_expand_ratio": self.mlp_expand_ratio,
+ }
+ )
+ return config
+
"""
## The ShiftViT model
@@ -617,6 +651,8 @@ def __init__(
)
self.global_avg_pool = layers.GlobalAveragePooling2D()
+ self.classifier = layers.Dense(config.num_classes)
+
def get_config(self):
config = super().get_config()
config.update(
@@ -625,6 +661,7 @@ def get_config(self):
"patch_projection": self.patch_projection,
"stages": self.stages,
"global_avg_pool": self.global_avg_pool,
+ "classifier": self.classifier,
}
)
return config
@@ -644,7 +681,8 @@ def _calculate_loss(self, data, training=False):
x = stage(x, training=training)
# Get the logits.
- logits = self.global_avg_pool(x)
+ x = self.global_avg_pool(x)
+ logits = self.classifier(x)
# Calculate the loss and return it.
total_loss = self.compiled_loss(labels, logits)
@@ -661,6 +699,7 @@ def train_step(self, inputs):
self.data_augmentation.trainable_variables,
self.patch_projection.trainable_variables,
self.global_avg_pool.trainable_variables,
+ self.classifier.trainable_variables,
]
train_vars = train_vars + [stage.trainable_variables for stage in self.stages]
@@ -683,6 +722,15 @@ def test_step(self, data):
self.compiled_metrics.update_state(labels, logits)
return {m.name: m.result() for m in self.metrics}
+ def call(self, images):
+ augmented_images = self.data_augmentation(images)
+ x = self.patch_projection(augmented_images)
+ for stage in self.stages:
+ x = stage(x, training=False)
+ x = self.global_avg_pool(x)
+ logits = self.classifier(x)
+ return logits
+
"""
## Instantiate the model
@@ -787,11 +835,25 @@ def __call__(self, step):
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
+ def get_config(self):
+ config = {
+ "lr_start": self.lr_start,
+ "lr_max": self.lr_max,
+ "total_steps": self.total_steps,
+ "warmup_steps": self.warmup_steps,
+ }
+ return config
+
"""
## Compile and train the model
"""
+# pass sample data to the model so that input shape is available at the time of
+# saving the model
+sample_ds, _ = next(iter(train_ds))
+model(sample_ds, training=False)
+
# Get the total number of steps for training.
total_steps = int((len(x_train) / config.batch_size) * config.epochs)
@@ -843,6 +905,123 @@ def __call__(self, step):
print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
+"""
+## Save trained model
+
+Since we created the model by Subclassing, we can't save the model in HDF5 format.
+
+It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well.
+"""
+model.save("ShiftViT")
+
+"""
+## Model inference
+"""
+
+"""
+**Download sample data for inference**
+"""
+
+"""shell
+wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip
+unzip -q inference_set.zip
+"""
+
+
+"""
+**Load saved model**
+"""
+# Custom objects are not included when the model is saved.
+# At loading time, these objects need to be passed for reconstruction of the model
+saved_model = tf.keras.models.load_model(
+ "ShiftViT",
+ custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
+)
+
+"""
+**Utility functions for inference**
+"""
+
+
+def process_image(img_path):
+ # read image file from string path
+ img = tf.io.read_file(img_path)
+
+ # decode jpeg to uint8 tensor
+ img = tf.io.decode_jpeg(img, channels=3)
+
+ # resize image to match input size accepted by model
+ # use `method` as `nearest` to preserve dtype of input passed to `resize()`
+ img = tf.image.resize(
+ img, [config.input_shape[0], config.input_shape[1]], method="nearest"
+ )
+ return img
+
+
+def create_tf_dataset(image_dir):
+ data_dir = pathlib.Path(image_dir)
+
+ # create tf.data dataset using directory of images
+ predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False)
+
+ # use map to convert string paths to uint8 image tensors
+ # setting `num_parallel_calls' helps in processing multiple images parallely
+ predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO)
+
+ # create a Prefetch Dataset for better latency & throughput
+ predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO)
+ return predict_ds
+
+
+def predict(predict_ds):
+ # ShiftViT model returns logits (non-normalized predictions)
+ logits = saved_model.predict(predict_ds)
+
+ # normalize predictions by calling softmax()
+ probabilities = tf.nn.softmax(logits)
+ return probabilities
+
+
+def get_predicted_class(probabilities):
+ pred_label = np.argmax(probabilities)
+ predicted_class = config.label_map[pred_label]
+ return predicted_class
+
+
+def get_confidence_scores(probabilities):
+ # get the indices of the probability scores sorted in descending order
+ labels = np.argsort(probabilities)[::-1]
+ confidences = {
+ config.label_map[label]: np.round((probabilities[label]) * 100, 2)
+ for label in labels
+ }
+ return confidences
+
+
+"""
+**Get predictions**
+"""
+
+img_dir = "inference_set"
+predict_ds = create_tf_dataset(img_dir)
+probabilities = predict(predict_ds)
+print(f"probabilities: {probabilities[0]}")
+confidences = get_confidence_scores(probabilities[0])
+print(confidences)
+
+"""
+**View predictions**
+"""
+
+plt.figure(figsize=(10, 10))
+for images in predict_ds:
+ for i in range(min(6, probabilities.shape[0])):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(images[i].numpy().astype("uint8"))
+ predicted_class = get_predicted_class(probabilities[i])
+ plt.title(predicted_class)
+ plt.axis("off")
+
"""
## Conclusion
@@ -866,3 +1045,11 @@ def __call__(self, step):
- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for
helping us with the Learning Rate Schedule.
"""
+
+"""
+**Example available on HuggingFace**
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) |
+"""