Skip to content

Commit

Permalink
Replaced expand_dims with reshape ops to avoid data copies (#630)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber authored Jan 28, 2019
1 parent 9c5ed6e commit 1f27c2a
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 38 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.72]
### Changed
- Removed use of `expand_dims` in favor of `reshape` to save memory.


## [1.18.71]
### Fixed
- Fixed default setting of source factor combination to be 'concat' for backwards compatibility.

## [1.18.70]
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.71'
__version__ = '1.18.72'
2 changes: 1 addition & 1 deletion sockeye/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def step(self, data):
bias=self.conv_bias,
num_hidden=num_hidden)
# (batch_size, num_hidden, 1)
data_conv = mx.sym.expand_dims(data_conv, axis=2)
data_conv = mx.sym.reshape(data_conv, shape=(-2, 1))
return self._post_convolution(data_conv)

def _post_convolution(self, data_conv):
Expand Down
21 changes: 10 additions & 11 deletions sockeye/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
:param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden).
:return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden).
"""
return prev_coverage + mx.sym.expand_dims(attention_prob_scores, axis=2)
return prev_coverage + mx.sym.reshape(attention_prob_scores, shape=(-2, 1))

return update_coverage

Expand Down Expand Up @@ -183,9 +183,8 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
"""

# (batch_size, source_seq_len, 1)
expanded_att_scores = mx.sym.expand_dims(data=attention_prob_scores,
axis=2,
name="%sexpand_attention_scores" % self.prefix)
expanded_att_scores = mx.sym.reshape(attention_prob_scores, shape=(-2, 1),
name="%sexpand_attention_scores" % self.prefix)

# (batch_size, source_seq_len, 1)
new_coverage = scaled_fertility * expanded_att_scores
Expand Down Expand Up @@ -237,13 +236,13 @@ def update_coverage(prev_hidden: mx.sym.Symbol,

# (batch_size, source_seq_len, decoder_num_hidden)
expanded_decoder = mx.sym.broadcast_axis(
data=mx.sym.expand_dims(data=prev_hidden, axis=1, name="%sexpand_decoder" % self.prefix),
data=mx.sym.reshape(data=prev_hidden, shape=(0, 1, -1), name="%sexpand_decoder" % self.prefix),
axis=1, size=source_seq_len, name="%sbroadcast_decoder" % self.prefix)

# (batch_size, source_seq_len, 1)
expanded_att_scores = mx.sym.expand_dims(data=attention_prob_scores,
axis=2,
name="%sexpand_attention_scores" % self.prefix)
expanded_att_scores = mx.sym.reshape(data=attention_prob_scores,
shape=(-2, 1),
name="%sexpand_attention_scores" % self.prefix)

# (batch_size, source_seq_len, encoder_num_hidden + decoder_num_hidden + 1)
# +1 for the attention_prob_score for the source word
Expand Down Expand Up @@ -332,7 +331,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
name="%sprevious_hidden_fc" % self.prefix)

# (batch_size, source_seq_len, 1)
attention_prob_scores = mx.sym.expand_dims(attention_prob_scores, axis=2)
attention_prob_scores = mx.sym.reshape(attention_prob_scores, shape=(-2, 1))

# (batch_size, source_seq_len, coverage_num_hidden)
attention_hidden = mx.sym.FullyConnected(data=attention_prob_scores,
Expand All @@ -347,8 +346,8 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
num_hidden=self.num_hidden, name="%sdecoder_hidden")

# (batch_size, 1, coverage_num_hidden)
prev_hidden = mx.sym.expand_dims(data=prev_hidden, axis=1,
name="%sinput_decoder_hidden_expanded" % self.prefix)
prev_hidden = mx.sym.reshape(data=prev_hidden, shape=(0, 1, -1),
name="%sinput_decoder_hidden_expanded" % self.prefix)

