python - 如何在 TensorFlow 中应用渐变裁剪?

考虑到 example code .

我想知道如何在 RNN 上的这个网络上应用梯度裁剪,因为那里有可能发生梯度爆炸。

tf.clip_by_value(t, clip_value_min, clip_value_max, name=None)

这是一个可以使用的示例,但我在哪里介绍呢? 在RNN的def中

    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
    # Split data because rnn cell needs a list of inputs for the RNN inner loop
    _X = tf.split(0, n_steps, _X) # n_steps
tf.clip_by_value(_X, -1, 1, name=None)

但这没有意义,因为张量 _X 是输入,而不是 grad 要剪裁什么?

我必须为此定义自己的优化器还是有更简单的选择?

最佳答案

梯度裁剪需要在计算梯度之后,但在应用它们更新模型参数之前进行。在您的示例中,这两件事都由 AdamOptimizer.minimize() 方法处理。

为了剪裁渐变,您需要按照 this section in TensorFlow's API documentation 中的说明显式计算、剪裁和应用它们。 .具体来说,您需要将 minimize() 方法的调用替换为以下内容:

optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
gvs = optimizer.compute_gradients(cost)
capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
train_op = optimizer.apply_gradients(capped_gvs)

https://stackoverflow.com/questions/36498127/

相关文章:

python - 对带有异常的字符串进行标题化

python - 如何使用 Python 创建完整的压缩 tar 文件?

python - 如何绘制正态分布

python - 为什么循环导入似乎在调用堆栈中更靠前,但在更靠下的位置引发 ImportError

python - “模块”没有属性 'urlencode'

python - 非 ASCII 字符的语法错误

python - 如何在 python 中从变量参数(kwargs)设置类属性

python - 如何使用 Python 在 NLTK 中使用 Stanford Parser

python - 如何检查文件是否是有效的图像文件?

python - SQLAlchemy ORM 转换为 pandas DataFrame