Skip to content

Commit

Permalink
Ported to keras v3
Browse files Browse the repository at this point in the history
  • Loading branch information
di-kan committed Sep 7, 2024
1 parent f0daec7 commit 91a7c76
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions examples/vision/shiftvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,15 @@
In this example, we minimally implement the paper with close alignement to the author's
[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).
This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can
be installed using the following command:
"""
"""shell
pip install -qq -U tensorflow-addons
"""

"""
## Setup and imports
"""

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import keras
from keras import layers

import pathlib
import glob
Expand Down Expand Up @@ -280,7 +271,7 @@ def __init__(self, drop_path_prob, **kwargs):
def call(self, x, training=False):
if training:
keep_prob = 1 - self.drop_path_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
Expand Down Expand Up @@ -871,7 +862,7 @@ def get_config(self):
)

# Get the optimizer.
optimizer = tfa.optimizers.AdamW(
optimizer = keras.optimizers.AdamW(
learning_rate=scheduled_lrs, weight_decay=config.weight_decay
)

Expand Down Expand Up @@ -913,7 +904,7 @@ def get_config(self):
It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well.
"""
model.save("ShiftViT")
model.export("ShiftViT")

"""
## Model inference
Expand All @@ -932,12 +923,10 @@ def get_config(self):
"""
**Load saved model**
"""
# Custom objects are not included when the model is saved.
# At loading time, these objects need to be passed for reconstruction of the model
saved_model = tf.keras.models.load_model(
"ShiftViT",
custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
)
saved_layer = keras.layers.TFSMLayer("ShiftViT")
inputs = tf.keras.Input(shape=(config.input_shape)) # specify your input shape
outputs = saved_layer(inputs)
saved_model = tf.keras.Model(inputs, outputs)

"""
**Utility functions for inference**
Expand Down

0 comments on commit 91a7c76

Please sign in to comment.