Keras get_weights() 的错误用法

结论:get_weights() 放到模型里不会随迭代更新

起因是在构建 MemN2N 模型的时候需要最后一个 Dense 层的权重等于前面 Embedding 层的转置。于是我一拍脑袋就是一通操作

1
2
answer = Lambda(lambda x: np.array(emb_A.get_weights()[0]) * x)(u_temp)
answer = Lambda(lambda x: K.sum(x, axis=2))(answer)

然后就出现问题了,发现迭代过程中emb_A.get_weights()不会更新,一直都是初始值,相当于没通过一个不断更新权重的Dense层,而只是每次输出前乘一个常量而已……

想想还是挺合理的,最后研究了半天也没想出来要怎么把这个 Embedding 层取反,没办法用现有层,只能直接搞个 MemN2N 层里面直接定义权重矩阵运算了……