add part of dataset api

This commit is contained in:
ms_yan 2021-11-22 16:38:21 +08:00
parent a78b4fd60f
commit 556f1e5bc0
9 changed files with 230 additions and 3 deletions

View File

@ -0,0 +1,24 @@
apply(apply_func)
对数据集对象执行给定操作函数。
参数:
apply_func (function):传入`Dataset`对象作为参数,并将返回处理后的`Dataset`对象。
返回:
执行了给定操作函数的数据集对象。
示例:
>>> # dataset是数据集类的实例化对象
>>>
>>> # 声明一个名为apply_func函数其返回值是一个Dataset对象
>>> def apply_func(data)
... data = data.batch(2)
... return data
>>>
>>> # 通过apply操作调用apply_func函数
>>> dataset = dataset.apply(apply_func)
异常:
TypeErrorapply_func不是一个函数。
TypeErrorapply_func未返回Dataset对象。

View File

@ -0,0 +1,57 @@
batch(batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False)
将dataset中连续`batch_size`行数据合并为一个批处理数据。
对一个批处理数据执行给定操作与对条数据进行给定操作用法一致。
对于任意列batch操作要求该列中的各条数据shape必须相同。
如果给定可执行函数`per_batch_map`,它将作用于批处理后的数据。
注:
执行`repeat``batch`操作的顺序,会影响数据批次的数量及`per_batch_map`操作。
建议在batch操作完成后执行repeat操作。
参数:
batch_size (int or function)每个批处理数据包含的条数。参数需要是int或可调用对象该对象接收1个参数即BatchInfo。
drop_remainder (bool, optional)是否删除最后一个数据条数小于批处理大小的batch默认值为False
如果为True并且最后一个批次中数据行数少于`batch_size`,则这些数据将被丢弃,不会传递给后续的操作。
num_parallel_workers (int, optional)用于进行batch操作的的线程数threads默认值为None。
per_batch_map (callable, optional):是一个以(list[Tensor], list[Tensor], ..., BatchInfo)作为输入参数的可调用对象。
每个list[Tensor]代表给定列上的一批Tensor。入参中list[Tensor]的个数应与input_columns中传入列名的数量相匹配。
该可调用对象的最后一个参数始终是BatchInfo对象。`per_batch_map`应返回(list[Tensor], list[Tensor], ...)。
其出中list[Tensor]的个数应与输入相同。如果输出列数与输入列数不一致则需要指定output_columns。
input_columns (Union[str, list[str]], optional):由输入列名组成的列表。如果`per_batch_map`不为None
列表中列名的个数应与 `per_batch_map` 中包含的列数匹配默认为None
output_columns (Union[str, list[str]], optional):当前操作所有输出列的列名列表。
如果len(input_columns) != len(output_columns),则此参数必须指定。
此列表中列名的数量必须与给定操作的输出列数相匹配默认为None输出列将与输入列具有相同的名称
column_order (Union[str, list[str]], optional):指定整个数据集对象中包含的所有列名的顺序。
如果len(input_column) != len(output_column),则此参数必须指定。
注意:这里的列名不仅仅是在`input_columns``output_columns`中指定的列。
pad_info (dict, optional):用于对给定列进行填充。例如`pad_info={"col1":([224,224],0)}`
则将列名为"col1"的列填充到大小为[224,224]的张量并用0填充缺失的值默认为None)。
python_multiprocessing (bool, optional):针对`per_batch_map`函数使用Python多进执行的方式进行调用。
如果函数计算量大开启这个选项可能会很有帮助默认值为False
返回:
批处理后的数据集对象。
示例:
>>> # 创建一个数据集对象每100条数据合并成一个批次
>>> # 如果最后一个批次数据小于给定的批次大小batch_size),则丢弃这个批次
>>> dataset = dataset.batch(100, True)
>>> # 根据批次编号调整图像大小如果是第5批则图像大小调整为(5^2, 5^2) = (25, 25)
>>> def np_resize(col, batchInfo):
... output = col.copy()
... s = (batchInfo.get_batch_num() + 1) ** 2
... index = 0
... for c in col:
... img = Image.fromarray(c.astype('uint8')).convert('RGB')
... img = img.resize((s, s), Image.ANTIALIAS)
... output[index] = np.array(img)
... index += 1
... return (output,)
>>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)