# (batch_size, source_seq_len, coverage_num_hidden)
intermediate = mx.sym.broadcast_add(lhs=source_hidden, rhs=prev_hidden,
Expand Down
12 changes: 6 additions & 6 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def decode_sequence(self,
fold_heads=True,
name="%ssource_bias" % self.prefix)
# (batch_size * heads, 1, max_length)
source_bias = mx.sym.expand_dims(source_bias, axis=1)
source_bias = mx.sym.reshape(source_bias, shape=(0, 1, -1))

# (1, target_max_length, target_max_length)
target_bias = transformer.get_autoregressive_bias(target_embed_max_length, name="%starget_bias" % self.prefix)
Expand Down Expand Up @@ -302,7 +302,7 @@ def decode_step(self,
# (batch_size, num_embed)
target_embed_prev = self.pos_embedding.encode_positions(indices, target_embed_prev)
# (batch_size, 1, num_embed)
target = mx.sym.expand_dims(target_embed_prev, axis=1)
target = mx.sym.reshape(target_embed_prev, shape=(0, 1, -1))

# (batch_size * heads, max_length)
source_bias = transformer.get_variable_length_bias(lengths=source_encoded_lengths,
Expand All @@ -311,7 +311,7 @@ def decode_step(self,
fold_heads=True,
name="%ssource_bias" % self.prefix)
# (batch_size * heads, 1, max_length)
source_bias = mx.sym.expand_dims(source_bias, axis=1)
source_bias = mx.sym.reshape(source_bias, shape=(0, 1, -1))

# auto-regressive bias for last position in sequence
# (1, target_max_length, target_max_length)
Expand Down Expand Up @@ -779,7 +779,7 @@ def get_initial_state(self,
# we derive the shape of hidden and layer_states from some input to enable
# shape inference for the batch dimension during inference.
# (batch_size, 1)
zeros = mx.sym.expand_dims(mx.sym.zeros_like(source_encoded_length), axis=1)
zeros = mx.sym.reshape(mx.sym.zeros_like(source_encoded_length), shape=(-1, 1))
# last encoder state: (batch, num_hidden)
source_encoded_last = mx.sym.SequenceLast(data=source_encoded,
axis=1,
Expand Down Expand Up @@ -807,7 +807,7 @@ def get_initial_state(self,
elif self.config.state_init == C.RNN_DEC_INIT_AVG:
# (batch_size, encoder_num_hidden)
init = mx.sym.broadcast_div(mx.sym.sum(source_masked, axis=1, keepdims=False),
mx.sym.expand_dims(source_encoded_length, axis=1))
mx.sym.reshape(source_encoded_length, shape=(-1, 1)))
else:
raise ValueError("Unknown decoder state init type '%s'" % self.config.state_init)

Expand Down Expand Up @@ -1139,7 +1139,7 @@ def decode_step(self,
weight=self.i2h_weight)
# re-arrange outcoming layer to the dimensions of the output
# (batch_size, 1, num_hidden)
target_hidden_step = mx.sym.expand_dims(target_hidden_step, axis=1)
target_hidden_step = mx.sym.reshape(target_hidden_step, shape=(0, 1, -1))
# (batch_size, kernel_width, num_hidden)
target_hidden = mx.sym.concat(embed_layer_state, target_hidden_step, dim=1)

Expand Down
18 changes: 9 additions & 9 deletions sockeye/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,11 @@ def encode_positions(self,
:return: (batch_size, num_embed)
"""
# (batch_size, 1)
positions = mx.sym.expand_dims(positions, axis=1)
positions = mx.sym.reshape(positions, shape=(-1, 1))
# (num_embed,)
channels = mx.sym.arange(0, self.num_embed // 2)
# (1, num_embed,)
scaling = mx.sym.expand_dims(1. / mx.sym.pow(10000, (2 * channels) / self.num_embed), axis=0)
# (1, num_embed)
scaling = mx.sym.reshape(1. / mx.sym.pow(10000, (2 * channels) / self.num_embed), shape=(1, -1))

# (batch_size, num_embed/2)
scaled_positions = mx.sym.dot(positions, scaling)
Expand Down Expand Up @@ -614,7 +614,7 @@ def encode(self,
"""

# (1, source_seq_len)
positions = mx.sym.expand_dims(data=mx.sym.arange(start=0, stop=seq_len, step=1), axis=0)
positions = mx.sym.reshape(data=mx.sym.arange(start=0, stop=seq_len, step=1), shape=(1, -1))

# (1, source_seq_len, num_embed)
pos_embedding = mx.sym.Embedding(data=positions,
Expand Down Expand Up @@ -1043,11 +1043,11 @@ def encode(self,
data = mx.sym.Dropout(data=data, p=self.config.dropout_prepost)

# (batch_size * heads, 1, max_length)
bias = mx.sym.expand_dims(transformer.get_variable_length_bias(lengths=data_length,
max_length=seq_len,
num_heads=self.config.attention_heads,
fold_heads=True,
name="%sbias" % self.prefix), axis=1)
bias = mx.sym.reshape(transformer.get_variable_length_bias(lengths=data_length,
max_length=seq_len,
num_heads=self.config.attention_heads,
fold_heads=True,
name="%sbias" % self.prefix), shape=(0, 1, -1))
bias = utils.cast_conditionally(bias, self.dtype)
for i, layer in enumerate(self.layers):
# (batch_size, seq_len, config.model_size)
Expand Down
2 changes: 1 addition & 1 deletion sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2265,7 +2265,7 @@ def hybrid_forward(self, F, best_word_indices, max_output_lengths, finished, sco

# Update lengths of all items, except those that were already finished. This updates
# the lengths for inactive items, too, but that doesn't matter since they are ignored anyway.
lengths = lengths + F.cast(1 - F.expand_dims(finished, axis=1), dtype='float32')
lengths = lengths + F.cast(1 - F.reshape(finished, shape=(-1, 1)), dtype='float32')

# Now, recompute finished. Hypotheses are finished if they are
# - extended with <pad>, or
Expand Down
11 changes: 9 additions & 2 deletions sockeye/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,15 @@ def broadcast_to_heads(x: mx.sym.Symbol, num_heads: int, ndim: int, fold_heads:
Shape: (batch * heads, d1 ... dn-1) if fold_heads == True, (batch, heads, d1 ... dn-1) else.
"""
dims = [0] * (ndim - 1)
# x: (batch, 1)
x = mx.sym.expand_dims(x, axis=1)
if ndim == 1:
# x: (batch, 1)
x = mx.sym.reshape(x, shape=(-1, 1))
elif ndim == 2:
# x: (batch, 1, d1)
x = mx.sym.reshape(x, shape=(0, 1, -1))
else:
# x: (batch, 1, d1 ... dn - 1)
x = mx.sym.reshape(x, shape=(0, 1, -2))
# x: (batch, heads, dims...)
x = mx.sym.broadcast_to(x, shape=[0, num_heads] + dims)
if fold_heads:
Expand Down
14 changes: 7 additions & 7 deletions sockeye/rnn_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta
:return: Updated attention state.
"""
# (batch_size, decoder_num_hidden, 1)
query = mx.sym.expand_dims(att_input.query, axis=2)
query = mx.sym.reshape(att_input.query, shape=(-2, 1))

# in: (batch_size, source_seq_len, self.num_hidden) X (batch_size, self.num_hidden, 1)
# out: (batch_size, source_seq_len, 1).
Expand Down Expand Up @@ -368,7 +368,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta
query = query * self.scale

# (batch_size, decoder_num_hidden, 1)
expanded_decoder_state = mx.sym.expand_dims(query, axis=2)
expanded_decoder_state = mx.sym.reshape(query, shape=(-2, 1))

# batch_dot: (batch, M, K) X (batch, K, N) –> (batch, M, N).
# (batch_size, seq_len, 1)
Expand Down Expand Up @@ -479,7 +479,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta

# combine heads
# (batch*heads, 1, num_hidden/head)
context = mx.sym.expand_dims(context, axis=1)
context = mx.sym.reshape(context, shape=(0, 1, -1))
# (batch, 1, num_hidden)
context = layers.combine_heads(context, self.num_hidden_per_head, heads=self.heads)
# (batch, num_hidden)
Expand Down Expand Up @@ -585,7 +585,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta
end=source_seq_len)

# attention_scores: (batch_size, seq_len, 1)
attention_scores = mx.sym.expand_dims(data=attention_scores, axis=2)
attention_scores = mx.sym.reshape(data=attention_scores, shape=(-2, 1))

context, attention_probs = get_context_and_attention_probs(source, source_length, attention_scores,
self.dtype)
Expand Down Expand Up @@ -687,9 +687,9 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta
name="%squery_hidden" % self.prefix)

# (batch_size, 1, attention_num_hidden)
query_hidden = mx.sym.expand_dims(data=query_hidden,
axis=1,
name="%squery_hidden_expanded" % self.prefix)
query_hidden = mx.sym.reshape(data=query_hidden,
shape=(0, 1, -1),
name="%squery_hidden_expanded" % self.prefix)

attention_hidden_lhs = source_hidden
if self.coverage:
Expand Down

0 comments on commit 1f27c2a

Please sign in to comment.