You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
https://github.com/Kyubyong/transformer/blob/master/model.py
In this code from line 176 ~ 181, you are using "==" inside of tensorflow model which won't work. for _ in tqdm(range(self.hp.maxlen2)): logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False) if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1) ys = (_decoder_inputs, y, y_seqlen, sents2)
This would result not stopping at the pad output but keep iterates until the maxlen ends.
This is a minor issue but makes the eval function slower.
Use something like this instead would make the eval function faster: logits, y_hat, y, sent2 = tf.cond(tf.equal(y_hat[0][-1], self.token2idx["<pad>"]), lambda: (logits, y_hat, y, sent2), lambda:self.decode(ys, memory, src_masks, False))
The text was updated successfully, but these errors were encountered:
I wonder that why you use y_hat[0][-1], because the first shape of y_hat equals with self.hp.batch_size , why you use every first example to calculate one batch data whether meets 'pad' or not ?
https://github.com/Kyubyong/transformer/blob/master/model.py
In this code from line 176 ~ 181, you are using "==" inside of tensorflow model which won't work.
for _ in tqdm(range(self.hp.maxlen2)): logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False) if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1) ys = (_decoder_inputs, y, y_seqlen, sents2)
This would result not stopping at the pad output but keep iterates until the maxlen ends.
This is a minor issue but makes the eval function slower.
Use something like this instead would make the eval function faster:
logits, y_hat, y, sent2 = tf.cond(tf.equal(y_hat[0][-1], self.token2idx["<pad>"]), lambda: (logits, y_hat, y, sent2), lambda:self.decode(ys, memory, src_masks, False))
The text was updated successfully, but these errors were encountered: