python - 在 TensorFlow 中使用预训练的词嵌入(word2vec 或 Glove)

我最近查看了 convolutional text classification 的一个有趣的实现。 .然而,我查看过的所有 TensorFlow 代码都使用随机(未预训练)嵌入向量,如下所示:

with tf.device('/cpu:0'), tf.name_scope("embedding"):
    W = tf.Variable(
        tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
        name="W")
    self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x)
    self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

有人知道如何使用 Word2vec 或 GloVe 预训练词嵌入的结果而不是随机的吗?

最佳答案

您可以通过多种方式在 TensorFlow 中使用预训练嵌入。假设您在一个名为 embedding 的 NumPy 数组中嵌入了 vocab_size 行和 embedding_dim 列,并且您想创建一个张量 W 可用于调用 tf.nn.embedding_lookup() .

  1. 只需将 W 创建为 tf.constant()embedding 作为其值:

    W = tf.constant(embedding, name="W")
    

    这是最简单的方法,但内存效率不高,因为 tf.constant() 的值在内存中存储了多次。由于 embedding 可能非常大,因此您应该只将这种方法用于玩具示例。

  2. 创建 W 作为 tf.Variable 并通过 tf.placeholder() 从 NumPy 数组初始化它:

    W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_dim]),
                    trainable=False, name="W")
    
    embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_dim])
    embedding_init = W.assign(embedding_placeholder)
    
    # ...
    sess = tf.Session()
    
    sess.run(embedding_init, feed_dict={embedding_placeholder: embedding})
    

    这避免了在图中存储 embedding 的副本,但它确实需要足够的内存来一次在内存中保存矩阵的两个副本(一个用于 NumPy 数组,一个用于 tf.变量)。请注意,我假设您希望在训练期间保持嵌入矩阵不变,因此 W 是使用 trainable=False 创建的。

  3. 如果嵌入是作为另一个 TensorFlow 模型的一部分进行训练的,您可以使用 tf.train.Saver从另一个模型的检查点文件加载值。这意味着嵌入矩阵可以完全绕过 Python。像选项 2 一样创建 W,然后执行以下操作:

    W = tf.Variable(...)
    
    embedding_saver = tf.train.Saver({"name_of_variable_in_other_model": W})
    
    # ...
    sess = tf.Session()
    embedding_saver.restore(sess, "checkpoint_filename.ckpt")
    

https://stackoverflow.com/questions/35687678/

相关文章:

python - matplotlib 中的曲面图

python - 有效地检查 Python/numpy/pandas 中的任意对象是否为 NaN?

python - sqlite3.ProgrammingError : You must not u

python - 断言 numpy.array 相等性的最佳方法?

python - 使用 pandas GroupBy.agg() 对同一列进行多个聚合

python - SQLAlchemy:如何过滤日期字段?

python - 如何从具有透明背景的 matplotlib 导出绘图?

python - 在系统范围内安装 pip 和 virtualenv 的官方 "preferred"

python - Python 多处理模块的 .join() 方法到底在做什么?

python - 如何在 Tesseract 和 OpenCV 之间进行选择?