diff --git a/examples/vision/img/shiftvit/shiftvit_46_0.png b/examples/vision/img/shiftvit/shiftvit_46_0.png new file mode 100644 index 0000000000..ad616c3bd1 Binary files /dev/null and b/examples/vision/img/shiftvit/shiftvit_46_0.png differ diff --git a/examples/vision/ipynb/shiftvit.ipynb b/examples/vision/ipynb/shiftvit.ipynb index 13a30f6c2c..682d161424 100644 --- a/examples/vision/ipynb/shiftvit.ipynb +++ b/examples/vision/ipynb/shiftvit.ipynb @@ -8,9 +8,9 @@ "source": [ "# A Vision Transformer without Attention\n", "\n", - "**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
\n", + "**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)
\n", "**Date created:** 2022/02/24
\n", - "**Last modified:** 2022/03/01
\n", + "**Last modified:** 2022/10/15
\n", "**Description:** A minimal implementation of ShiftViT." ] }, @@ -39,12 +39,19 @@ "In this example, we minimally implement the paper with close alignement to the author's\n", "[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).\n", "\n", - "This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can\n", - "be installed using the following command:\n", - "\n", - "```shell\n", - "pip install -qq -U tensorflow-addons\n", - "```" + "This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can\n", + "be installed using the following command:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "!pip install -qq -U tensorflow-addons" ] }, { @@ -70,9 +77,11 @@ "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", - "\n", "import tensorflow_addons as tfa\n", "\n", + "import pathlib\n", + "import glob\n", + "\n", "# Setting seed for reproducibiltiy\n", "SEED = 42\n", "keras.utils.set_random_seed(SEED)" @@ -128,6 +137,21 @@ " # TRAINING\n", " epochs = 100\n", "\n", + " # INFERENCE\n", + " label_map = {\n", + " 0: \"airplane\",\n", + " 1: \"automobile\",\n", + " 2: \"bird\",\n", + " 3: \"cat\",\n", + " 4: \"deer\",\n", + " 5: \"dog\",\n", + " 6: \"frog\",\n", + " 7: \"horse\",\n", + " 8: \"ship\",\n", + " 9: \"truck\",\n", + " }\n", + " tf_ds_batch_size = 20\n", + "\n", "\n", "config = Config()" ] @@ -284,7 +308,7 @@ "source": [ "#### The MLP block\n", "\n", - "The MLP block is intended to be a stack of densely-connected layers.s" + "The MLP block is intended to be a stack of densely-connected layers" ] }, { @@ -315,7 +339,10 @@ "\n", " self.mlp = keras.Sequential(\n", " [\n", - " layers.Dense(units=initial_filters, activation=tf.nn.gelu,),\n", + " layers.Dense(\n", + " units=initial_filters,\n", + " activation=tf.nn.gelu,\n", + " ),\n", " layers.Dropout(rate=self.mlp_dropout_rate),\n", " layers.Dense(units=input_channels),\n", " layers.Dropout(rate=self.mlp_dropout_rate),\n", @@ -382,7 +409,7 @@ "source": [ "#### Block\n", "\n", - "The most important operation in this paper is the **shift opperation**. In this section,\n", + "The most important operation in this paper is the **shift operation**. In this section,\n", "we describe the shift operation and compare it with its original implementation provided\n", "by the authors.\n", "\n", @@ -693,6 +720,24 @@ " if self.is_merge:\n", " x = self.patch_merge(x)\n", " return x\n", + "\n", + " # Since this is a custom layer, we need to overwrite get_config()\n", + " # so that model can be easily saved & loaded after training\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update(\n", + " {\n", + " \"epsilon\": self.epsilon,\n", + " \"mlp_dropout_rate\": self.mlp_dropout_rate,\n", + " \"num_shift_blocks\": self.num_shift_blocks,\n", + " \"stochastic_depth_rate\": self.stochastic_depth_rate,\n", + " \"is_merge\": self.is_merge,\n", + " \"num_div\": self.num_div,\n", + " \"shift_pixel\": self.shift_pixel,\n", + " \"mlp_expand_ratio\": self.mlp_expand_ratio,\n", + " }\n", + " )\n", + " return config\n", "" ] }, @@ -780,6 +825,8 @@ " )\n", " self.global_avg_pool = layers.GlobalAveragePooling2D()\n", "\n", + " self.classifier = layers.Dense(config.num_classes)\n", + "\n", " def get_config(self):\n", " config = super().get_config()\n", " config.update(\n", @@ -788,6 +835,7 @@ " \"patch_projection\": self.patch_projection,\n", " \"stages\": self.stages,\n", " \"global_avg_pool\": self.global_avg_pool,\n", + " \"classifier\": self.classifier,\n", " }\n", " )\n", " return config\n", @@ -807,7 +855,8 @@ " x = stage(x, training=training)\n", "\n", " # Get the logits.\n", - " logits = self.global_avg_pool(x)\n", + " x = self.global_avg_pool(x)\n", + " logits = self.classifier(x)\n", "\n", " # Calculate the loss and return it.\n", " total_loss = self.compiled_loss(labels, logits)\n", @@ -824,6 +873,7 @@ " self.data_augmentation.trainable_variables,\n", " self.patch_projection.trainable_variables,\n", " self.global_avg_pool.trainable_variables,\n", + " self.classifier.trainable_variables,\n", " ]\n", " train_vars = train_vars + [stage.trainable_variables for stage in self.stages]\n", "\n", @@ -845,6 +895,15 @@ " # Update the metrics\n", " self.compiled_metrics.update_state(labels, logits)\n", " return {m.name: m.result() for m in self.metrics}\n", + "\n", + " def call(self, images):\n", + " augmented_images = self.data_augmentation(images)\n", + " x = self.patch_projection(augmented_images)\n", + " for stage in self.stages:\n", + " x = stage(x, training=False)\n", + " x = self.global_avg_pool(x)\n", + " logits = self.classifier(x)\n", + " return logits\n", "" ] }, @@ -976,6 +1035,15 @@ " return tf.where(\n", " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n", " )\n", + "\n", + " def get_config(self):\n", + " config = {\n", + " \"lr_start\": self.lr_start,\n", + " \"lr_max\": self.lr_max,\n", + " \"total_steps\": self.total_steps,\n", + " \"warmup_steps\": self.warmup_steps,\n", + " }\n", + " return config\n", "" ] }, @@ -996,6 +1064,11 @@ }, "outputs": [], "source": [ + "# pass sample data to the model so that input shape is available at the time of\n", + "# saving the model\n", + "sample_ds, _ = next(iter(train_ds))\n", + "model(sample_ds, training=False)\n", + "\n", "# Get the total number of steps for training.\n", "total_steps = int((len(x_train) / config.batch_size) * config.epochs)\n", "\n", @@ -1005,7 +1078,10 @@ "\n", "# Initialize the warmupcosine schedule.\n", "scheduled_lrs = WarmUpCosine(\n", - " lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps,\n", + " lr_start=1e-5,\n", + " lr_max=1e-3,\n", + " warmup_steps=warmup_steps,\n", + " total_steps=total_steps,\n", ")\n", "\n", "# Get the optimizer.\n", @@ -1029,7 +1105,11 @@ " epochs=config.epochs,\n", " validation_data=val_ds,\n", " callbacks=[\n", - " keras.callbacks.EarlyStopping(monitor=\"val_accuracy\", patience=5, mode=\"auto\",)\n", + " keras.callbacks.EarlyStopping(\n", + " monitor=\"val_accuracy\",\n", + " patience=5,\n", + " mode=\"auto\",\n", + " )\n", " ],\n", ")\n", "\n", @@ -1041,6 +1121,211 @@ "print(f\"Top 5 test accuracy: {acc_top5*100:0.2f}%\")" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Save trained model\n", + "\n", + "Since we created the model by Subclassing, we can't save the model in HDF5 format.\n", + "\n", + "It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "model.save(\"ShiftViT\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Model inference" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Download sample data for inference**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "!wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip\n", + "!unzip -q inference_set.zip" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Load saved model**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "# Custom objects are not included when the model is saved.\n", + "# At loading time, these objects need to be passed for reconstruction of the model\n", + "saved_model = tf.keras.models.load_model(\n", + " \"ShiftViT\",\n", + " custom_objects={\"WarmUpCosine\": WarmUpCosine, \"AdamW\": tfa.optimizers.AdamW},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Utility functions for inference**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def process_image(img_path):\n", + " # read image file from string path\n", + " img = tf.io.read_file(img_path)\n", + "\n", + " # decode jpeg to uint8 tensor\n", + " img = tf.io.decode_jpeg(img, channels=3)\n", + "\n", + " # resize image to match input size accepted by model\n", + " # use `method` as `nearest` to preserve dtype of input passed to `resize()`\n", + " img = tf.image.resize(\n", + " img, [config.input_shape[0], config.input_shape[1]], method=\"nearest\"\n", + " )\n", + " return img\n", + "\n", + "\n", + "def create_tf_dataset(image_dir):\n", + " data_dir = pathlib.Path(image_dir)\n", + "\n", + " # create tf.data dataset using directory of images\n", + " predict_ds = tf.data.Dataset.list_files(str(data_dir / \"*.jpg\"), shuffle=False)\n", + "\n", + " # use map to convert string paths to uint8 image tensors\n", + " # setting `num_parallel_calls' helps in processing multiple images parallely\n", + " predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO)\n", + "\n", + " # create a Prefetch Dataset for better latency & throughput\n", + " predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO)\n", + " return predict_ds\n", + "\n", + "\n", + "def predict(predict_ds):\n", + " # ShiftViT model returns logits (non-normalized predictions)\n", + " logits = saved_model.predict(predict_ds)\n", + "\n", + " # normalize predictions by calling softmax()\n", + " probabilities = tf.nn.softmax(logits)\n", + " return probabilities\n", + "\n", + "\n", + "def get_predicted_class(probabilities):\n", + " pred_label = np.argmax(probabilities)\n", + " predicted_class = config.label_map[pred_label]\n", + " return predicted_class\n", + "\n", + "\n", + "def get_confidence_scores(probabilities):\n", + " # get the indices of the probability scores sorted in descending order\n", + " labels = np.argsort(probabilities)[::-1]\n", + " confidences = {\n", + " config.label_map[label]: np.round((probabilities[label]) * 100, 2)\n", + " for label in labels\n", + " }\n", + " return confidences\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Get predictions**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "img_dir = \"inference_set\"\n", + "predict_ds = create_tf_dataset(img_dir)\n", + "probabilities = predict(predict_ds)\n", + "print(f\"probabilities: {probabilities[0]}\")\n", + "confidences = get_confidence_scores(probabilities[0])\n", + "print(confidences)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**View predictions**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(10, 10))\n", + "for images in predict_ds:\n", + " for i in range(min(6, probabilities.shape[0])):\n", + " ax = plt.subplot(3, 3, i + 1)\n", + " plt.imshow(images[i].numpy().astype(\"uint8\"))\n", + " predicted_class = get_predicted_class(probabilities[i])\n", + " plt.title(predicted_class)\n", + " plt.axis(\"off\")" + ] + }, { "cell_type": "markdown", "metadata": { @@ -1069,6 +1354,19 @@ "- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for\n", "helping us with the Learning Rate Schedule." ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Example available on HuggingFace**\n", + "\n", + "| Trained Model | Demo |\n", + "| :--: | :--: |\n", + "| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) |" + ] } ], "metadata": { @@ -1100,4 +1398,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/examples/vision/md/shiftvit.md b/examples/vision/md/shiftvit.md index 6738bfb777..0b583bd1c2 100644 --- a/examples/vision/md/shiftvit.md +++ b/examples/vision/md/shiftvit.md @@ -1,8 +1,8 @@ # A Vision Transformer without Attention -**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)
**Date created:** 2022/02/24
-**Last modified:** 2022/03/01
+**Last modified:** 2022/10/15
**Description:** A minimal implementation of ShiftViT. @@ -30,13 +30,15 @@ operation with a shifting operation. In this example, we minimally implement the paper with close alignement to the author's [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py). -This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can +This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can be installed using the following command: -```shell -pip install -qq -U tensorflow-addons + +```python +!pip install -qq -U tensorflow-addons ``` + --- ## Setup and imports @@ -48,9 +50,11 @@ import matplotlib.pyplot as plt import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers - import tensorflow_addons as tfa +import pathlib +import glob + # Setting seed for reproducibiltiy SEED = 42 keras.utils.set_random_seed(SEED) @@ -94,6 +98,21 @@ class Config(object): # TRAINING epochs = 100 + # INFERENCE + label_map = { + 0: "airplane", + 1: "automobile", + 2: "bird", + 3: "cat", + 4: "deer", + 5: "dog", + 6: "frog", + 7: "horse", + 8: "ship", + 9: "truck", + } + tf_ds_batch_size = 20 + config = Config() ``` @@ -127,14 +146,12 @@ test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
``` +Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +170498071/170498071 [==============================] - 3s 0us/step Training samples: 40000 Validation samples: 10000 Testing samples: 10000 -2022-03-01 03:10:21.342684: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA -To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. -2022-03-01 03:10:21.850844: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38420 MB memory: -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:61:00.0, compute capability: 8.0 - ```
--- @@ -219,7 +236,7 @@ The Shift Block as shown in Fig. 3, comprises of the following: #### The MLP block -The MLP block is intended to be a stack of densely-connected layers.s +The MLP block is intended to be a stack of densely-connected layers ```python @@ -243,7 +260,10 @@ class MLP(layers.Layer): self.mlp = keras.Sequential( [ - layers.Dense(units=initial_filters, activation=tf.nn.gelu,), + layers.Dense( + units=initial_filters, + activation=tf.nn.gelu, + ), layers.Dropout(rate=self.mlp_dropout_rate), layers.Dense(units=input_channels), layers.Dropout(rate=self.mlp_dropout_rate), @@ -291,7 +311,7 @@ class DropPath(layers.Layer): #### Block -The most important operation in this paper is the **shift opperation**. In this section, +The most important operation in this paper is the **shift operation**. In this section, we describe the shift operation and compare it with its original implementation provided by the authors. @@ -563,6 +583,24 @@ class StackedShiftBlocks(layers.Layer): x = self.patch_merge(x) return x + # Since this is a custom layer, we need to overwrite get_config() + # so that model can be easily saved & loaded after training + def get_config(self): + config = super().get_config() + config.update( + { + "epsilon": self.epsilon, + "mlp_dropout_rate": self.mlp_dropout_rate, + "num_shift_blocks": self.num_shift_blocks, + "stochastic_depth_rate": self.stochastic_depth_rate, + "is_merge": self.is_merge, + "num_div": self.num_div, + "shift_pixel": self.shift_pixel, + "mlp_expand_ratio": self.mlp_expand_ratio, + } + ) + return config + ``` --- @@ -637,6 +675,8 @@ class ShiftViTModel(keras.Model): ) self.global_avg_pool = layers.GlobalAveragePooling2D() + self.classifier = layers.Dense(config.num_classes) + def get_config(self): config = super().get_config() config.update( @@ -645,6 +685,7 @@ class ShiftViTModel(keras.Model): "patch_projection": self.patch_projection, "stages": self.stages, "global_avg_pool": self.global_avg_pool, + "classifier": self.classifier, } ) return config @@ -664,7 +705,8 @@ class ShiftViTModel(keras.Model): x = stage(x, training=training) # Get the logits. - logits = self.global_avg_pool(x) + x = self.global_avg_pool(x) + logits = self.classifier(x) # Calculate the loss and return it. total_loss = self.compiled_loss(labels, logits) @@ -681,6 +723,7 @@ class ShiftViTModel(keras.Model): self.data_augmentation.trainable_variables, self.patch_projection.trainable_variables, self.global_avg_pool.trainable_variables, + self.classifier.trainable_variables, ] train_vars = train_vars + [stage.trainable_variables for stage in self.stages] @@ -703,6 +746,15 @@ class ShiftViTModel(keras.Model): self.compiled_metrics.update_state(labels, logits) return {m.name: m.result() for m in self.metrics} + def call(self, images): + augmented_images = self.data_augmentation(images) + x = self.patch_projection(augmented_images) + for stage in self.stages: + x = stage(x, training=False) + x = self.global_avg_pool(x) + logits = self.classifier(x) + return logits + ``` --- @@ -810,6 +862,15 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): step > self.total_steps, 0.0, learning_rate, name="learning_rate" ) + def get_config(self): + config = { + "lr_start": self.lr_start, + "lr_max": self.lr_max, + "total_steps": self.total_steps, + "warmup_steps": self.warmup_steps, + } + return config + ``` --- @@ -817,6 +878,11 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): ```python +# pass sample data to the model so that input shape is available at the time of +# saving the model +sample_ds, _ = next(iter(train_ds)) +model(sample_ds, training=False) + # Get the total number of steps for training. total_steps = int((len(x_train) / config.batch_size) * config.epochs) @@ -826,7 +892,10 @@ warmup_steps = int(total_steps * warmup_epoch_percentage) # Initialize the warmupcosine schedule. scheduled_lrs = WarmUpCosine( - lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps, + lr_start=1e-5, + lr_max=1e-3, + warmup_steps=warmup_steps, + total_steps=total_steps, ) # Get the optimizer. @@ -850,7 +919,11 @@ history = model.fit( epochs=config.epochs, validation_data=val_ds, callbacks=[ - keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5, mode="auto",) + keras.callbacks.EarlyStopping( + monitor="val_accuracy", + patience=5, + mode="auto", + ) ], ) @@ -865,109 +938,249 @@ print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
``` Epoch 1/100 - -2022-03-01 03:10:41.373231: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8202 -2022-03-01 03:10:43.145958: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. - -157/157 [==============================] - 34s 84ms/step - loss: 3.2975 - accuracy: 0.1084 - top-5-accuracy: 0.4806 - val_loss: 2.1575 - val_accuracy: 0.2017 - val_top-5-accuracy: 0.7184 +157/157 [==============================] - 72s 332ms/step - loss: 2.3844 - accuracy: 0.1444 - top-5-accuracy: 0.6051 - val_loss: 2.0984 - val_accuracy: 0.2610 - val_top-5-accuracy: 0.7638 Epoch 2/100 -157/157 [==============================] - 11s 67ms/step - loss: 2.1727 - accuracy: 0.2289 - top-5-accuracy: 0.7516 - val_loss: 1.8819 - val_accuracy: 0.3182 - val_top-5-accuracy: 0.8386 +157/157 [==============================] - 49s 314ms/step - loss: 1.9457 - accuracy: 0.2893 - top-5-accuracy: 0.8103 - val_loss: 1.9459 - val_accuracy: 0.3356 - val_top-5-accuracy: 0.8614 Epoch 3/100 -157/157 [==============================] - 10s 67ms/step - loss: 1.8169 - accuracy: 0.3426 - top-5-accuracy: 0.8592 - val_loss: 1.6174 - val_accuracy: 0.4053 - val_top-5-accuracy: 0.8934 +157/157 [==============================] - 50s 316ms/step - loss: 1.7093 - accuracy: 0.3810 - top-5-accuracy: 0.8761 - val_loss: 1.5349 - val_accuracy: 0.4585 - val_top-5-accuracy: 0.9045 Epoch 4/100 -157/157 [==============================] - 10s 67ms/step - loss: 1.6215 - accuracy: 0.4092 - top-5-accuracy: 0.8983 - val_loss: 1.4239 - val_accuracy: 0.4903 - val_top-5-accuracy: 0.9216 +157/157 [==============================] - 49s 315ms/step - loss: 1.5473 - accuracy: 0.4374 - top-5-accuracy: 0.9090 - val_loss: 1.4257 - val_accuracy: 0.4862 - val_top-5-accuracy: 0.9298 Epoch 5/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.5081 - accuracy: 0.4571 - top-5-accuracy: 0.9148 - val_loss: 1.3359 - val_accuracy: 0.5161 - val_top-5-accuracy: 0.9369 +157/157 [==============================] - 50s 316ms/step - loss: 1.4316 - accuracy: 0.4816 - top-5-accuracy: 0.9243 - val_loss: 1.4032 - val_accuracy: 0.5092 - val_top-5-accuracy: 0.9362 Epoch 6/100 -157/157 [==============================] - 11s 68ms/step - loss: 1.4282 - accuracy: 0.4868 - top-5-accuracy: 0.9249 - val_loss: 1.2929 - val_accuracy: 0.5347 - val_top-5-accuracy: 0.9404 +157/157 [==============================] - 50s 316ms/step - loss: 1.3588 - accuracy: 0.5131 - top-5-accuracy: 0.9333 - val_loss: 1.2893 - val_accuracy: 0.5411 - val_top-5-accuracy: 0.9457 Epoch 7/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.3465 - accuracy: 0.5181 - top-5-accuracy: 0.9362 - val_loss: 1.2653 - val_accuracy: 0.5497 - val_top-5-accuracy: 0.9449 +157/157 [==============================] - 50s 316ms/step - loss: 1.2894 - accuracy: 0.5385 - top-5-accuracy: 0.9410 - val_loss: 1.2922 - val_accuracy: 0.5416 - val_top-5-accuracy: 0.9432 Epoch 8/100 -157/157 [==============================] - 10s 67ms/step - loss: 1.2907 - accuracy: 0.5400 - top-5-accuracy: 0.9416 - val_loss: 1.1919 - val_accuracy: 0.5753 - val_top-5-accuracy: 0.9515 +157/157 [==============================] - 49s 315ms/step - loss: 1.2388 - accuracy: 0.5568 - top-5-accuracy: 0.9468 - val_loss: 1.2100 - val_accuracy: 0.5733 - val_top-5-accuracy: 0.9545 Epoch 9/100 -157/157 [==============================] - 11s 67ms/step - loss: 1.2247 - accuracy: 0.5644 - top-5-accuracy: 0.9480 - val_loss: 1.1741 - val_accuracy: 0.5742 - val_top-5-accuracy: 0.9563 +157/157 [==============================] - 49s 315ms/step - loss: 1.2043 - accuracy: 0.5698 - top-5-accuracy: 0.9491 - val_loss: 1.2166 - val_accuracy: 0.5675 - val_top-5-accuracy: 0.9520 Epoch 10/100 -157/157 [==============================] - 11s 67ms/step - loss: 1.1983 - accuracy: 0.5760 - top-5-accuracy: 0.9505 - val_loss: 1.4545 - val_accuracy: 0.4804 - val_top-5-accuracy: 0.9198 +157/157 [==============================] - 49s 315ms/step - loss: 1.1694 - accuracy: 0.5861 - top-5-accuracy: 0.9528 - val_loss: 1.1738 - val_accuracy: 0.5883 - val_top-5-accuracy: 0.9541 Epoch 11/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.2002 - accuracy: 0.5766 - top-5-accuracy: 0.9510 - val_loss: 1.1129 - val_accuracy: 0.6055 - val_top-5-accuracy: 0.9593 +157/157 [==============================] - 50s 316ms/step - loss: 1.1290 - accuracy: 0.5994 - top-5-accuracy: 0.9575 - val_loss: 1.1161 - val_accuracy: 0.6064 - val_top-5-accuracy: 0.9618 Epoch 12/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.1309 - accuracy: 0.5990 - top-5-accuracy: 0.9575 - val_loss: 1.0369 - val_accuracy: 0.6341 - val_top-5-accuracy: 0.9638 +157/157 [==============================] - 50s 316ms/step - loss: 1.0861 - accuracy: 0.6157 - top-5-accuracy: 0.9602 - val_loss: 1.1220 - val_accuracy: 0.6133 - val_top-5-accuracy: 0.9576 Epoch 13/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.0786 - accuracy: 0.6204 - top-5-accuracy: 0.9613 - val_loss: 1.0802 - val_accuracy: 0.6193 - val_top-5-accuracy: 0.9594 +157/157 [==============================] - 49s 315ms/step - loss: 1.0766 - accuracy: 0.6178 - top-5-accuracy: 0.9612 - val_loss: 1.0108 - val_accuracy: 0.6402 - val_top-5-accuracy: 0.9681 Epoch 14/100 -157/157 [==============================] - 10s 65ms/step - loss: 1.0438 - accuracy: 0.6330 - top-5-accuracy: 0.9640 - val_loss: 0.9584 - val_accuracy: 0.6596 - val_top-5-accuracy: 0.9713 +157/157 [==============================] - 49s 315ms/step - loss: 1.0179 - accuracy: 0.6416 - top-5-accuracy: 0.9658 - val_loss: 1.0196 - val_accuracy: 0.6405 - val_top-5-accuracy: 0.9667 Epoch 15/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.9957 - accuracy: 0.6496 - top-5-accuracy: 0.9684 - val_loss: 0.9530 - val_accuracy: 0.6636 - val_top-5-accuracy: 0.9712 +157/157 [==============================] - 50s 316ms/step - loss: 1.0028 - accuracy: 0.6470 - top-5-accuracy: 0.9678 - val_loss: 1.0113 - val_accuracy: 0.6415 - val_top-5-accuracy: 0.9672 Epoch 16/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.9710 - accuracy: 0.6599 - top-5-accuracy: 0.9696 - val_loss: 0.8856 - val_accuracy: 0.6863 - val_top-5-accuracy: 0.9756 +157/157 [==============================] - 50s 316ms/step - loss: 0.9613 - accuracy: 0.6611 - top-5-accuracy: 0.9710 - val_loss: 1.0516 - val_accuracy: 0.6406 - val_top-5-accuracy: 0.9596 Epoch 17/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.9316 - accuracy: 0.6706 - top-5-accuracy: 0.9721 - val_loss: 0.9919 - val_accuracy: 0.6480 - val_top-5-accuracy: 0.9671 +157/157 [==============================] - 50s 316ms/step - loss: 0.9262 - accuracy: 0.6740 - top-5-accuracy: 0.9729 - val_loss: 0.9010 - val_accuracy: 0.6844 - val_top-5-accuracy: 0.9750 Epoch 18/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.8899 - accuracy: 0.6884 - top-5-accuracy: 0.9763 - val_loss: 0.8753 - val_accuracy: 0.6949 - val_top-5-accuracy: 0.9752 +157/157 [==============================] - 50s 316ms/step - loss: 0.8768 - accuracy: 0.6916 - top-5-accuracy: 0.9769 - val_loss: 0.8862 - val_accuracy: 0.6908 - val_top-5-accuracy: 0.9767 Epoch 19/100 -157/157 [==============================] - 10s 64ms/step - loss: 0.8529 - accuracy: 0.6979 - top-5-accuracy: 0.9772 - val_loss: 0.8793 - val_accuracy: 0.6943 - val_top-5-accuracy: 0.9754 +157/157 [==============================] - 49s 315ms/step - loss: 0.8595 - accuracy: 0.6984 - top-5-accuracy: 0.9768 - val_loss: 0.8732 - val_accuracy: 0.6982 - val_top-5-accuracy: 0.9738 Epoch 20/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.8509 - accuracy: 0.7009 - top-5-accuracy: 0.9783 - val_loss: 0.8183 - val_accuracy: 0.7174 - val_top-5-accuracy: 0.9763 +157/157 [==============================] - 50s 317ms/step - loss: 0.8252 - accuracy: 0.7103 - top-5-accuracy: 0.9793 - val_loss: 0.9330 - val_accuracy: 0.6745 - val_top-5-accuracy: 0.9718 Epoch 21/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.8087 - accuracy: 0.7143 - top-5-accuracy: 0.9809 - val_loss: 0.7885 - val_accuracy: 0.7276 - val_top-5-accuracy: 0.9769 +157/157 [==============================] - 51s 322ms/step - loss: 0.8003 - accuracy: 0.7180 - top-5-accuracy: 0.9814 - val_loss: 0.8912 - val_accuracy: 0.6948 - val_top-5-accuracy: 0.9728 Epoch 22/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.8004 - accuracy: 0.7192 - top-5-accuracy: 0.9811 - val_loss: 0.7601 - val_accuracy: 0.7371 - val_top-5-accuracy: 0.9805 +157/157 [==============================] - 51s 326ms/step - loss: 0.7651 - accuracy: 0.7317 - top-5-accuracy: 0.9829 - val_loss: 0.7894 - val_accuracy: 0.7277 - val_top-5-accuracy: 0.9791 Epoch 23/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.7665 - accuracy: 0.7304 - top-5-accuracy: 0.9816 - val_loss: 0.7564 - val_accuracy: 0.7412 - val_top-5-accuracy: 0.9808 +157/157 [==============================] - 52s 328ms/step - loss: 0.7372 - accuracy: 0.7415 - top-5-accuracy: 0.9843 - val_loss: 0.7752 - val_accuracy: 0.7284 - val_top-5-accuracy: 0.9804 Epoch 24/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.7599 - accuracy: 0.7344 - top-5-accuracy: 0.9832 - val_loss: 0.7475 - val_accuracy: 0.7389 - val_top-5-accuracy: 0.9822 +157/157 [==============================] - 51s 327ms/step - loss: 0.7324 - accuracy: 0.7423 - top-5-accuracy: 0.9852 - val_loss: 0.7949 - val_accuracy: 0.7340 - val_top-5-accuracy: 0.9792 Epoch 25/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.7398 - accuracy: 0.7427 - top-5-accuracy: 0.9833 - val_loss: 0.7211 - val_accuracy: 0.7504 - val_top-5-accuracy: 0.9829 +157/157 [==============================] - 51s 323ms/step - loss: 0.7051 - accuracy: 0.7512 - top-5-accuracy: 0.9858 - val_loss: 0.7967 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9787 Epoch 26/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.7114 - accuracy: 0.7500 - top-5-accuracy: 0.9857 - val_loss: 0.7385 - val_accuracy: 0.7462 - val_top-5-accuracy: 0.9822 +157/157 [==============================] - 51s 323ms/step - loss: 0.6832 - accuracy: 0.7577 - top-5-accuracy: 0.9870 - val_loss: 0.7840 - val_accuracy: 0.7322 - val_top-5-accuracy: 0.9807 Epoch 27/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.6954 - accuracy: 0.7577 - top-5-accuracy: 0.9851 - val_loss: 0.7477 - val_accuracy: 0.7402 - val_top-5-accuracy: 0.9802 +157/157 [==============================] - 51s 322ms/step - loss: 0.6609 - accuracy: 0.7654 - top-5-accuracy: 0.9877 - val_loss: 0.7447 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9816 Epoch 28/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.6807 - accuracy: 0.7588 - top-5-accuracy: 0.9871 - val_loss: 0.7275 - val_accuracy: 0.7536 - val_top-5-accuracy: 0.9822 +157/157 [==============================] - 50s 319ms/step - loss: 0.6495 - accuracy: 0.7724 - top-5-accuracy: 0.9883 - val_loss: 0.7885 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9817 Epoch 29/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.6719 - accuracy: 0.7648 - top-5-accuracy: 0.9876 - val_loss: 0.7261 - val_accuracy: 0.7487 - val_top-5-accuracy: 0.9815 +157/157 [==============================] - 50s 317ms/step - loss: 0.6491 - accuracy: 0.7707 - top-5-accuracy: 0.9885 - val_loss: 0.7539 - val_accuracy: 0.7458 - val_top-5-accuracy: 0.9821 Epoch 30/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.6578 - accuracy: 0.7696 - top-5-accuracy: 0.9871 - val_loss: 0.6932 - val_accuracy: 0.7641 - val_top-5-accuracy: 0.9833 +157/157 [==============================] - 50s 317ms/step - loss: 0.6213 - accuracy: 0.7823 - top-5-accuracy: 0.9888 - val_loss: 0.7571 - val_accuracy: 0.7470 - val_top-5-accuracy: 0.9815 Epoch 31/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.6489 - accuracy: 0.7740 - top-5-accuracy: 0.9877 - val_loss: 0.7400 - val_accuracy: 0.7486 - val_top-5-accuracy: 0.9820 +157/157 [==============================] - 50s 318ms/step - loss: 0.5976 - accuracy: 0.7902 - top-5-accuracy: 0.9906 - val_loss: 0.7430 - val_accuracy: 0.7508 - val_top-5-accuracy: 0.9817 Epoch 32/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.6290 - accuracy: 0.7812 - top-5-accuracy: 0.9895 - val_loss: 0.6954 - val_accuracy: 0.7628 - val_top-5-accuracy: 0.9847 +157/157 [==============================] - 50s 318ms/step - loss: 0.5932 - accuracy: 0.7898 - top-5-accuracy: 0.9910 - val_loss: 0.7545 - val_accuracy: 0.7469 - val_top-5-accuracy: 0.9793 Epoch 33/100 -157/157 [==============================] - 10s 67ms/step - loss: 0.6194 - accuracy: 0.7826 - top-5-accuracy: 0.9894 - val_loss: 0.6913 - val_accuracy: 0.7619 - val_top-5-accuracy: 0.9842 +157/157 [==============================] - 50s 318ms/step - loss: 0.5977 - accuracy: 0.7850 - top-5-accuracy: 0.9913 - val_loss: 0.7200 - val_accuracy: 0.7569 - val_top-5-accuracy: 0.9830 Epoch 34/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.5917 - accuracy: 0.7930 - top-5-accuracy: 0.9902 - val_loss: 0.6879 - val_accuracy: 0.7715 - val_top-5-accuracy: 0.9831 +157/157 [==============================] - 50s 317ms/step - loss: 0.5552 - accuracy: 0.8041 - top-5-accuracy: 0.9920 - val_loss: 0.7377 - val_accuracy: 0.7552 - val_top-5-accuracy: 0.9818 Epoch 35/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.5878 - accuracy: 0.7916 - top-5-accuracy: 0.9907 - val_loss: 0.6759 - val_accuracy: 0.7720 - val_top-5-accuracy: 0.9849 +157/157 [==============================] - 50s 319ms/step - loss: 0.5509 - accuracy: 0.8056 - top-5-accuracy: 0.9921 - val_loss: 0.8125 - val_accuracy: 0.7331 - val_top-5-accuracy: 0.9782 Epoch 36/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.5713 - accuracy: 0.8004 - top-5-accuracy: 0.9913 - val_loss: 0.6920 - val_accuracy: 0.7657 - val_top-5-accuracy: 0.9841 +157/157 [==============================] - 50s 317ms/step - loss: 0.5296 - accuracy: 0.8116 - top-5-accuracy: 0.9933 - val_loss: 0.6900 - val_accuracy: 0.7680 - val_top-5-accuracy: 0.9849 Epoch 37/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.5590 - accuracy: 0.8040 - top-5-accuracy: 0.9913 - val_loss: 0.6790 - val_accuracy: 0.7718 - val_top-5-accuracy: 0.9831 +157/157 [==============================] - 50s 316ms/step - loss: 0.5151 - accuracy: 0.8170 - top-5-accuracy: 0.9941 - val_loss: 0.7275 - val_accuracy: 0.7610 - val_top-5-accuracy: 0.9841 Epoch 38/100 -157/157 [==============================] - 11s 67ms/step - loss: 0.5445 - accuracy: 0.8114 - top-5-accuracy: 0.9926 - val_loss: 0.6756 - val_accuracy: 0.7720 - val_top-5-accuracy: 0.9852 +157/157 [==============================] - 50s 317ms/step - loss: 0.5069 - accuracy: 0.8217 - top-5-accuracy: 0.9936 - val_loss: 0.7067 - val_accuracy: 0.7703 - val_top-5-accuracy: 0.9835 Epoch 39/100 -157/157 [==============================] - 11s 67ms/step - loss: 0.5292 - accuracy: 0.8155 - top-5-accuracy: 0.9930 - val_loss: 0.6578 - val_accuracy: 0.7807 - val_top-5-accuracy: 0.9845 +157/157 [==============================] - 50s 318ms/step - loss: 0.4771 - accuracy: 0.8304 - top-5-accuracy: 0.9945 - val_loss: 0.7110 - val_accuracy: 0.7668 - val_top-5-accuracy: 0.9836 Epoch 40/100 -157/157 [==============================] - 11s 68ms/step - loss: 0.5169 - accuracy: 0.8181 - top-5-accuracy: 0.9926 - val_loss: 0.6582 - val_accuracy: 0.7795 - val_top-5-accuracy: 0.9849 +157/157 [==============================] - 50s 317ms/step - loss: 0.4675 - accuracy: 0.8350 - top-5-accuracy: 0.9956 - val_loss: 0.7130 - val_accuracy: 0.7688 - val_top-5-accuracy: 0.9829 Epoch 41/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.5108 - accuracy: 0.8217 - top-5-accuracy: 0.9937 - val_loss: 0.6344 - val_accuracy: 0.7846 - val_top-5-accuracy: 0.9855 +157/157 [==============================] - 50s 319ms/step - loss: 0.4586 - accuracy: 0.8382 - top-5-accuracy: 0.9959 - val_loss: 0.7331 - val_accuracy: 0.7598 - val_top-5-accuracy: 0.9806 Epoch 42/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.5056 - accuracy: 0.8220 - top-5-accuracy: 0.9936 - val_loss: 0.6723 - val_accuracy: 0.7744 - val_top-5-accuracy: 0.9851 +157/157 [==============================] - 50s 318ms/step - loss: 0.4558 - accuracy: 0.8380 - top-5-accuracy: 0.9959 - val_loss: 0.7187 - val_accuracy: 0.7722 - val_top-5-accuracy: 0.9832 Epoch 43/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.4824 - accuracy: 0.8317 - top-5-accuracy: 0.9943 - val_loss: 0.6800 - val_accuracy: 0.7771 - val_top-5-accuracy: 0.9834 +157/157 [==============================] - 50s 320ms/step - loss: 0.4356 - accuracy: 0.8450 - top-5-accuracy: 0.9958 - val_loss: 0.7162 - val_accuracy: 0.7693 - val_top-5-accuracy: 0.9850 Epoch 44/100 -157/157 [==============================] - 10s 67ms/step - loss: 0.4719 - accuracy: 0.8339 - top-5-accuracy: 0.9938 - val_loss: 0.6742 - val_accuracy: 0.7785 - val_top-5-accuracy: 0.9840 +157/157 [==============================] - 49s 314ms/step - loss: 0.4425 - accuracy: 0.8433 - top-5-accuracy: 0.9958 - val_loss: 0.7061 - val_accuracy: 0.7698 - val_top-5-accuracy: 0.9853 Epoch 45/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.4605 - accuracy: 0.8379 - top-5-accuracy: 0.9953 - val_loss: 0.6732 - val_accuracy: 0.7781 - val_top-5-accuracy: 0.9841 +157/157 [==============================] - 49s 314ms/step - loss: 0.4072 - accuracy: 0.8551 - top-5-accuracy: 0.9967 - val_loss: 0.7025 - val_accuracy: 0.7820 - val_top-5-accuracy: 0.9848 Epoch 46/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.4608 - accuracy: 0.8390 - top-5-accuracy: 0.9947 - val_loss: 0.6547 - val_accuracy: 0.7846 - val_top-5-accuracy: 0.9852 +157/157 [==============================] - 49s 314ms/step - loss: 0.3865 - accuracy: 0.8644 - top-5-accuracy: 0.9970 - val_loss: 0.7178 - val_accuracy: 0.7740 - val_top-5-accuracy: 0.9844 +Epoch 47/100 +157/157 [==============================] - 49s 313ms/step - loss: 0.3718 - accuracy: 0.8694 - top-5-accuracy: 0.9973 - val_loss: 0.7216 - val_accuracy: 0.7768 - val_top-5-accuracy: 0.9828 +Epoch 48/100 +157/157 [==============================] - 49s 314ms/step - loss: 0.3733 - accuracy: 0.8673 - top-5-accuracy: 0.9970 - val_loss: 0.7440 - val_accuracy: 0.7713 - val_top-5-accuracy: 0.9841 +Epoch 49/100 +157/157 [==============================] - 49s 313ms/step - loss: 0.3531 - accuracy: 0.8741 - top-5-accuracy: 0.9979 - val_loss: 0.7220 - val_accuracy: 0.7738 - val_top-5-accuracy: 0.9848 +Epoch 50/100 +157/157 [==============================] - 49s 314ms/step - loss: 0.3502 - accuracy: 0.8738 - top-5-accuracy: 0.9980 - val_loss: 0.7245 - val_accuracy: 0.7734 - val_top-5-accuracy: 0.9836 TESTING -40/40 [==============================] - 1s 22ms/step - loss: 0.6801 - accuracy: 0.7720 - top-5-accuracy: 0.9864 -Loss: 0.68 -Top 1 test accuracy: 77.20% -Top 5 test accuracy: 98.64% +40/40 [==============================] - 2s 56ms/step - loss: 0.7336 - accuracy: 0.7638 - top-5-accuracy: 0.9855 +Loss: 0.73 +Top 1 test accuracy: 76.38% +Top 5 test accuracy: 98.55% ```
+--- +## Save trained model + +Since we created the model by Subclassing, we can't save the model in HDF5 format. + +It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well. + + +```python +model.save("ShiftViT") +``` + +--- +## Model inference + +**Download sample data for inference** + + +```python +!wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip +!unzip -q inference_set.zip +``` + +**Load saved model** + + +```python +# Custom objects are not included when the model is saved. +# At loading time, these objects need to be passed for reconstruction of the model +saved_model = tf.keras.models.load_model( + "ShiftViT", + custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW}, +) +``` + + +**Utility functions for inference** + + +```python + +def process_image(img_path): + # read image file from string path + img = tf.io.read_file(img_path) + + # decode jpeg to uint8 tensor + img = tf.io.decode_jpeg(img, channels=3) + + # resize image to match input size accepted by model + # use `method` as `nearest` to preserve dtype of input passed to `resize()` + img = tf.image.resize( + img, [config.input_shape[0], config.input_shape[1]], method="nearest" + ) + return img + + +def create_tf_dataset(image_dir): + data_dir = pathlib.Path(image_dir) + + # create tf.data dataset using directory of images + predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False) + + # use map to convert string paths to uint8 image tensors + # setting `num_parallel_calls' helps in processing multiple images parallely + predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO) + + # create a Prefetch Dataset for better latency & throughput + predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO) + return predict_ds + + +def predict(predict_ds): + # ShiftViT model returns logits (non-normalized predictions) + logits = saved_model.predict(predict_ds) + + # normalize predictions by calling softmax() + probabilities = tf.nn.softmax(logits) + return probabilities + + +def get_predicted_class(probabilities): + pred_label = np.argmax(probabilities) + predicted_class = config.label_map[pred_label] + return predicted_class + + +def get_confidence_scores(probabilities): + # get the indices of the probability scores sorted in descending order + labels = np.argsort(probabilities)[::-1] + confidences = { + config.label_map[label]: np.round((probabilities[label]) * 100, 2) + for label in labels + } + return confidences + +``` + +**Get predictions** + + +```python +img_dir = "inference_set" +predict_ds = create_tf_dataset(img_dir) +probabilities = predict(predict_ds) +print(f"probabilities: {probabilities[0]}") +confidences = get_confidence_scores(probabilities[0]) +print(confidences) +``` + +
+``` +1/1 [==============================] - 2s 2s/step +probabilities: [8.7329084e-01 1.3162658e-03 6.1781306e-05 1.9132349e-05 4.4482469e-05 + 1.8182898e-06 2.2834571e-05 1.1466043e-05 1.2504059e-01 1.9084632e-04] +{'airplane': 87.33, 'ship': 12.5, 'automobile': 0.13, 'truck': 0.02, 'bird': 0.01, 'deer': 0.0, 'frog': 0.0, 'cat': 0.0, 'horse': 0.0, 'dog': 0.0} + +``` +
+**View predictions** + + +```python +plt.figure(figsize=(10, 10)) +for images in predict_ds: + for i in range(min(6, probabilities.shape[0])): + ax = plt.subplot(3, 3, i + 1) + plt.imshow(images[i].numpy().astype("uint8")) + predicted_class = get_predicted_class(probabilities[i]) + plt.title(predicted_class) + plt.axis("off") +``` + + +![png](/img/examples/vision/shiftvit/shiftvit_46_0.png) + + --- ## Conclusion @@ -990,3 +1203,9 @@ GPU credits. library. - A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for helping us with the Learning Rate Schedule. + +**Example available on HuggingFace** + +| Trained Model | Demo | +| :--: | :--: | +| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) | diff --git a/examples/vision/shiftvit.py b/examples/vision/shiftvit.py index 3d3476d90c..e6400449c8 100644 --- a/examples/vision/shiftvit.py +++ b/examples/vision/shiftvit.py @@ -1,8 +1,8 @@ """ Title: A Vision Transformer without Attention -Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha) +Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/) Date created: 2022/02/24 -Last modified: 2022/03/01 +Last modified: 2022/10/15 Description: A minimal implementation of ShiftViT. Accelerator: GPU """ @@ -26,12 +26,11 @@ In this example, we minimally implement the paper with close alignement to the author's [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py). -This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can +This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can be installed using the following command: - -```shell +""" +"""shell pip install -qq -U tensorflow-addons -``` """ """ @@ -44,9 +43,11 @@ import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers - import tensorflow_addons as tfa +import pathlib +import glob + # Setting seed for reproducibiltiy SEED = 42 keras.utils.set_random_seed(SEED) @@ -88,6 +89,21 @@ class Config(object): # TRAINING epochs = 100 + # INFERENCE + label_map = { + 0: "airplane", + 1: "automobile", + 2: "bird", + 3: "cat", + 4: "deer", + 5: "dog", + 6: "frog", + 7: "horse", + 8: "ship", + 9: "truck", + } + tf_ds_batch_size = 20 + config = Config() @@ -200,7 +216,7 @@ def get_augmentation_model(): """ #### The MLP block -The MLP block is intended to be a stack of densely-connected layers.s +The MLP block is intended to be a stack of densely-connected layers """ @@ -273,7 +289,7 @@ def call(self, x, training=False): """ #### Block -The most important operation in this paper is the **shift opperation**. In this section, +The most important operation in this paper is the **shift operation**. In this section, we describe the shift operation and compare it with its original implementation provided by the authors. @@ -545,6 +561,24 @@ def call(self, x, training=False): x = self.patch_merge(x) return x + # Since this is a custom layer, we need to overwrite get_config() + # so that model can be easily saved & loaded after training + def get_config(self): + config = super().get_config() + config.update( + { + "epsilon": self.epsilon, + "mlp_dropout_rate": self.mlp_dropout_rate, + "num_shift_blocks": self.num_shift_blocks, + "stochastic_depth_rate": self.stochastic_depth_rate, + "is_merge": self.is_merge, + "num_div": self.num_div, + "shift_pixel": self.shift_pixel, + "mlp_expand_ratio": self.mlp_expand_ratio, + } + ) + return config + """ ## The ShiftViT model @@ -617,6 +651,8 @@ def __init__( ) self.global_avg_pool = layers.GlobalAveragePooling2D() + self.classifier = layers.Dense(config.num_classes) + def get_config(self): config = super().get_config() config.update( @@ -625,6 +661,7 @@ def get_config(self): "patch_projection": self.patch_projection, "stages": self.stages, "global_avg_pool": self.global_avg_pool, + "classifier": self.classifier, } ) return config @@ -644,7 +681,8 @@ def _calculate_loss(self, data, training=False): x = stage(x, training=training) # Get the logits. - logits = self.global_avg_pool(x) + x = self.global_avg_pool(x) + logits = self.classifier(x) # Calculate the loss and return it. total_loss = self.compiled_loss(labels, logits) @@ -661,6 +699,7 @@ def train_step(self, inputs): self.data_augmentation.trainable_variables, self.patch_projection.trainable_variables, self.global_avg_pool.trainable_variables, + self.classifier.trainable_variables, ] train_vars = train_vars + [stage.trainable_variables for stage in self.stages] @@ -683,6 +722,15 @@ def test_step(self, data): self.compiled_metrics.update_state(labels, logits) return {m.name: m.result() for m in self.metrics} + def call(self, images): + augmented_images = self.data_augmentation(images) + x = self.patch_projection(augmented_images) + for stage in self.stages: + x = stage(x, training=False) + x = self.global_avg_pool(x) + logits = self.classifier(x) + return logits + """ ## Instantiate the model @@ -787,11 +835,25 @@ def __call__(self, step): step > self.total_steps, 0.0, learning_rate, name="learning_rate" ) + def get_config(self): + config = { + "lr_start": self.lr_start, + "lr_max": self.lr_max, + "total_steps": self.total_steps, + "warmup_steps": self.warmup_steps, + } + return config + """ ## Compile and train the model """ +# pass sample data to the model so that input shape is available at the time of +# saving the model +sample_ds, _ = next(iter(train_ds)) +model(sample_ds, training=False) + # Get the total number of steps for training. total_steps = int((len(x_train) / config.batch_size) * config.epochs) @@ -843,6 +905,123 @@ def __call__(self, step): print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%") print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%") +""" +## Save trained model + +Since we created the model by Subclassing, we can't save the model in HDF5 format. + +It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well. +""" +model.save("ShiftViT") + +""" +## Model inference +""" + +""" +**Download sample data for inference** +""" + +"""shell +wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip +unzip -q inference_set.zip +""" + + +""" +**Load saved model** +""" +# Custom objects are not included when the model is saved. +# At loading time, these objects need to be passed for reconstruction of the model +saved_model = tf.keras.models.load_model( + "ShiftViT", + custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW}, +) + +""" +**Utility functions for inference** +""" + + +def process_image(img_path): + # read image file from string path + img = tf.io.read_file(img_path) + + # decode jpeg to uint8 tensor + img = tf.io.decode_jpeg(img, channels=3) + + # resize image to match input size accepted by model + # use `method` as `nearest` to preserve dtype of input passed to `resize()` + img = tf.image.resize( + img, [config.input_shape[0], config.input_shape[1]], method="nearest" + ) + return img + + +def create_tf_dataset(image_dir): + data_dir = pathlib.Path(image_dir) + + # create tf.data dataset using directory of images + predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False) + + # use map to convert string paths to uint8 image tensors + # setting `num_parallel_calls' helps in processing multiple images parallely + predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO) + + # create a Prefetch Dataset for better latency & throughput + predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO) + return predict_ds + + +def predict(predict_ds): + # ShiftViT model returns logits (non-normalized predictions) + logits = saved_model.predict(predict_ds) + + # normalize predictions by calling softmax() + probabilities = tf.nn.softmax(logits) + return probabilities + + +def get_predicted_class(probabilities): + pred_label = np.argmax(probabilities) + predicted_class = config.label_map[pred_label] + return predicted_class + + +def get_confidence_scores(probabilities): + # get the indices of the probability scores sorted in descending order + labels = np.argsort(probabilities)[::-1] + confidences = { + config.label_map[label]: np.round((probabilities[label]) * 100, 2) + for label in labels + } + return confidences + + +""" +**Get predictions** +""" + +img_dir = "inference_set" +predict_ds = create_tf_dataset(img_dir) +probabilities = predict(predict_ds) +print(f"probabilities: {probabilities[0]}") +confidences = get_confidence_scores(probabilities[0]) +print(confidences) + +""" +**View predictions** +""" + +plt.figure(figsize=(10, 10)) +for images in predict_ds: + for i in range(min(6, probabilities.shape[0])): + ax = plt.subplot(3, 3, i + 1) + plt.imshow(images[i].numpy().astype("uint8")) + predicted_class = get_predicted_class(probabilities[i]) + plt.title(predicted_class) + plt.axis("off") + """ ## Conclusion @@ -866,3 +1045,11 @@ def __call__(self, step): - A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for helping us with the Learning Rate Schedule. """ + +""" +**Example available on HuggingFace** + +| Trained Model | Demo | +| :--: | :--: | +| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) | +"""