-
Notifications
You must be signed in to change notification settings - Fork 1
/
attention_layer.py
30 lines (23 loc) · 1.13 KB
/
attention_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from keras import backend as K
from keras.engine.topology import Layer
class AttentionLayer(Layer):
def __init__(self, **kwargs):
super(AttentionLayer, self).__init__(**kwargs)
def build(self, input_shape):
if not isinstance(input_shape, list) or len(input_shape) != 2:
raise ValueError('An attention layer should be called '
'on a list of 2 inputs.')
if not input_shape[0][2] == input_shape[1][2]:
raise ValueError('Embedding sizes should be of the same size')
self.kernel = self.add_weight(shape=(input_shape[0][2], input_shape[0][2]),
initializer='glorot_uniform',
name='kernel',
trainable=True)
super(AttentionLayer, self).build(input_shape)
def call(self, inputs):
a = K.dot(inputs[0], self.kernel)
y_trans = K.permute_dimensions(inputs[1], (0,2,1))
b = K.batch_dot(a, y_trans, axes=[2,1])
return K.tanh(b)
def compute_output_shape(self, input_shape):
return (None, None, None)