数据集 Dataset
仓颉 TensorBoost 提供数据加载及处理能力,支持读取多种数据集格式和自定义数据集,支持常用的数据处理、数据增强方法,数据加载和处理需要导入 dataset 包:
from CangjieTB import dataset.*
加载数据集
使用仓颉 TensorBoost 的 数据集接口可以加载数据集, 示例代码中使用MnistDataset
接口加载 MNIST 数据集,并使用随机采样器获取 5 个样本。其他格式的数据集接口可参见 附录 C: 数据集
var sampler5 = RandomSampler(numSamples: 5)
let dataPath: String = "./data/mnist/train"
var mnistDs = MnistDataset(dataPath, sampler: sampler5)
// 将图像数据映射到固定的范围
var rescale = rescale(1.0 / 255.0, 0.0)
mnistDs.datasetMap([rescale], "image")
将数据集读取到 Tensor 中
调用数据集类的 getNext
方法可以将数据读取到 Tensor 中,如下所示:
var input: Tensor = parameter(zerosTensor(Array<Int64>([1, 28, 28]), dtype: FLOAT32), "data")
var label: Tensor = parameter(zerosTensor(Array<Int64>([1]), dtype: INT32), "label")
while (mnistDs.getNext([input, label])) {
print("---------------\n")
print("input: ", input)
print("label: ", label)
}
输出为:
---------------
input:
Tensor(shape=[1, 28, 28], dtype=Float32, value=
[[[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
...
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]]])
label:
Tensor(shape=[1], dtype=Int32, value= [0])
...
【说明】:
parameter
函数用于定义数据可变的 Tensor 对象,作为网络训练的 input 和 label 。- 如果数据已经取完,
getNext
返回 false,循环结束。
【注意】:在定义 Parameter 时,要确保 shape 和 dtype 和数据文件中的数据一致。
自定义数据集
对于目前仓颉 TensorBoost 不支持直接加载的数据集,可以先将数据集转换成 MindRecord 或 TFRecord 格式,再用仓颉 TensorBoost 进行读取。
将数据转换成 MindRecord 格式的方法见自定义数据集转换为 MindRecord。将数据转换成 TFRecord 格式的方法见TFRecord。
读取自定义数据集./data/train.mindrecord
如下所示:
// 加载数据
let mindrecordSampler = RandomSampler(numSamples: 6)
let mindRecordDataPath: String = "./data/train.mindrecord"
let columnNames: Array<String> = ["R", "F", "E"]
var msDataset = MindDataDataset(mindRecordDataPath, columnNames, sampler: mindrecordSampler)
// 读取数据
var R = parameter(zerosTensor(Array<Int64>([9, 3])), "R")
var F = parameter(zerosTensor(Array<Int64>([9, 3])), "F")
var E = parameter(zerosTensor(Array<Int64>([1])), "E")
while (msDataset.getNext([R, F, E])) {
print("---------------\n")
print("R", R)
print("F", F)
print("E", E)
}
输出为:
---------------
R
Tensor(shape=[9, 3], dtype=Float32, value=
[[-1.95923567e-01 -5.75768769e-01 -8.86928290e-02]
[-9.65070069e-01 6.16460443e-01 5.17787218e-01]
[1.09309697e+00 -9.55241770e-02 -3.60066712e-01]
...
[-5.41572392e-01 1.01731408e+00 1.46117091e+00]
[-1.10684466e+00 1.27925050e+00 -2.28156179e-01]
[9.93210256e-01 6.73578858e-01 -9.84581828e-01]])
F
Tensor(shape=[9, 3], dtype=Float32, value=
[[-2.55999870e+01 3.66684914e+01 1.54593821e+01]
[4.35021973e+01 -5.85456161e+01 6.23012238e+01]
[2.99424801e+01 3.68474412e+00 -2.41483364e+01]
...
[7.96970320e+00 2.15402818e+00 -1.43688583e+01]
[-4.69919205e+00 6.80967560e+01 -5.65661163e+01]
[3.59104991e+00 -1.91799850e+01 1.60485039e+01]])
E
Tensor(shape=[1], dtype=Float32, value= [-9.71969688e+04])
...
和 MnistDataset 类似,使用 getNext
函数可以将输入读取到 Tensor 中。和 MnistDataset 不同的是,MindDataDataset 的构造方法中需要额外的一个列名的入参,以指定在数据文件中读取哪几列数据。
【注意】 传入 getNext
函数的 Tensor
数组元素的个数要与 MindRecord 数据集中保存的数据类别个数相同,并且对应的 shape 和 dtype 也要相同
数据处理与增强
数据处理
仓颉 TensorBoost 提供的数据集接口具备常用的数据处理方法,用户只需调用相应的函数接口即可快速进行数据处理。 如下代码展示了仓颉 TensorBoost 的数据处理过程:将数据集进行 shuffle 处理,然后将数据两两组成一个批次。
// 重新读取数据
msDataset = MindDataDataset(mindRecordDataPath, columnNames, sampler: mindrecordSampler)
let batchSize: Int32 = 2
let bufferSize: Int32 = 1000
// 随机打乱数据顺序
msDataset.shuffle(bufferSize)
// 对数据集进行分批
msDataset.batch(batchSize, true)
// 读取数据
R = parameter(zerosTensor([Int64(batchSize), 9, 3]), "R")
F = parameter(zerosTensor([Int64(batchSize), 9, 3]), "F")
E = parameter(zerosTensor([Int64(batchSize), 1]), "E")
while (msDataset.getNext([R, F, E])) {
print("---------------\n")
print("R", R)
print("F", F)
print("E", E)
}
输出为:
---------------
R
Tensor(shape=[2, 9, 3], dtype=Float32, value=
[[[-1.81845412e-01 -2.61216700e-01 -5.36276340e-01]
[-5.81744909e-02 1.27089608e+00 -2.06515580e-01]
[1.85678512e-01 -9.63150203e-01 6.43111765e-01]
...
[-6.72326028e-01 1.42653990e+00 7.66738772e-01]
[1.07878494e+00 1.39543390e+00 -3.93244214e-02]
[5.54795861e-01 -3.44734281e-01 1.29838037e+00]]
[[3.33174407e-01 4.40440863e-01 1.62186071e-01]
[-7.58461058e-01 5.63840449e-01 -8.67527306e-01]
[4.60158527e-01 -9.15827274e-01 6.33802950e-01]
...
[-1.68348837e+00 6.67325556e-01 -3.36017698e-01]
[-8.81266594e-01 1.50722134e+00 -1.51811504e+00]
[-3.96128386e-01 -1.10900962e+00 1.10739207e+00]]])
F
Tensor(shape=[2, 9, 3], dtype=Float32, value=
[[[5.50469017e+01 2.22117062e+01 -2.59304409e+01]
[-4.75894117e+00 -2.55197372e+01 6.58721209e+00]
[-6.24587297e+00 -2.03564777e+01 8.99683416e-01]
...
[2.27004929e+01 2.90634203e+00 -3.15368729e+01]
[-3.87679062e+01 1.21524820e+01 -4.42834806e+00]
[-1.44927382e+00 9.46627378e-01 -5.96580458e+00]]
[[3.88336296e+01 3.53328371e+00 5.44805384e+00]
[-1.79748745e+01 1.74029770e+01 -6.41082306e+01]
[-2.62239990e+01 -5.71368599e+00 1.24214802e+01]
...
[-2.74372044e+01 -1.68999844e+01 8.07555008e+00]
[2.52471638e+01 -1.31735935e+01 2.34950371e+01]
[2.11702538e+01 -2.78234661e-01 -1.10375137e+01]]])
E
Tensor(shape=[2, 1], dtype=Float32, value=
[[-9.71979297e+04]
[-9.71984062e+04]])
...
其中:
bufferSize
:数据集中进行 shuffle 操作的缓存区的大小。batchSize
:每组包含的数据个数,现设置每组包含 2 个数据。
数据增强
数据量过小或是样本场景单一等问题会影响模型的训练效果,用户可以通过数据增强操作扩充样本多样性,从而提升模型的泛化能力。
如下代码可以对数据集进行数据增强:定义数据处理算子,对数据集进行 resize
,rescale
和 randomCrop
操作,然后通过 datasetMap
设置数据处理的管道,达到生成新数据的目的。
// 重新读取数据
mnistDs = MnistDataset(dataPath, sampler: sampler5)
// 定义数据增强算子
var resize = resize(Array<Int32>([16, 16]))
var randomCropOp = randomCrop(Array<Int32>([8, 8]), Array<Int32>([0, 0, 0, 0]))
// 进行数据增强操作
mnistDs.datasetMap([resize, rescale, randomCropOp], "image")
// 查看增强后的数据
input = parameter(initialize(Array<Int64>([8, 8]), initType: InitType.ZERO, dtype: FLOAT32), "data")
while (mnistDs.getNext([input, label])) {
print("---------------\n")
print("input: ", input)
print("label: ", label)
}
输出为:
---------------
input:
Tensor(shape=[8, 8], dtype=Float32, value=
[[ 0.00000000e+00 1.41176477e-01 4.50980425e-01 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 3.72549027e-01 1.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 0.00000000e+00 3.72549027e-01 1.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
...
[ 1.00000000e+00 1.00000000e+00 7.64705956e-01 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 1.00000000e+00 6.70588255e-01 0.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ 6.07843161e-01 0.00000000e+00 0.00000000e+00 ... 0.00000000e+00 0.00000000e+00 0.00000000e+00]])
label:
Tensor(shape=[1], dtype=Int32, value= [1])
...