tensorflow - 如何将 tf.data.Dataset 与 kedro 一起使用?

我正在使用 tf.data.Dataset准备用于训练 tf.kears 模型的流数据集。用kedro , 有没有办法创建一个节点并返回创建的 tf.data.Dataset 以在下一个训练节点中使用它?

MemoryDataset可能无法工作,因为 tf.data.Dataset 不能被 pickle(deepcopy 不可能),另见 this SO question .根据issue #91 MemoryDataset 中的深拷贝是为了避免其他节点修改数据。有人可以详细说明为什么/如何发生这种并发修改吗?

来自docs ,似乎有一个 copy_mode = "assign"。如果数据不可 pickle ,是否可以使用此选项?

另一种解决方案(在 issue 91 中也提到)是仅使用一个函数在训练节点内生成流式 tf.data.Dataset,而无需前面的数据集生成节点。但是,我不确定这种方法的缺点是什么(如果有的话)。如果有人可以举一些例子,那就太好了。

此外,我想避免存储流数据集的完整输出,例如使用 tfrecordstf.data.experimental.save因为这些选项会占用大量磁盘存储空间。

有没有办法只传递创建的 tf.data.Dataset 对象以将其用于训练节点?

最佳答案

虽然在 kedro.community 中有介绍,但为了社区的利益在此提供解决方法由@DataEngineerOne 提供。

根据@DataEngineerOne。

With kedro, is there a way to create a node and return the created tf.data.Dataset to use it in the next training node?

是的,绝对!

Can someone please elaborate a bit more on why/how this concurrent modification could happen?

From the docs, there seems to be a copy_mode = "assign" . Would it be possible to use this option in case the data is not picklable?

我还没有尝试过这个选项,但理论上应该可行。您需要做的就是在包含 copy_mode 选项的 catalog.yml 文件中创建一个新的数据集条目。

例如:

# catalog.yml
tf_data:
  type: MemoryDataSet
  copy_mode: assign

# pipeline.py
node(
  tf_generator,
  inputs=...,
  outputs="tf_data",
)

我不能保证这个解决方案,但试一试,让我知道它是否适合你。

Another solution (also mentioned in issue 91) is to use just a function to generate the streaming tf.data.Dataset inside the training node, without having the preceding dataset generation node. However, I am not sure what the drawbacks of this approach will be (if any). Would be greate if someone could give some examples.

这也是一个很好的替代解决方案,我认为(猜测)MemoryDataSet 在这种情况下会自动使用 assign,而不是其正常的 deepcopy ,所以你应该没问题。

# node.py

def generate_tf_data(...):
  tensor_slices = [1, 2, 3]
  def _tf_data():
    dataset = tf.data.Dataset.from_tensor_slices(tensor_slices)
    return dataset
  return _tf_data

def use_tf_data(tf_data_func):
  dataset = tf_data_func()

# pipeline.py
Pipeline([
node(
  generate_tf_data,
  inputs=...,
  outputs='tf_data_func',
),
node(
  use_tf_data,
  inputs='tf_data_func',
  outputs=...
),
])

这里唯一的缺点是额外的复杂性。更多详情可以引用here .

https://stackoverflow.com/questions/63730066/

相关文章:

python-3.x - SendGrid 无法向 Yahoo 和 Outlook 发送列表-取消订

android - 如何处理 In-App Billing Library 中的多个用户? (最佳实

firebase - 如何在 flutter 中捕获 StorageException

authentication - 如何在 ReactJS 中实现登录亚马逊?

css - 关键帧动画问题

node.js - 在 node.js 中将字符串转换为 64 位 float

python - 如何在 Pandas DataFrames 中显示 IntEnum 的名称

python - 如何一次对每一行进行不同的 tensorflow 张量切片?

ios - iPhone 上的 Safari 是否支持为 <input type =“file

json - 使用 JSON Body on Rule 调用 REST API Business C