Skip to content

Commit

Permalink
fix formatting issue with example (#1488)
Browse files Browse the repository at this point in the history
  • Loading branch information
pcoet committed Aug 15, 2023
1 parent 88f5103 commit bec7e1f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@
"(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)\n",
"print(len(x_train), \"Training sequences\")\n",
"print(len(x_val), \"Validation sequences\")\n",
"x_train = keras.utils.pad_sequences(\n",
" x_train, maxlen=num_tokens_per_example\n",
")\n",
"x_train = keras.utils.pad_sequences(x_train, maxlen=num_tokens_per_example)\n",
"x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)"
]
},
Expand Down Expand Up @@ -175,9 +173,9 @@
"outputs": [],
"source": [
"\n",
"def create_feedforward_network(ff_dim, name=None):\n",
"def create_feedforward_network(ff_dim, embed_dim, name=None):\n",
" return keras.Sequential(\n",
" [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(ff_dim)], name=name\n",
" [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim)], name=name\n",
" )\n",
""
]
Expand Down Expand Up @@ -218,7 +216,7 @@
" # num_expert elements. The two vectors will be pushed towards uniform allocation\n",
" # when the dot product is minimized.\n",
" loss = tf.reduce_mean(density_proxy * density) * tf.cast(\n",
" (num_experts ** 2), tf.dtypes.float32\n",
" (num_experts**2), tf.dtypes.float32\n",
" )\n",
" return loss\n",
""
Expand Down Expand Up @@ -321,11 +319,13 @@
"source": [
"\n",
"class Switch(layers.Layer):\n",
" def __init__(self, num_experts, embed_dim, num_tokens_per_batch, capacity_factor=1):\n",
" def __init__(\n",
" self, num_experts, embed_dim, ff_dim, num_tokens_per_batch, capacity_factor=1\n",
" ):\n",
" self.num_experts = num_experts\n",
" self.embed_dim = embed_dim\n",
" self.experts = [\n",
" create_feedforward_network(embed_dim) for _ in range(num_experts)\n",
" create_feedforward_network(ff_dim, embed_dim) for _ in range(num_experts)\n",
" ]\n",
"\n",
" self.expert_capacity = num_tokens_per_batch // self.num_experts\n",
Expand Down Expand Up @@ -430,8 +430,8 @@
"source": [
"\n",
"def create_classifier():\n",
" switch = Switch(num_experts, embed_dim, num_tokens_per_batch)\n",
" transformer_block = TransformerBlock(ff_dim, num_heads, switch)\n",
" switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)\n",
" transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)\n",
"\n",
" inputs = layers.Input(shape=(num_tokens_per_example,))\n",
" embedding_layer = TokenAndPositionEmbedding(\n",
Expand Down
31 changes: 16 additions & 15 deletions examples/nlp/md/text_classification_with_switch_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ num_tokens_per_example = 200 # Only consider the first 200 words of each movie
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = keras.utils.pad_sequences(
x_train, maxlen=num_tokens_per_example
)
x_train = keras.utils.pad_sequences(x_train, maxlen=num_tokens_per_example)
x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)
```

<div class="k-default-codeblock">
```
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
17464789/17464789 [==============================] - 1s 0us/step
25000 Training sequences
25000 Validation sequences
Expand Down Expand Up @@ -119,9 +118,9 @@ This is used as the Mixture of Experts in the Switch Transformer.

```python

def create_feedforward_network(ff_dim, name=None):
def create_feedforward_network(ff_dim, embed_dim, name=None):
return keras.Sequential(
[layers.Dense(ff_dim, activation="relu"), layers.Dense(ff_dim)], name=name
[layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim)], name=name
)

```
Expand Down Expand Up @@ -150,7 +149,7 @@ def load_balanced_loss(router_probs, expert_mask):
# num_expert elements. The two vectors will be pushed towards uniform allocation
# when the dot product is minimized.
loss = tf.reduce_mean(density_proxy * density) * tf.cast(
(num_experts ** 2), tf.dtypes.float32
(num_experts**2), tf.dtypes.float32
)
return loss

Expand Down Expand Up @@ -227,11 +226,13 @@ class Router(layers.Layer):
```python

class Switch(layers.Layer):
def __init__(self, num_experts, embed_dim, num_tokens_per_batch, capacity_factor=1):
def __init__(
self, num_experts, embed_dim, ff_dim, num_tokens_per_batch, capacity_factor=1
):
self.num_experts = num_experts
self.embed_dim = embed_dim
self.experts = [
create_feedforward_network(embed_dim) for _ in range(num_experts)
create_feedforward_network(ff_dim, embed_dim) for _ in range(num_experts)
]

self.expert_capacity = num_tokens_per_batch // self.num_experts
Expand Down Expand Up @@ -312,8 +313,8 @@ of it to classify text.
```python

def create_classifier():
switch = Switch(num_experts, embed_dim, num_tokens_per_batch)
transformer_block = TransformerBlock(ff_dim, num_heads, switch)
switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)
transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)

inputs = layers.Input(shape=(num_tokens_per_example,))
embedding_layer = TokenAndPositionEmbedding(
Expand Down Expand Up @@ -362,13 +363,13 @@ run_experiment(classifier)
<div class="k-default-codeblock">
```
Epoch 1/3
500/500 [==============================] - 575s 1s/step - loss: 1.5311 - accuracy: 0.7151 - val_loss: 1.2915 - val_accuracy: 0.8772
500/500 [==============================] - 645s 1s/step - loss: 1.4064 - accuracy: 0.8070 - val_loss: 1.3201 - val_accuracy: 0.8642
Epoch 2/3
500/500 [==============================] - 575s 1s/step - loss: 1.1971 - accuracy: 0.9262 - val_loss: 1.3073 - val_accuracy: 0.8708
500/500 [==============================] - 625s 1s/step - loss: 1.2073 - accuracy: 0.9218 - val_loss: 1.3140 - val_accuracy: 0.8713
Epoch 3/3
500/500 [==============================] - 624s 1s/step - loss: 1.1284 - accuracy: 0.9563 - val_loss: 1.3547 - val_accuracy: 0.8637
500/500 [==============================] - 637s 1s/step - loss: 1.1428 - accuracy: 0.9494 - val_loss: 1.3530 - val_accuracy: 0.8618
<tensorflow.python.keras.callbacks.History at 0x1495461d0>
<keras.src.callbacks.History at 0x136fb5450>
```
</div>
Expand Down
15 changes: 6 additions & 9 deletions examples/nlp/text_classification_with_switch_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,12 @@ def call(self, inputs, training=False):
expert_gate *= expert_mask_flat
# Combine expert outputs and scaling with router probability.
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
combined_tensor = (
tf.expand_dims(
expert_gate
* expert_mask_flat
* tf.squeeze(tf.one_hot(expert_index, depth=self.num_experts), 1),
-1,
)
* tf.squeeze(tf.one_hot(position_in_expert, depth=self.expert_capacity), 1)
)
combined_tensor = tf.expand_dims(
expert_gate
* expert_mask_flat
* tf.squeeze(tf.one_hot(expert_index, depth=self.num_experts), 1),
-1,
) * tf.squeeze(tf.one_hot(position_in_expert, depth=self.expert_capacity), 1)
# Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]
# that is 1 if the token gets routed to the corresponding expert.
dispatch_tensor = tf.cast(combined_tensor, tf.dtypes.float32)
Expand Down

0 comments on commit bec7e1f

Please sign in to comment.