Skip to content

Commit

Permalink
Migrate vivit tutorial to Keras3-all backends (#1739)
Browse files Browse the repository at this point in the history
  • Loading branch information
SuryanarayanaY authored Jan 22, 2024
1 parent 7217316 commit 37c1af2
Show file tree
Hide file tree
Showing 3 changed files with 1,570 additions and 13,199 deletions.
36 changes: 18 additions & 18 deletions examples/vision/ipynb/vivit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush Thakur](https://twitter.com/ayushthakur0) (equal contribution)<br>\n",
"**Date created:** 2022/01/12<br>\n",
"**Last modified:** 2024/01/13<br>\n",
"**Last modified:** 2024/01/15<br>\n",
"**Description:** A Transformer-based architecture for video classification."
]
},
Expand Down Expand Up @@ -43,8 +43,8 @@
"the embedding scheme and one of the variants of the Transformer\n",
"architecture, for simplicity.\n",
"\n",
"This example requires the `medmnist`\n",
"package, which can be installed by running the code cell below."
"This example requires `medmnist` package, which can be installed\n",
"by running the code cell below."
]
},
{
Expand Down Expand Up @@ -81,11 +81,9 @@
"import medmnist\n",
"import ipywidgets\n",
"import numpy as np\n",
"\n",
"import tensorflow as tf # for data preprocessing only\n",
"import keras\n",
"import tensorflow as tf\n",
"from keras import layers\n",
"from keras import ops\n",
"from keras import layers, ops\n",
"\n",
"# Setting seed for reproducibility\n",
"SEED = 42\n",
Expand Down Expand Up @@ -214,15 +212,17 @@
"outputs": [],
"source": [
"\n",
"def preprocess(frames, label):\n",
"def preprocess(frames: tf.Tensor, label: tf.Tensor):\n",
" \"\"\"Preprocess the frames tensors and parse the labels.\"\"\"\n",
" # Preprocess images\n",
" frames = ops.cast(frames, \"float32\")\n",
" frames = ops.expand_dims(\n",
" frames, axis=-1\n",
" ) # The new axis is to help for further processing with Conv3D layers\n",
" frames = tf.image.convert_image_dtype(\n",
" frames[\n",
" ..., tf.newaxis\n",
" ], # The new axis is to help for further processing with Conv3D layers\n",
" tf.float32,\n",
" )\n",
" # Parse label\n",
" label = ops.cast(label, \"float32\")\n",
" label = tf.cast(label, tf.float32)\n",
" return frames, label\n",
"\n",
"\n",
Expand Down Expand Up @@ -337,7 +337,7 @@
" self.position_embedding = layers.Embedding(\n",
" input_dim=num_tokens, output_dim=self.embed_dim\n",
" )\n",
" self.positions = ops.arange(start=0, stop=num_tokens, step=1)\n",
" self.positions = ops.arange(0, num_tokens, 1)\n",
"\n",
" def call(self, encoded_tokens):\n",
" # Encode the positions and add it to the encoded tokens\n",
Expand Down Expand Up @@ -411,8 +411,8 @@
" x3 = layers.LayerNormalization(epsilon=1e-6)(x2)\n",
" x3 = keras.Sequential(\n",
" [\n",
" layers.Dense(units=embed_dim * 4, activation=\"gelu\"),\n",
" layers.Dense(units=embed_dim, activation=\"gelu\"),\n",
" layers.Dense(units=embed_dim * 4, activation=ops.gelu),\n",
" layers.Dense(units=embed_dim, activation=ops.gelu),\n",
" ]\n",
" )(x3)\n",
"\n",
Expand Down Expand Up @@ -511,9 +511,9 @@
"\n",
"for i, (testsample, label) in enumerate(zip(testsamples, labels)):\n",
" # Generate gif\n",
" testsample = ops.reshape(testsample, (-1, 28, 28))\n",
" testsample = np.reshape(testsample.numpy(), (-1, 28, 28))\n",
" with io.BytesIO() as gif:\n",
" imageio.mimsave(gif, (testsample.numpy() * 255).astype(\"uint8\"), \"GIF\", fps=5)\n",
" imageio.mimsave(gif, (testsample * 255).astype(\"uint8\"), \"GIF\", fps=5)\n",
" videos.append(gif.getvalue())\n",
"\n",
" # Get model prediction\n",
Expand Down
Loading

0 comments on commit 37c1af2

Please sign in to comment.