基于tf.data API,我们可以使用简单的代码来构建复杂的输入,tf.data API可以轻松处理大量数据、不同的数据格式以及复杂的转换。

使用该API构建数据管道,主要依靠两个API:

  • tf.data.Dataset 表示一系列元素

    每个元素包含一个或多个Tensor 对象。例如,在图片管道中,一个元素可能是单个训练样本,具有一对表示图片数据和标签的张量。

  • tf.data.Iterator 它的主要机制实际上就是创建了一个有枚举功能的迭代器对象,用来保存数据集。

本节主要是介绍tf.data.Dataset 。

创建Dataset

方法1:可以直接从Tensor中创建,主要使用Dataset.from_tensor_slices()来创建数据集Dataset。

方法2:也可以通过对一个或多个tf.data.Dataset对象来使用变换(例如Dataset.zip)来创建Dataset。

Dataset的属性由构成该Dataset的元素的属性映射得到,元素可以是单个张量、张量元组,也可以是张量的嵌套元组。

需要特别注意,一个Dataset 对象包含多个元素,每个元素的结构都相同。每个元素包含一个或多个tf.Tensor 对象,这些对象被称为组件。要是元素结构不相同,可能会报错。

维度不一致导致的报错

创建Dataset的代码示例

可以直接使用from_tensor_slices来创建。

import tensorflow as tf

# 主要就是使用Dataset.from_tensor_slices来创建

# 创建一个一维的数据集
dataset1 = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7])
# 内容:  <TensorSliceDataset shapes: (), types: tf.int32>

for ele in dataset:
    print(ele)         # 输出dataset里的内容
    print(ele.numpy())   # 转换为numpy类型
# 输出结果
'''
tf.Tensor(1, shape=(), dtype=int32)
1
tf.Tensor(2, shape=(), dtype=int32)
2
tf.Tensor(3, shape=(), dtype=int32)
3
tf.Tensor(4, shape=(), dtype=int32)
4
tf.Tensor(5, shape=(), dtype=int32)
5
tf.Tensor(6, shape=(), dtype=int32)
6
tf.Tensor(7, shape=(), dtype=int32)
7
'''

# 还可以创建一个二维的数据集
dataset2 = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4], [5, 6]])  # 二维列表,三个元素

# 还可以使用字典来创建dataset
dataset_dic = tf.data.Dataset.from_tensor_slices({'a':[1,2,3,4],
                                                  'b':[6,7,8,9],
                                                  'c':[12,13,14,15]
                                                 })

使用numpy创建的数据,也会被直接被转换为Tensor类型。

转换为Tensor类型

tf.data的用法(shuffle、batch、repeat)

dataset.shuffle作用是将数据进行打乱操作,传入参数为buffer_size,改参数为设置“打乱缓存区大小”,也就是说程序会维持一个buffer_size大小的缓存,每次都会随机在这个缓存区抽取一定数量的数据

dataset.batch作用是将数据打包成batch_size

dataset.repeat作用就是将数据重复使用多少epoch

  • 1、take() : 取出指定的数据。一般都是使用for循环来迭代。

take方法

  • 2、重复函数:repeat()

后面的参数是重复几次,当repeat()参数为空时,意思是重复无数遍。

  • 3、数据的变换之乱序:shuffle(shuffle_size),shuffle_size表示乱序的范围

shuffle乱序

  • 4、batch相当于一个定长度的切片

batch演示

那么,如何对里面的数据进行操作呢?需要用到dataset.map函数。下面举个例子:

map方法


博主个人公众号
版权声明 ▶ 本网站名称:陶小桃Blog
▶ 本文链接:https://www.52txr.cn/2022/tfdata1.html
▶ 本网站的文章部分内容可能来源于网络,仅供大家学习与参考,如有侵权,请联系站长进行核实删除。
▶ 转载本站文章需要遵守:商业转载请联系站长,非商业转载请注明出处!!

最后修改:2022 年 09 月 14 日
如果觉得我的文章对你有用,请随意赞赏