View File

@ -0,0 +1,51 @@
bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None, pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False)
依据数据中元素长度进行分桶。每个桶将在满了的时候进行元素填充和批处理操作。
对数据集中的每一条数据执行长度计算函数。然后,根据该条数据的长度和桶的边界将该数据归到特定的桶里面。
当桶中数据条数达到指定的大小`bucket_batch_sizes`时,将根据`pad_info`对桶中元素进行填充,再进行批处理。
这样每个批次都是满的但也有特殊情况每个桶的最后一个批次batch可能不满。
参数:
column_names (list[str]):传递给长度计算函数的所有列名。
bucket_boundaries (list[int])由各个桶的上边界值组成的列表必须严格递增。如果有n个边界则创建n+1个桶分配后桶的边界如下
[0, bucket_boundaries[0])[bucket_boundaries[i], bucket_boundaries[i+1])其中0<i<n-1[bucket_boundaries[n-1], inf)。
bucket_batch_sizes (list[int]):由每个桶的批次大小组成的列表,必须包含`len(bucket_boundaries)+1`个元素。
element_length_function (Callable, optional)输入包含M个参数的函数其中M等于`len(column_names)`,并返回一个整数。
如果未指定该参数,则`len(column_names)`必须为1并且该列数据第一维的shape值将用作长度默认为None
pad_info (dict, optional)有关如何对指定列进行填充的字典对象。字典中键对应要填充的列名值必须是包含2个元素的元组。
元组中第一个元素对应要填充成的shape第二个元素对应要填充的值。
如果某一列未指定将要填充后的shape和填充值则当前批次中该列上的每条数据都将填充至该批次中最长数据的长度填充值为0。
除非`pad_to_bucket_boundary`为True否则`pad_info`中任何填充shape为None的列
其每条数据长度都将被填充为当前批处理中最数据的长度。如果不需要填充,请将`pad_info`设置为None默认为None
pad_to_bucket_boundary (bool, optional)如果为True则pad_info中填充shape为None的列
其长度都会被填充至`bucket_boundary-1`长度。如果有任何元素落入最后一个桶中则将报错默认为False
drop_remainder (bool, optional)如果为True则丢弃每个桶中最后不足一个批次数据默认为False
返回:
BucketBatchByLengthDataset按长度进行分桶和批处理操作后的数据集对象。
示例:
>>> # 创建一个数据集对象,其中给定条数的数据会被组成一个批次数据
>>> # 如果最后一个批次数据小于给定的批次大小batch_size),则丢弃这个批次
>>> import numpy as np
>>> def generate_2_columns(n):
... for i in range(n):
... yield (np.array([i]), np.array([j for j in range(i + 1)]))
>>>
>>> column_names = ["col1", "col2"]
>>> dataset = ds.GeneratorDataset(generate_2_columns(8), column_names)
>>> bucket_boundaries = [5, 10]
>>> bucket_batch_sizes = [2, 1, 1]
>>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
>>> # 将对列名为"col2"的列进行填充填充后的shape为[bucket_boundaries[i]]其中i是当前正在批处理的桶的索引
>>> pad_info = {"col2": ([None], -1)}
>>> pad_to_bucket_boundary = True
>>> dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
... bucket_batch_sizes,
... element_length_function, pad_info,
... pad_to_bucket_boundary)

View File

@ -0,0 +1,25 @@
build_sentencepiece_vocab(columns, vocab_size, character_coverage, model_type, params)
用于从源数据集对象创建句子词表的函数。
参数:
columns(list[str]):指定从哪一列中获取单词。
vocab_size(int):词汇表大小。
character_coverage(int)模型涵盖的字符百分比必须介于0.98和1.0之间。
默认值如0.9995适用于具有丰富字符集的语言如日语或中文字符集1.0适用于其他字符集较小的语言,比如英语或拉丁文。
model_type(SentencePieceModel)模型类型枚举值包括unigram默认值、bpe、char及word。当类型为word时输入句子必须预先标记。
params(dict):依据原始数据内容构建祠表的附加参数,无附加参数时取值可以是空字典。
返回:
SentencePieceVocab从数据集构建的词汇表。
示例:
>>> from mindspore.dataset.text import SentencePieceModel
>>>
>>> # DE_C_INTER_SENTENCEPIECE_MODE 是一个映射字典
>>> from mindspore.dataset.text.utils import DE_C_INTER_SENTENCEPIECE_MODE
>>> dataset = ds.TextFileDataset("/path/to/sentence/piece/vocab/file", shuffle=False)
>>> dataset = dataset.build_sentencepiece_vocab(["text"], 5000, 0.9995,
... DE_C_INTER_SENTENCEPIECE_MODE[SentencePieceModel.UNIGRAM],
... {})

