api format modify
This commit is contained in:
parent
1de626df42
commit
fdb7a3b4f0
|
@ -26,6 +26,25 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. py:method:: load()
|
||||
|
||||
从给定(处理好的)路径加载数据,也可以在自己实现的Dataset类中实现这个方法。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. py:method:: process()
|
||||
|
||||
针对argoverse数据集的处理方法,基于加载上来的原始数据集创建很多子图。
|
||||
数据预处理方法主要参考:https://github.com/xk-huang/yet-another-vectornet/blob/master/dataset.py。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. py:method:: save()
|
||||
|
||||
将经过 `process` 函数处理后的数据以 numpy.npz 格式保存到磁盘中,也可以在自己实现的Dataset类中自己实现这个方法。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -212,4 +212,12 @@ mindspore.dataset.CLUEDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -39,4 +39,12 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -120,6 +120,14 @@ mindspore.dataset.Caltech101Dataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -105,6 +105,14 @@ mindspore.dataset.Caltech256Dataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -126,6 +126,14 @@ mindspore.dataset.CelebADataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -90,6 +90,14 @@ mindspore.dataset.Cifar100Dataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -94,6 +94,14 @@ mindspore.dataset.Cifar10Dataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -129,6 +129,14 @@ mindspore.dataset.CityscapesDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -157,6 +157,14 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -144,6 +144,14 @@ mindspore.dataset.DIV2KDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -134,296 +134,3 @@
|
|||
|
||||
返回:
|
||||
int,数据集的input index信息。
|
||||
|
||||
.. py:method:: map(operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16, offload=None)
|
||||
|
||||
给定一组数据增强列表,按顺序将数据增强作用在数据集对象上。
|
||||
|
||||
每个数据增强操作将数据集对象中的一个或多个数据列作为输入,将数据增强的结果输出为一个或多个数据列。
|
||||
第一个数据增强操作将 `input_columns` 中指定的列作为输入。
|
||||
如果数据增强列表中存在多个数据增强操作,则上一个数据增强的输出列将作为下一个数据增强的输入列。
|
||||
|
||||
最后一个数据增强的输出列的列名由 `output_columns` 指定,如果没有指定 `output_columns` ,输出列名与 `input_columns` 一致。
|
||||
|
||||
参数:
|
||||
- **operations** (Union[list[TensorOperation], list[functions]]) - 一组数据增强操作,支持数据集增强算子或者用户自定义的Python Callable对象。map操作将按顺序将一组数据增强作用在数据集对象上。
|
||||
- **input_columns** (Union[str, list[str]], 可选) - 第一个数据增强操作的输入数据列。此列表的长度必须与 `operations` 列表中第一个数据增强的预期输入列数相匹配。默认值:None。表示所有数据列都将传递给第一个数据增强操作。
|
||||
- **output_columns** (Union[str, list[str]], 可选) - 最后一个数据增强操作的输出数据列。如果 `input_columns` 长度不等于 `output_columns` 长度,则必须指定此参数。列表的长度必须必须与最后一个数据增强的输出列数相匹配。默认值:None,输出列将与输入列具有相同的名称。
|
||||
- **column_order** (Union[str, list[str]], 可选) - 指定传递到下一个数据集操作的数据列的顺序。如果 `input_columns` 长度不等于 `output_columns` 长度,则必须指定此参数。注意:参数的列名不限定在 `input_columns` 和 `output_columns` 中指定的列,也可以是上一个操作输出的未被处理的数据列。默认值:None,按照原输入顺序排列。
|
||||
- **num_parallel_workers** (int, 可选) - 指定map操作的多进程/多线程并发数,加快处理速度。默认值:None,将使用 `set_num_parallel_workers` 设置的并发数。
|
||||
- **python_multiprocessing** (bool, 可选) - 启用Python多进程模式加速map操作。当传入的 `operations` 计算量很大时,开启此选项可能会有较好效果。默认值:False。
|
||||
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值:None,不使用缓存。
|
||||
- **callbacks** (DSCallback, list[DSCallback], 可选) - 要调用的Dataset回调函数列表。默认值:None。
|
||||
- **max_rowsize** (int, 可选) - 指定在多进程之间复制数据时,共享内存分配的最大空间,仅当 `python_multiprocessing` 为True时,该选项有效。默认值:16,单位为MB。
|
||||
- **offload** (bool, 可选) - 是否进行异构硬件加速,详情请阅读 `数据准备异构加速 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/dataset_offload.html>`_ 。默认值:None。
|
||||
|
||||
.. note::
|
||||
- `operations` 参数接收 `TensorOperation` 类型的数据处理操作,以及用户定义的Python函数(PyFuncs)。
|
||||
- 不要将 `mindspore.nn` 和 `mindspore.ops` 或其他的网络计算算子添加到 `operations` 中。
|
||||
|
||||
返回:
|
||||
MapDataset,map操作后的数据集。
|
||||
|
||||
.. py:method:: num_classes()
|
||||
|
||||
获取数据集对象中所有样本的类别数目。
|
||||
|
||||
返回:
|
||||
int,类别的数目。
|
||||
|
||||
.. py:method:: output_shapes(estimate=False)
|
||||
|
||||
获取数据集对象中每列数据的shape。
|
||||
|
||||
参数:
|
||||
- **estimate** (bool) - 如果 `estimate` 为 False,将返回数据集第一条数据的shape。
|
||||
否则将遍历整个数据集以获取数据集的真实shape信息,其中动态变化的维度将被标记为-1(可用于动态shape数据集场景),默认值:False。
|
||||
|
||||
返回:
|
||||
list,每列数据的shape列表。
|
||||
|
||||
.. py:method:: output_types()
|
||||
|
||||
获取数据集对象中每列数据的数据类型。
|
||||
|
||||
返回:
|
||||
list,每列数据的数据类型列表。
|
||||
|
||||
.. py:method:: project(columns)
|
||||
|
||||
从数据集对象中选择需要的列,并按给定的列名的顺序进行排序,
|
||||
未指定的数据列将被丢弃。
|
||||
|
||||
参数:
|
||||
- **columns** (Union[str, list[str]]) - 要选择的数据列的列名列表。
|
||||
|
||||
返回:
|
||||
ProjectDataset,project操作后的数据集对象。
|
||||
|
||||
.. py:method:: rename(input_columns, output_columns)
|
||||
|
||||
对数据集对象按指定的列名进行重命名。
|
||||
|
||||
参数:
|
||||
- **input_columns** (Union[str, list[str]]) - 待重命名的列名列表。
|
||||
- **output_columns** (Union[str, list[str]]) - 重命名后的列名列表。
|
||||
|
||||
返回:
|
||||
RenameDataset,rename操作后的数据集对象。
|
||||
|
||||
.. py:method:: repeat(count=None)
|
||||
|
||||
重复此数据集 `count` 次。如果 `count` 为None或-1,则无限重复迭代。
|
||||
|
||||
.. note::
|
||||
repeat和batch的顺序反映了batch的数量。建议:repeat操作在batch操作之后使用。
|
||||
|
||||
参数:
|
||||
- **count** (int) - 数据集重复的次数。默认值:None。
|
||||
|
||||
返回:
|
||||
RepeatDataset,repeat操作后的数据集对象。
|
||||
|
||||
.. py:method:: reset()
|
||||
|
||||
重置下一个epoch的数据集对象。
|
||||
|
||||
.. py:method:: save(file_name, num_files=1, file_type='mindrecord')
|
||||
|
||||
将数据处理管道中正处理的数据保存为通用的数据集格式。支持的数据集格式:'mindrecord',然后可以使用'MindDataset'类来读取保存的'mindrecord'文件。
|
||||
|
||||
将数据保存为'mindrecord'格式时存在隐式类型转换。转换表展示如何执行类型转换。
|
||||
|
||||
.. list-table:: 保存为'mindrecord'格式时的隐式类型转换
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - 'dataset'类型
|
||||
- 'mindrecord'类型
|
||||
- 说明
|
||||
* - bool
|
||||
- None
|
||||
- 不支持
|
||||
* - int8
|
||||
- int32
|
||||
-
|
||||
* - uint8
|
||||
- bytes
|
||||
- 丢失维度信息
|
||||
* - int16
|
||||
- int32
|
||||
-
|
||||
* - uint16
|
||||
- int32
|
||||
-
|
||||
* - int32
|
||||
- int32
|
||||
-
|
||||
* - uint32
|
||||
- int64
|
||||
-
|
||||
* - int64
|
||||
- int64
|
||||
-
|
||||
* - uint64
|
||||
- None
|
||||
- 不支持
|
||||
* - float16
|
||||
- float32
|
||||
-
|
||||
* - float32
|
||||
- float32
|
||||
-
|
||||
* - float64
|
||||
- float64
|
||||
-
|
||||
* - string
|
||||
- string
|
||||
- 不支持多维字符串
|
||||
|
||||
.. note::
|
||||
1. 如需按顺序保存数据,将数据集的 `shuffle` 设置为False,将 `num_files` 设置为1。
|
||||
2. 在执行保存操作之前,不要使用batch操作、repeat操作或具有随机属性的数据增强的map操作。
|
||||
3. 当数据的维度可变时,只支持1维数组或者在第0维变化的多维数组。
|
||||
4. 不支持UINT64类型、多维的UINT8类型、多维STRING类型。
|
||||
|
||||
参数:
|
||||
- **file_name** (str) - 数据集文件的路径。
|
||||
- **num_files** (int, 可选) - 数据集文件的数量,默认值:1。
|
||||
- **file_type** (str, 可选) - 数据集格式,默认值:'mindrecord'。
|
||||
|
||||
.. py:method:: set_dynamic_columns(columns=None)
|
||||
|
||||
设置数据集的动态shape信息,需要在定义好完整的数据处理管道后进行设置。
|
||||
|
||||
参数:
|
||||
- **columns** (dict) - 包含数据集中每列shape信息的字典。shape[i]为 `None` 表示shape[i]的数据长度是动态的。
|
||||
|
||||
.. py:method:: shuffle(buffer_size)
|
||||
|
||||
使用以下策略混洗此数据集的行:
|
||||
|
||||
1. 生成一个混洗缓冲区包含 `buffer_size` 条数据行。
|
||||
|
||||
2. 从混洗缓冲区中随机选择一个数据行,传递给下一个操作。
|
||||
|
||||
3. 从上一个操作获取下一个数据行(如果有的话),并将其放入混洗缓冲区中。
|
||||
|
||||
4. 重复步骤2和3,直到混洗缓冲区中没有数据行为止。
|
||||
|
||||
在第一个epoch中可以通过 `dataset.config.set_seed` 来设置随机种子,在随后的每个epoch,种子都会被设置成一个新产生的随机值。
|
||||
|
||||
参数:
|
||||
- **buffer_size** (int) - 用于混洗的缓冲区大小(必须大于1)。将 `buffer_size` 设置为数据集大小将进行全局混洗。
|
||||
|
||||
返回:
|
||||
ShuffleDataset,混洗后的数据集对象。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 混洗前存在通过 `dataset.sync_wait` 进行同步操作。
|
||||
|
||||
.. py:method:: skip(count)
|
||||
|
||||
跳过此数据集对象的前 `count` 条数据。
|
||||
|
||||
参数:
|
||||
- **count** (int) - 要跳过数据的条数。
|
||||
|
||||
返回:
|
||||
SkipDataset,跳过指定条数据后的数据集对象。
|
||||
|
||||
.. py:method:: split(sizes, randomize=True)
|
||||
|
||||
将数据集拆分为多个不重叠的子数据集。
|
||||
|
||||
参数:
|
||||
- **sizes** (Union[list[int], list[float]]) - 如果指定了一列整数[s1, s2, …, sn],数据集将被拆分为n个大小为s1、s2、...、sn的数据集。如果所有输入大小的总和不等于原始数据集大小,则报错。如果指定了一列浮点数[f1, f2, …, fn],则所有浮点数必须介于0和1之间,并且总和必须为1,否则报错。数据集将被拆分为n个大小为round(f1*K)、round(f2*K)、...、round(fn*K)的数据集,其中K是原始数据集的大小。
|
||||
|
||||
如果round四舍五入计算后:
|
||||
|
||||
- 任何子数据集的的大小等于0,都将发生错误。
|
||||
- 如果子数据集大小的总和小于K,K - sigma(round(fi * k))的值将添加到第一个子数据集,sigma为求和操作。
|
||||
- 如果子数据集大小的总和大于K,sigma(round(fi * K)) - K的值将从第一个足够大的子数据集中删除,且删除后的子数据集大小至少大于1。
|
||||
|
||||
- **randomize** (bool, 可选) - 确定是否随机拆分数据,默认值:True,数据集将被随机拆分。否则将按顺序拆分为多个不重叠的子数据集。
|
||||
|
||||
.. note::
|
||||
1. 如果进行拆分操作的数据集对象为MappableDataset类型,则将自动调用一个优化后的split操作。
|
||||
2. 如果进行split操作,则不应对数据集对象进行分片操作(如指定num_shards或使用DistributerSampler)。相反,如果创建一个DistributerSampler,并在split操作拆分后的子数据集对象上进行分片操作,强烈建议在每个子数据集上设置相同的种子,否则每个分片可能不是同一个子数据集的一部分(请参见示例)。
|
||||
3. 强烈建议不要对数据集进行混洗,而是使用随机化(randomize=True)。对数据集进行混洗的结果具有不确定性,每个拆分后的子数据集中的数据在每个epoch可能都不同。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 数据集对象不支持 `get_dataset_size` 或者 `get_dataset_size` 返回None。
|
||||
- **RuntimeError** - `sizes` 是list[int],并且 `sizes` 中所有元素的总和不等于数据集大小。
|
||||
- **RuntimeError** - `sizes` 是list[float],并且计算后存在大小为0的拆分子数据集。
|
||||
- **RuntimeError** - 数据集对象在调用拆分之前已进行分片。
|
||||
- **ValueError** - `sizes` 是list[float],且并非所有float数值都在0和1之间,或者float数值的总和不等于1。
|
||||
|
||||
返回:
|
||||
tuple(Dataset),split操作后子数据集对象的元组。
|
||||
|
||||
.. py:method:: sync_update(condition_name, num_batch=None, data=None)
|
||||
|
||||
释放阻塞条件并使用给定数据触发回调函数。
|
||||
|
||||
参数:
|
||||
- **condition_name** (str) - 用于触发发送下一个数据行的条件名称。
|
||||
- **num_batch** (Union[int, None]) - 释放的batch(row)数。当 `num_batch` 为None时,将默认为 `sync_wait` 操作指定的值,默认值:None。
|
||||
- **data** (Any) - 用户自定义传递给回调函数的数据,默认值:None。
|
||||
|
||||
.. py:method:: sync_wait(condition_name, num_batch=1, callback=None)
|
||||
|
||||
为同步操作在数据集对象上添加阻塞条件。
|
||||
|
||||
参数:
|
||||
- **condition_name** (str) - 用于触发发送下一行数据的条件名称。
|
||||
- **num_batch** (int) - 每个epoch开始时无阻塞的batch数。
|
||||
- **callback** (function) - `sync_update` 操作中将调用的回调函数。
|
||||
|
||||
返回:
|
||||
SyncWaitDataset,添加了阻塞条件的数据集对象。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 条件名称已存在。
|
||||
|
||||
.. py:method:: take(count=-1)
|
||||
|
||||
从数据集中获取最多 `count` 的元素。
|
||||
|
||||
.. note::
|
||||
1. 如果 `count` 大于数据集中的数据条数或等于-1,则取数据集中的所有数据。
|
||||
2. take和batch操作顺序很重要,如果take在batch操作之前,则取给定条数,否则取给定batch数。
|
||||
|
||||
参数:
|
||||
- **count** (int, 可选) - 要从数据集对象中获取的数据条数,默认值:-1,获取所有数据。
|
||||
|
||||
返回:
|
||||
TakeDataset,take操作后的数据集对象。
|
||||
|
||||
.. py:method:: to_device(send_epoch_end=True, create_data_info_queue=False)
|
||||
|
||||
将数据从CPU传输到GPU、Ascend或其他设备。
|
||||
|
||||
参数:
|
||||
- **send_epoch_end** (bool, 可选) - 是否将epoch结束符 `end_of_sequence` 发送到设备,默认值:True。
|
||||
- **create_data_info_queue** (bool, 可选) - 是否创建存储数据类型和shape的队列,默认值:False。
|
||||
|
||||
.. note::
|
||||
该接口在将来会被删除或不可见,建议使用 `device_queue` 接口。
|
||||
如果设备为Ascend,则逐个传输数据。每次数据传输的限制为256M。
|
||||
|
||||
返回:
|
||||
TransferDataset,用于传输的数据集对象。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 如果提供了分布式训练的文件路径但读取失败。
|
||||
|
||||
.. py:method:: to_json(filename='')
|
||||
|
||||
将数据处理管道序列化为JSON字符串,如果提供了文件名,则转储到文件中。
|
||||
|
||||
参数:
|
||||
- **filename** (str) - 保存JSON文件的路径(包含文件名)。
|
||||
|
||||
返回:
|
||||
str,数据处理管道序列化后的JSON字符串。
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
.. py:method:: map(operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16, offload=None)
|
||||
|
||||
给定一组数据增强列表,按顺序将数据增强作用在数据集对象上。
|
||||
|
||||
每个数据增强操作将数据集对象中的一个或多个数据列作为输入,将数据增强的结果输出为一个或多个数据列。
|
||||
第一个数据增强操作将 `input_columns` 中指定的列作为输入。
|
||||
如果数据增强列表中存在多个数据增强操作,则上一个数据增强的输出列将作为下一个数据增强的输入列。
|
||||
|
||||
最后一个数据增强的输出列的列名由 `output_columns` 指定,如果没有指定 `output_columns` ,输出列名与 `input_columns` 一致。
|
||||
|
||||
参数:
|
||||
- **operations** (Union[list[TensorOperation], list[functions]]) - 一组数据增强操作,支持数据集增强算子或者用户自定义的Python Callable对象。map操作将按顺序将一组数据增强作用在数据集对象上。
|
||||
- **input_columns** (Union[str, list[str]], 可选) - 第一个数据增强操作的输入数据列。此列表的长度必须与 `operations` 列表中第一个数据增强的预期输入列数相匹配。默认值:None。表示所有数据列都将传递给第一个数据增强操作。
|
||||
- **output_columns** (Union[str, list[str]], 可选) - 最后一个数据增强操作的输出数据列。如果 `input_columns` 长度不等于 `output_columns` 长度,则必须指定此参数。列表的长度必须必须与最后一个数据增强的输出列数相匹配。默认值:None,输出列将与输入列具有相同的名称。
|
||||
- **column_order** (Union[str, list[str]], 可选) - 指定传递到下一个数据集操作的数据列的顺序。如果 `input_columns` 长度不等于 `output_columns` 长度,则必须指定此参数。注意:参数的列名不限定在 `input_columns` 和 `output_columns` 中指定的列,也可以是上一个操作输出的未被处理的数据列。默认值:None,按照原输入顺序排列。
|
||||
- **num_parallel_workers** (int, 可选) - 指定map操作的多进程/多线程并发数,加快处理速度。默认值:None,将使用 `set_num_parallel_workers` 设置的并发数。
|
||||
- **python_multiprocessing** (bool, 可选) - 启用Python多进程模式加速map操作。当传入的 `operations` 计算量很大时,开启此选项可能会有较好效果。默认值:False。
|
||||
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值:None,不使用缓存。
|
||||
- **callbacks** (DSCallback, list[DSCallback], 可选) - 要调用的Dataset回调函数列表。默认值:None。
|
||||
- **max_rowsize** (int, 可选) - 指定在多进程之间复制数据时,共享内存分配的最大空间,仅当 `python_multiprocessing` 为True时,该选项有效。默认值:16,单位为MB。
|
||||
- **offload** (bool, 可选) - 是否进行异构硬件加速,详情请阅读 `数据准备异构加速 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/dataset_offload.html>`_ 。默认值:None。
|
||||
|
||||
.. note::
|
||||
- `operations` 参数接收 `TensorOperation` 类型的数据处理操作,以及用户定义的Python函数(PyFuncs)。
|
||||
- 不要将 `mindspore.nn` 和 `mindspore.ops` 或其他的网络计算算子添加到 `operations` 中。
|
||||
|
||||
返回:
|
||||
MapDataset,map操作后的数据集。
|
||||
|
||||
.. py:method:: num_classes()
|
||||
|
||||
获取数据集对象中所有样本的类别数目。
|
||||
|
||||
返回:
|
||||
int,类别的数目。
|
||||
|
||||
.. py:method:: output_shapes(estimate=False)
|
||||
|
||||
获取数据集对象中每列数据的shape。
|
||||
|
||||
参数:
|
||||
- **estimate** (bool) - 如果 `estimate` 为 False,将返回数据集第一条数据的shape。
|
||||
否则将遍历整个数据集以获取数据集的真实shape信息,其中动态变化的维度将被标记为-1(可用于动态shape数据集场景),默认值:False。
|
||||
|
||||
返回:
|
||||
list,每列数据的shape列表。
|
||||
|
||||
.. py:method:: output_types()
|
||||
|
||||
获取数据集对象中每列数据的数据类型。
|
||||
|
||||
返回:
|
||||
list,每列数据的数据类型列表。
|
|
@ -0,0 +1,39 @@
|
|||
.. py:method:: project(columns)
|
||||
|
||||
从数据集对象中选择需要的列,并按给定的列名的顺序进行排序,
|
||||
未指定的数据列将被丢弃。
|
||||
|
||||
参数:
|
||||
- **columns** (Union[str, list[str]]) - 要选择的数据列的列名列表。
|
||||
|
||||
返回:
|
||||
ProjectDataset,project操作后的数据集对象。
|
||||
|
||||
.. py:method:: rename(input_columns, output_columns)
|
||||
|
||||
对数据集对象按指定的列名进行重命名。
|
||||
|
||||
参数:
|
||||
- **input_columns** (Union[str, list[str]]) - 待重命名的列名列表。
|
||||
- **output_columns** (Union[str, list[str]]) - 重命名后的列名列表。
|
||||
|
||||
返回:
|
||||
RenameDataset,rename操作后的数据集对象。
|
||||
|
||||
.. py:method:: repeat(count=None)
|
||||
|
||||
重复此数据集 `count` 次。如果 `count` 为None或-1,则无限重复迭代。
|
||||
|
||||
.. note::
|
||||
repeat和batch的顺序反映了batch的数量。建议:repeat操作在batch操作之后使用。
|
||||
|
||||
参数:
|
||||
- **count** (int) - 数据集重复的次数。默认值:None。
|
||||
|
||||
返回:
|
||||
RepeatDataset,repeat操作后的数据集对象。
|
||||
|
||||
.. py:method:: reset()
|
||||
|
||||
重置下一个epoch的数据集对象。
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
.. py:method:: set_dynamic_columns(columns=None)
|
||||
|
||||
设置数据集的动态shape信息,需要在定义好完整的数据处理管道后进行设置。
|
||||
|
||||
参数:
|
||||
- **columns** (dict) - 包含数据集中每列shape信息的字典。shape[i]为 `None` 表示shape[i]的数据长度是动态的。
|
||||
|
||||
.. py:method:: shuffle(buffer_size)
|
||||
|
||||
使用以下策略混洗此数据集的行:
|
||||
|
||||
1. 生成一个混洗缓冲区包含 `buffer_size` 条数据行。
|
||||
|
||||
2. 从混洗缓冲区中随机选择一个数据行,传递给下一个操作。
|
||||
|
||||
3. 从上一个操作获取下一个数据行(如果有的话),并将其放入混洗缓冲区中。
|
||||
|
||||
4. 重复步骤2和3,直到混洗缓冲区中没有数据行为止。
|
||||
|
||||
在第一个epoch中可以通过 `dataset.config.set_seed` 来设置随机种子,在随后的每个epoch,种子都会被设置成一个新产生的随机值。
|
||||
|
||||
参数:
|
||||
- **buffer_size** (int) - 用于混洗的缓冲区大小(必须大于1)。将 `buffer_size` 设置为数据集大小将进行全局混洗。
|
||||
|
||||
返回:
|
||||
ShuffleDataset,混洗后的数据集对象。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 混洗前存在通过 `dataset.sync_wait` 进行同步操作。
|
||||
|
||||
.. py:method:: skip(count)
|
||||
|
||||
跳过此数据集对象的前 `count` 条数据。
|
||||
|
||||
参数:
|
||||
- **count** (int) - 要跳过数据的条数。
|
||||
|
||||
返回:
|
||||
SkipDataset,跳过指定条数据后的数据集对象。
|
||||
|
||||
.. py:method:: split(sizes, randomize=True)
|
||||
|
||||
将数据集拆分为多个不重叠的子数据集。
|
||||
|
||||
参数:
|
||||
- **sizes** (Union[list[int], list[float]]) - 如果指定了一列整数[s1, s2, …, sn],数据集将被拆分为n个大小为s1、s2、...、sn的数据集。如果所有输入大小的总和不等于原始数据集大小,则报错。如果指定了一列浮点数[f1, f2, …, fn],则所有浮点数必须介于0和1之间,并且总和必须为1,否则报错。数据集将被拆分为n个大小为round(f1*K)、round(f2*K)、...、round(fn*K)的数据集,其中K是原始数据集的大小。
|
||||
|
||||
如果round四舍五入计算后:
|
||||
|
||||
- 任何子数据集的的大小等于0,都将发生错误。
|
||||
- 如果子数据集大小的总和小于K,K - sigma(round(fi * k))的值将添加到第一个子数据集,sigma为求和操作。
|
||||
- 如果子数据集大小的总和大于K,sigma(round(fi * K)) - K的值将从第一个足够大的子数据集中删除,且删除后的子数据集大小至少大于1。
|
||||
|
||||
- **randomize** (bool, 可选) - 确定是否随机拆分数据,默认值:True,数据集将被随机拆分。否则将按顺序拆分为多个不重叠的子数据集。
|
||||
|
||||
.. note::
|
||||
1. 如果进行拆分操作的数据集对象为MappableDataset类型,则将自动调用一个优化后的split操作。
|
||||
2. 如果进行split操作,则不应对数据集对象进行分片操作(如指定num_shards或使用DistributerSampler)。相反,如果创建一个DistributerSampler,并在split操作拆分后的子数据集对象上进行分片操作,强烈建议在每个子数据集上设置相同的种子,否则每个分片可能不是同一个子数据集的一部分(请参见示例)。
|
||||
3. 强烈建议不要对数据集进行混洗,而是使用随机化(randomize=True)。对数据集进行混洗的结果具有不确定性,每个拆分后的子数据集中的数据在每个epoch可能都不同。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 数据集对象不支持 `get_dataset_size` 或者 `get_dataset_size` 返回None。
|
||||
- **RuntimeError** - `sizes` 是list[int],并且 `sizes` 中所有元素的总和不等于数据集大小。
|
||||
- **RuntimeError** - `sizes` 是list[float],并且计算后存在大小为0的拆分子数据集。
|
||||
- **RuntimeError** - 数据集对象在调用拆分之前已进行分片。
|
||||
- **ValueError** - `sizes` 是list[float],且并非所有float数值都在0和1之间,或者float数值的总和不等于1。
|
||||
|
||||
返回:
|
||||
tuple(Dataset),split操作后子数据集对象的元组。
|
||||
|
||||
.. py:method:: sync_update(condition_name, num_batch=None, data=None)
|
||||
|
||||
释放阻塞条件并使用给定数据触发回调函数。
|
||||
|
||||
参数:
|
||||
- **condition_name** (str) - 用于触发发送下一个数据行的条件名称。
|
||||
- **num_batch** (Union[int, None]) - 释放的batch(row)数。当 `num_batch` 为None时,将默认为 `sync_wait` 操作指定的值,默认值:None。
|
||||
- **data** (Any) - 用户自定义传递给回调函数的数据,默认值:None。
|
||||
|
||||
.. py:method:: sync_wait(condition_name, num_batch=1, callback=None)
|
||||
|
||||
为同步操作在数据集对象上添加阻塞条件。
|
||||
|
||||
参数:
|
||||
- **condition_name** (str) - 用于触发发送下一行数据的条件名称。
|
||||
- **num_batch** (int) - 每个epoch开始时无阻塞的batch数。
|
||||
- **callback** (function) - `sync_update` 操作中将调用的回调函数。
|
||||
|
||||
返回:
|
||||
SyncWaitDataset,添加了阻塞条件的数据集对象。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 条件名称已存在。
|
||||
|
||||
.. py:method:: take(count=-1)
|
||||
|
||||
从数据集中获取最多 `count` 的元素。
|
||||
|
||||
.. note::
|
||||
1. 如果 `count` 大于数据集中的数据条数或等于-1,则取数据集中的所有数据。
|
||||
2. take和batch操作顺序很重要,如果take在batch操作之前,则取给定条数,否则取给定batch数。
|
||||
|
||||
参数:
|
||||
- **count** (int, 可选) - 要从数据集对象中获取的数据条数,默认值:-1,获取所有数据。
|
||||
|
||||
返回:
|
||||
TakeDataset,take操作后的数据集对象。
|
||||
|
||||
.. py:method:: to_device(send_epoch_end=True, create_data_info_queue=False)
|
||||
|
||||
将数据从CPU传输到GPU、Ascend或其他设备。
|
||||
|
||||
参数:
|
||||
- **send_epoch_end** (bool, 可选) - 是否将epoch结束符 `end_of_sequence` 发送到设备,默认值:True。
|
||||
- **create_data_info_queue** (bool, 可选) - 是否创建存储数据类型和shape的队列,默认值:False。
|
||||
|
||||
.. note::
|
||||
该接口在将来会被删除或不可见,建议使用 `device_queue` 接口。
|
||||
如果设备为Ascend,则逐个传输数据。每次数据传输的限制为256M。
|
||||
|
||||
返回:
|
||||
TransferDataset,用于传输的数据集对象。
|
||||
|
||||
异常:
|
||||
- **RuntimeError** - 如果提供了分布式训练的文件路径但读取失败。
|
||||
|
||||
.. py:method:: to_json(filename='')
|
||||
|
||||
将数据处理管道序列化为JSON字符串,如果提供了文件名,则转储到文件中。
|
||||
|
||||
参数:
|
||||
- **filename** (str) - 保存JSON文件的路径(包含文件名)。
|
||||
|
||||
返回:
|
||||
str,数据处理管道序列化后的JSON字符串。
|
|
@ -0,0 +1,63 @@
|
|||
.. py:method:: save(file_name, num_files=1, file_type='mindrecord')
|
||||
|
||||
将数据处理管道中正处理的数据保存为通用的数据集格式。支持的数据集格式:'mindrecord',然后可以使用'MindDataset'类来读取保存的'mindrecord'文件。
|
||||
|
||||
将数据保存为'mindrecord'格式时存在隐式类型转换。转换表展示如何执行类型转换。
|
||||
|
||||
.. list-table:: 保存为'mindrecord'格式时的隐式类型转换
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - 'dataset'类型
|
||||
- 'mindrecord'类型
|
||||
- 说明
|
||||
* - bool
|
||||
- None
|
||||
- 不支持
|
||||
* - int8
|
||||
- int32
|
||||
-
|
||||
* - uint8
|
||||
- bytes
|
||||
- 丢失维度信息
|
||||
* - int16
|
||||
- int32
|
||||
-
|
||||
* - uint16
|
||||
- int32
|
||||
-
|
||||
* - int32
|
||||
- int32
|
||||
-
|
||||
* - uint32
|
||||
- int64
|
||||
-
|
||||
* - int64
|
||||
- int64
|
||||
-
|
||||
* - uint64
|
||||
- None
|
||||
- 不支持
|
||||
* - float16
|
||||
- float32
|
||||
-
|
||||
* - float32
|
||||
- float32
|
||||
-
|
||||
* - float64
|
||||
- float64
|
||||
-
|
||||
* - string
|
||||
- string
|
||||
- 不支持多维字符串
|
||||
|
||||
.. note::
|
||||
1. 如需按顺序保存数据,将数据集的 `shuffle` 设置为False,将 `num_files` 设置为1。
|
||||
2. 在执行保存操作之前,不要使用batch操作、repeat操作或具有随机属性的数据增强的map操作。
|
||||
3. 当数据的维度可变时,只支持1维数组或者在第0维变化的多维数组。
|
||||
4. 不支持UINT64类型、多维的UINT8类型、多维STRING类型。
|
||||
|
||||
参数:
|
||||
- **file_name** (str) - 数据集文件的路径。
|
||||
- **num_files** (int, 可选) - 数据集文件的数量,默认值:1。
|
||||
- **file_type** (str, 可选) - 数据集格式,默认值:'mindrecord'。
|
|
@ -77,6 +77,14 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -27,6 +27,12 @@ mindspore.dataset.Graph
|
|||
- **auto_shutdown** (bool, 可选) - 当工作模式设置为 'server' 时有效。当连接的客户端数量达到 `num_client` ,且没有客户端正在连接时,服务器将自动退出,默认值:True。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `edges` 不是list或NumPy array类型。
|
||||
- **TypeError** - 如果提供了 `node_feat` 但不是dict类型, 或者dict中的key不是string类型, 或者dict中的value不是NumPy array类型。
|
||||
- **TypeError** - 如果提供了 `edge_feat` 但不是dict类型, 或者dict中的key不是string类型, 或者dict中的value不是NumPy array类型。
|
||||
- **TypeError** - 如果提供了 `graph_feat` 但不是dict类型, 或者dict中的key不是string类型, 或者dict中的value不是NumPy array类型。
|
||||
- **TypeError** - 如果提供了 `node_type` 但不是list或NumPy array类型。
|
||||
- **TypeError** - 如果提供了 `edge_type` 但不是list或 NumPy array类型。
|
||||
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
|
||||
- **ValueError** - `working_mode` 参数取值不为'local', 'client' 或 'server'。
|
||||
- **TypeError** - `hostname` 参数类型错误。
|
||||
|
@ -187,6 +193,19 @@ mindspore.dataset.Graph
|
|||
异常:
|
||||
- **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。
|
||||
|
||||
.. py:method:: get_graph_feature(edge_list, feature_types)
|
||||
|
||||
依据给定的 `feature_types` 获取存储在Graph中对应的特征。
|
||||
|
||||
参数:
|
||||
- **feature_types** (Union[list, numpy.ndarray]) - 包含给定特征类型的列表,列表中每个元素是string类型。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含特征的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
|
||||
|
||||
.. py:method:: get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type)
|
||||
|
||||
获取 `node_list` 列表中节所有点的负样本相邻节点,以 `neg_neighbor_type` 类型返回。
|
||||
|
@ -219,7 +238,6 @@ mindspore.dataset.Graph
|
|||
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
|
||||
- **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
|
||||
|
||||
|
||||
.. py:method:: get_nodes_from_edges(edge_list)
|
||||
|
||||
从图中的边获取节点。
|
||||
|
|
|
@ -90,6 +90,14 @@ mindspore.dataset.ImageFolderDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -19,18 +19,6 @@
|
|||
- **python_multiprocessing** (bool,可选) - 启用Python多进程模式加速运算,默认值:True。当传入 `source` 的Python对象的计算量很大时,开启此选项可能会有较好效果。
|
||||
- **max_rowsize** (int,可选) - 指定在多进程之间复制数据时,共享内存分配的最大空间,默认值:6,单位为MB。仅当参数 `python_multiprocessing` 设为True时,此参数才会生效。
|
||||
|
||||
.. py:method:: process()
|
||||
|
||||
与原始数据集相关的处理方法,建议在自定义的Dataset中重写此方法。
|
||||
|
||||
.. py:method:: save()
|
||||
|
||||
将经过 `process` 函数处理后的数据以 numpy.npz 格式保存到磁盘中,也可以在自己实现的Dataset类中自己实现这个方法。
|
||||
|
||||
.. py:method:: load()
|
||||
|
||||
从给定(处理好的)路径加载数据,也可以在自己实现的Dataset类中实现这个方法。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.add_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.rst
|
||||
|
@ -41,6 +29,24 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. py:method:: load()
|
||||
|
||||
从给定(处理好的)路径加载数据,也可以在自己实现的Dataset类中实现这个方法。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. py:method:: process()
|
||||
|
||||
与原始数据集相关的处理方法,建议在自定义的Dataset中重写此方法。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. py:method:: save()
|
||||
|
||||
将经过 `process` 函数处理后的数据以 numpy.npz 格式保存到磁盘中,也可以在自己实现的Dataset类中自己实现这个方法。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -68,6 +68,14 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -70,6 +70,14 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -91,6 +91,14 @@ mindspore.dataset.MnistDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -66,6 +66,14 @@ mindspore.dataset.NumpySlicesDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -48,4 +48,12 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -23,6 +23,14 @@ mindspore.dataset.PaddedDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -43,4 +43,12 @@ mindspore.dataset.TFRecordDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
||||
|
|
|
@ -35,4 +35,12 @@
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -131,6 +131,14 @@ mindspore.dataset.VOCDataset
|
|||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.e.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.f.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.save.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.g.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -120,9 +120,9 @@ class GraphData:
|
|||
|
||||
Examples:
|
||||
>>> graph_dataset_dir = "/path/to/graph_dataset_file"
|
||||
>>> graph_dataset = ds.GraphData(dataset_file=graph_dataset_dir, num_parallel_workers=2)
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[1])
|
||||
>>> graph_data = ds.GraphData(dataset_file=graph_dataset_dir, num_parallel_workers=2)
|
||||
>>> nodes = graph_data.get_all_nodes(node_type=1)
|
||||
>>> features = graph_data.get_node_feature(node_list=nodes, feature_types=[1])
|
||||
"""
|
||||
|
||||
@check_gnn_graphdata
|
||||
|
@ -161,7 +161,7 @@ class GraphData:
|
|||
numpy.ndarray, array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> nodes = graph_data.get_all_nodes(node_type=1)
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_type` is not integer.
|
||||
|
@ -182,7 +182,7 @@ class GraphData:
|
|||
numpy.ndarray, array of edges.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_all_edges(edge_type=0)
|
||||
>>> edges = graph_data.get_all_edges(edge_type=0)
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_type` is not integer.
|
||||
|
@ -221,7 +221,7 @@ class GraphData:
|
|||
numpy.ndarray, array of edges ID.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
|
||||
>>> edges = graph_data.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_list` is not list or ndarray.
|
||||
|
@ -335,12 +335,12 @@ class GraphData:
|
|||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.engine import OutputFormat
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)
|
||||
>>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
|
||||
... output_format=OutputFormat.COO)
|
||||
>>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
|
||||
... output_format=OutputFormat.CSR)
|
||||
>>> nodes = graph_data.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_data.get_all_neighbors(node_list=nodes, neighbor_type=2)
|
||||
>>> neighbors_coo = graph_data.get_all_neighbors(node_list=nodes, neighbor_type=2,
|
||||
... output_format=OutputFormat.COO)
|
||||
>>> offset_table, neighbors_csr = graph_data.get_all_neighbors(node_list=nodes, neighbor_type=2,
|
||||
... output_format=OutputFormat.CSR)
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -381,9 +381,9 @@ class GraphData:
|
|||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
|
||||
... neighbor_types=[2, 1])
|
||||
>>> nodes = graph_data.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_data.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
|
||||
... neighbor_types=[2, 1])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -411,9 +411,9 @@ class GraphData:
|
|||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
|
||||
... neg_neighbor_type=2)
|
||||
>>> nodes = graph_data.get_all_nodes(node_type=1)
|
||||
>>> neg_neighbors = graph_data.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
|
||||
... neg_neighbor_type=2)
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -438,8 +438,8 @@ class GraphData:
|
|||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[2, 3])
|
||||
>>> nodes = graph_data.get_all_nodes(node_type=1)
|
||||
>>> features = graph_data.get_node_feature(node_list=nodes, feature_types=[2, 3])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -467,8 +467,8 @@ class GraphData:
|
|||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_all_edges(edge_type=0)
|
||||
>>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=[1])
|
||||
>>> edges = graph_data.get_all_edges(edge_type=0)
|
||||
>>> features = graph_data.get_edge_feature(edge_list=edges, feature_types=[1])
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_list` is not list or ndarray.
|
||||
|
@ -513,8 +513,8 @@ class GraphData:
|
|||
numpy.ndarray, array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> walks = graph_dataset.random_walk(target_nodes=nodes, meta_path=[2, 1, 2])
|
||||
>>> nodes = graph_data.get_all_nodes(node_type=1)
|
||||
>>> walks = graph_data.random_walk(target_nodes=nodes, meta_path=[2, 1, 2])
|
||||
|
||||
Raises:
|
||||
TypeError: If `target_nodes` is not list or ndarray.
|
||||
|
@ -575,23 +575,39 @@ class Graph(GraphData):
|
|||
when the number of connected clients reaches num_client and no client is being connected,
|
||||
the server automatically exits (default=True).
|
||||
|
||||
Raises:
|
||||
TypeError: If `edges` not list or NumPy array.
|
||||
TypeError: If `node_feat` provided but not dict, or key in dict is not string type, or value in dict not NumPy
|
||||
array.
|
||||
TypeError: If `edge_feat` provided but not dict, or key in dict is not string type, or value in dict not NumPy
|
||||
array.
|
||||
TypeError: If `graph_feat` provided but not dict, or key in dict is not string type, or value in dict not NumPy
|
||||
array.
|
||||
TypeError: If `node_type` provided but its type not list or NumPy array.
|
||||
TypeError: If `edge_type` provided but its type not list or NumPy array.
|
||||
ValueError: If `num_parallel_workers` exceeds the max thread numbers.
|
||||
ValueError: If `working_mode` is not 'local', 'client' or 'server'.
|
||||
TypeError: If `hostname` is illegal.
|
||||
ValueError: If `port` is not in range [1024, 65535].
|
||||
ValueError: If `num_client` is not in range [1, 255].
|
||||
|
||||
Examples:
|
||||
>> # 1) Only provide edges for creating graph, as this is the only required input parameter
|
||||
>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>> g = Graph(edges)
|
||||
>> graph_info = g.graph_info()
|
||||
>>
|
||||
>> # 2) Setting node_feat and edge_feat for corresponding node and edge
|
||||
>> # first dimension of feature shape should be corrsponding node num or edge num.
|
||||
>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>> node_feat = {"node_feature_1": np.array([[0], [1], [2]], dtype=np.int32)}
|
||||
>> edge_feat = {"edge_feature_1": np.array([[1, 2], [3, 4]], dtype=np.int32)}
|
||||
>> g = Graph(edges, node_feat, edge_feat)
|
||||
>>
|
||||
>> # 3) Setting graph feature for graph, there is shape limit for graph feature
|
||||
>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>> graph_feature = {"graph_feature_1": np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)}
|
||||
>> g = Graph(edges, graph_feat=graph_feature)
|
||||
>>> # 1) Only provide edges for creating graph, as this is the only required input parameter
|
||||
>>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>>> graph = Graph(edges)
|
||||
>>> graph_info = g.graph_info()
|
||||
>>>
|
||||
>>> # 2) Setting node_feat and edge_feat for corresponding node and edge
|
||||
>>> # first dimension of feature shape should be corresponding node num or edge num.
|
||||
>>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>>> node_feat = {"node_feature_1": np.array([[0], [1], [2]], dtype=np.int32)}
|
||||
>>> edge_feat = {"edge_feature_1": np.array([[1, 2], [3, 4]], dtype=np.int32)}
|
||||
>>> graph = Graph(edges, node_feat, edge_feat)
|
||||
>>>
|
||||
>>> # 3) Setting graph feature for graph, there is no shape limit for graph feature
|
||||
>>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>>> graph_feature = {"graph_feature_1": np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)}
|
||||
>>> graph = Graph(edges, graph_feat=graph_feature)
|
||||
"""
|
||||
|
||||
@check_gnn_graph
|
||||
|
@ -657,7 +673,7 @@ class Graph(GraphData):
|
|||
numpy.ndarray, array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type="0")
|
||||
>>> nodes = graph.get_all_nodes(node_type="0")
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_type` is not string.
|
||||
|
@ -684,7 +700,7 @@ class Graph(GraphData):
|
|||
numpy.ndarray, array of edges.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_all_edges(edge_type='0')
|
||||
>>> edges = graph.get_all_edges(edge_type='0')
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_type` is not string.
|
||||
|
@ -803,12 +819,12 @@ class Graph(GraphData):
|
|||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.engine import OutputFormat
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type='0')
|
||||
>>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type='0',
|
||||
... output_format=OutputFormat.COO)
|
||||
>>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type='0',
|
||||
... output_format=OutputFormat.CSR)
|
||||
>>> nodes = graph.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph.get_all_neighbors(node_list=nodes, neighbor_type='0')
|
||||
>>> neighbors_coo = graph.get_all_neighbors(node_list=nodes, neighbor_type='0',
|
||||
... output_format=OutputFormat.COO)
|
||||
>>> offset_table, neighbors_csr = graph.get_all_neighbors(node_list=nodes, neighbor_type='0',
|
||||
... output_format=OutputFormat.CSR)
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -853,9 +869,9 @@ class Graph(GraphData):
|
|||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
|
||||
... neighbor_types=[2, 1])
|
||||
>>> nodes = graph.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
|
||||
... neighbor_types=[2, 1])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -890,9 +906,9 @@ class Graph(GraphData):
|
|||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
|
||||
... neg_neighbor_type='0')
|
||||
>>> nodes = graph.get_all_nodes(node_type=1)
|
||||
>>> neg_neighbors = graph.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
|
||||
... neg_neighbor_type='0')
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -921,8 +937,8 @@ class Graph(GraphData):
|
|||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type='0')
|
||||
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=["feature_1", "feature_2"])
|
||||
>>> nodes = graph.get_all_nodes(node_type='0')
|
||||
>>> features = graph.get_node_feature(node_list=nodes, feature_types=["feature_1", "feature_2"])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -957,8 +973,8 @@ class Graph(GraphData):
|
|||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_all_edges(edge_type='0')
|
||||
>>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=["feature_1"])
|
||||
>>> edges = graph.get_all_edges(edge_type='0')
|
||||
>>> features = graph.get_edge_feature(edge_list=edges, feature_types=["feature_1"])
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_list` is not list or ndarray.
|
||||
|
@ -983,7 +999,7 @@ class Graph(GraphData):
|
|||
@check_gnn_get_graph_feature
|
||||
def get_graph_feature(self, feature_types):
|
||||
"""
|
||||
Get `feature_types` feature of the nodes in `node_list`.
|
||||
Get `feature_types` feature that stored in Graph feature level.
|
||||
|
||||
Args:
|
||||
feature_types (Union[list, numpy.ndarray]): The given list of feature types, each element should be string.
|
||||
|
@ -992,10 +1008,9 @@ class Graph(GraphData):
|
|||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> features = graph_dataset.get_graph_feature(feature_types=['feature_1', 'feature_2'])
|
||||
>>> features = graph.get_graph_feature(feature_types=['feature_1', 'feature_2'])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `feature_types` is not list or ndarray.
|
||||
"""
|
||||
if self._working_mode in ['server']:
|
||||
|
@ -1349,7 +1364,8 @@ class ArgoverseDataset(InMemoryGraphDataset):
|
|||
|
||||
def process(self):
|
||||
"""
|
||||
process method mainly refers to: https://github.com/xk-huang/yet-another-vectornet/blob/master/dataset.py
|
||||
Process method for argoverse dataset, here we load original dataset and create a lot of graphs based on it.
|
||||
Pre-processed method mainly refers to: https://github.com/xk-huang/yet-another-vectornet/blob/master/dataset.py.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
|
|
Loading…
Reference in New Issue