关于 Keras 的 Embedding Initializer

结论:我的脑子坏掉了orz


起因是在逐步检查模型的时候发现这样的小模型:

1
2
3
4
5
input = Input(shape=(5,))
output = Embedding(input_dim=4, output_dim=3, embeddings_initializer='ones')(input)
model = Model(input, output)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy',
metrics=['accuracy'])

不进行训练,而直接进行预测

1
2
input = np.array([[1, 1, 2, 0, 3]])
output = model_1.predict(input)

得到的值是:

1
2
3
4
5
array([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]], dtype=float32)

然后我就惊呆了,觉得怎么会这样,居然初始化的是输出而不是权重 Σ(°Д°;

吓得我赶紧去看了看源码,实际上初始化不管是Dense层还是Embedding层都是initializers.py里完成的,拿上面用的ones为例

1
2
3
4
5
6
class Ones(Initializer):
"""Initializer that generates tensors initialized to 1.
"""

def __call__(self, shape, dtype=None):
return K.constant(1, shape=shape, dtype=dtype)

然后看这里面用的shape就能证明初始化的到底是输出还是权重矩阵。这里的shape来自embeddings.py

1
2
3
4
5
6
7
8
9
def build(self, input_shape):
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer,
name='embeddings',
regularizer=self.embeddings_regularizer,
constraint=self.embeddings_constraint,
dtype=self.dtype)
self.built = True

shape不是妥妥的权重矩阵嘛……

然后我就只能倒回去看看之前的小模型,然后就发现权重矩阵全为 1 输出算出来就该全是 1 ……

input = [1, 1, 2, 0, 3] one-hot 展开一下:

1
2
3
4
5
6
7
[
[0, 1, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 0, 0, 1]
]

然后上面那个例子里的 Embedding 矩阵是这样:

1
2
3
4
5
6
[
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]
]

这两个矩阵乘一下不是全是 1 是什么orz

当然实际上 Embedding 层在实现的时候不是通过矩阵乘法,而是通过K.gather() 来实现的

1
2
3
4
5
def call(self, inputs):
if K.dtype(inputs) != 'int32':
inputs = K.cast(inputs, 'int32')
out = K.gather(self.embeddings, inputs)
return out

然后K.gather()在文档中的定义为

Retrieves the elements of indices indices in the tensor reference.

源码中对应到 tensorflow 的 tf.nn.embedding_lookup(reference, indices),theano 和 CNTK 的实现要更复杂一点,就不贴了。

反正完成的功能就是在给定张量中检索给定下标的向量。

所以也就知道 Embedding 类不是通过将 input 转换为 one-hot 向量之后再进行向量乘法来获得输出的,而是直接检索 input 的每个值对应到 Embedding 矩阵中的向量。这也符合之前在看 Embedding 实现原理的时候提到的,向量乘法运算量太大了(矩阵很大),通过检索的方法更有效率。

所以 Embedding 权重矩阵全为 1 的时候,无论查到哪一列向量不都全是 1 嘛……