View File

@ -0,0 +1,3 @@
close_pool()
关闭数据集对象中的多进程池。如果您熟悉多进程库,可以将此视为进程池对象的析构函数。

View File

@ -0,0 +1,20 @@
concat(datasets)
对传入的多个数据集对象进行拼接操作。
重载“+”运算符来进行数据集对象拼接操作。
注:
用于拼接的多个数据集对象其列名、每列数据的维度rank)和类型必须相同。
参数:
datasets (Union[list, class Dataset]):与当前数据集对象拼接的数据集对象列表或单个数据集对象。
返回:
ConcatDataset拼接后的数据集对象。
示例:
>>> # 通过使用“+”运算符拼接dataset_1和dataset_2获得拼接后的数据集对象
>>> dataset = dataset_1 + dataset_2
>>> # 通过concat操作拼接dataset_1和dataset_2获得拼接后的数据集对象
>>> dataset = dataset_1.concat(dataset_2)

View File

@ -0,0 +1,22 @@
create_dict_iterator(num_epochs=-1, output_numpy=False)
基于数据集对象创建迭代器,输出数据为字典类型。
字典中列的顺序可能与数据集对象中原始顺序不同。
参数:
num_epochs (int, optional):迭代器可以迭代的最多轮次数(默认为-1迭代器可以迭代无限次
output_numpy (bool, optional)是否输出NumPy数据类型如果`output_numpy`为False
迭代器输出的每列数据类型为MindSpore.Tensor默认为False
返回:
DictIterator基于数据集对象创建的字典迭代器。
示例:
>>> # dataset是数据集类的实例化对象
>>> iterator = dataset.create_dict_iterator()
>>> for item in iterator:
... # item 是一个dict
... print(type(item))
... break
<class 'dict'>

View File

@ -0,0 +1,27 @@
create_tuple_iterator(columns=None, num_epochs=-1, output_numpy=False, do_copy=True)
基于数据集对象创建迭代器输出数据为ndarray组成的列表。
可以使用columns指定输出的所有列名及列的顺序。如果columns未指定列的顺序将保持不变。
参数:
columns (list[str], optional):用于指定列顺序的列名列表
默认为None表示所有列
num_epochs (int, optional):迭代器可以迭代的最多轮次数
(默认为-1迭代器可以迭代无限次
output_numpy (bool, optional)是否输出NumPy数据类型
如果output_numpy为False迭代器输出的每列数据类型为MindSpore.Tensor默认为False
do_copy (bool, optional)当输出数据类型为mindspore.Tensor时
通过此参数指定转换方法采用False主要考虑以获得更好的性能默认为True
返回:
TupleIterator基于数据集对象创建的元组迭代器。
示例:
>>> # dataset是数据集类的实例化对象
>>> iterator = dataset.create_tuple_iterator()
>>> for item in iterator
... # item 是一个列表
... print(type(item))
... break
<class 'list'>

View File

@ -480,7 +480,7 @@ class Dataset:
A length function is called on each row in the dataset. The row is then
bucketed based on its length and bucket boundaries. When a bucket reaches its
corresponding size specified in bucket_batch_sizes, the entire bucket will be
padded according to batch_info, and then form a batch.
padded according to pad_info, and then form a batch.
Each batch will be full, except one special case: the last batch for each bucket may not be full.
Args:
@ -1315,8 +1315,6 @@ class Dataset:
"""
Function to create a SentencePieceVocab from source dataset
Build a SentencePieceVocab from a dataset.
Args:
columns(list[str]): Column names to get words from.