update api format

This commit is contained in:
yingchen 2021-11-26 19:07:07 +08:00
parent 1b0a82fc30
commit 4d46962b11
14 changed files with 561 additions and 526 deletions

View File

@ -1,29 +1,27 @@
add_child(sampler)
.. py:method:: add_child(sampler)
为给定采样器添加子采样器。子采样器将接收父采样器输出的所有数据,并应用其采样逻辑返回新的采样。
为给定采样器添加子采样器。子采样器将接收父采样器输出的所有数据,并应用其采样逻辑返回新的采样。
**参数:**
- **sampler** (Sampler)用于从数据集中选择样本的对象。仅支持内置采样器DistributedSampler、PKSampler、RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler
**参数:**
**示例:**
>>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
>>> sampler.add_child(ds.RandomSampler(num_samples=2))
>>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler)
**sampler** (Sampler)用于从数据集中选择样本的对象。仅支持内置采样器DistributedSampler、PKSampler、RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler
get_child()
获取给定采样器的子采样器。
**样例:**
get_num_samples()
>>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
>>> sampler.add_child(ds.RandomSampler(num_samples=2))
>>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler)
所有采样器都可以包含num_samples数值也可以将其设置为None
子采样器可以存在也可以为None。
如果存在子采样器则子采样器计数可以是数值或None。
这些条件会影响最终的采样结果。
下表显示了调用此函数的可能结果。
.. py:method:: get_child()
获取给定采样器的子采样器。
.. py:method:: get_num_samples()
所有采样器都可以包含num_samples数值也可以将其设置为None。子采样器可以存在也可以为None。如果存在子采样器则子采样器计数可以是数值或None。这些条件会影响最终的采样结果。
下表显示了调用此函数的可能结果。
.. list-table::
:widths: 25 25 25 25
@ -58,5 +56,6 @@ get_num_samples()
- n/a
- None
**返回:**
int样本数可为None。
**返回:**
int样本数可为None。

View File

@ -1,70 +1,75 @@
Class mindspore.dataset.CLUEDataset(dataset_files, task='AFQMC', usage='train', num_samples=None, num_parallel_workers=None, shuffle=<Shuffle.GLOBAL: 'global'>, num_shards=None, shard_id=None, cache=None)
mindspore.dataset.CLUEDataset
=============================
读取和解析CLUE数据集的源数据集文件。
目前支持的CLUE分类任务包括`AFQMC``Tnews``IFLYTEK``CMNLI``WSC``CSL`
.. py:class:: mindspore.dataset.CLUEDataset(dataset_files, task='AFQMC', usage='train', num_samples=None, num_parallel_workers=None, shuffle=<Shuffle.GLOBAL: 'global'>, num_shards=None, shard_id=None, cache=None)
根据给定的`task`配置,数据集会生成不同的输出列:
读取和解析CLUE数据集的源数据集文件。目前支持的CLUE分类任务包括 `AFQMC``Tnews``IFLYTEK``CMNLI``WSC``CSL`
- task = :py:obj:`AFQMC`
- usage = :py:obj:`train`,输出列: :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`.
- usage = :py:obj:`test`,输出列: :py:obj:`[id, dtype=uint8]`, :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`.
- usage = :py:obj:`eval`,输出列: :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`.
根据给定的 `task` 配置,数据集会生成不同的输出列:
- task = :py:obj:`TNEWS`
- usage = :py:obj:`train`,输出列: :py:obj:`[label, dtype=string]`, :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`.
- usage = :py:obj:`test`,输出列: :py:obj:`[label, dtype=string]`, :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`.
- usage = :py:obj:`eval`,输出列: :py:obj:`[label, dtype=string]`, :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`, :py:obj:`[keywords, dtype=string]`.
- task = `AFQMC`
- usage = `train`,输出列: `[sentence1, dtype=string]`, `[sentence2, dtype=string]`, `[label, dtype=string]`.
- usage = `test`,输出列: `[id, dtype=uint8]`, `[sentence1, dtype=string]`, `[sentence2, dtype=string]`.
- usage = `eval`,输出列: `[sentence1, dtype=string]`, `[sentence2, dtype=string]`, `[label, dtype=string]`.
- task = :py:obj:`IFLYTEK`
- usage = :py:obj:`train`,输出列: :py:obj:`[label, dtype=string]`, :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`.
- usage = :py:obj:`test`,输出列: :py:obj:`[id, dtype=string]`, :py:obj:`[sentence, dtype=string]`.
- usage = :py:obj:`eval`,输出列: :py:obj:`[label, dtype=string]`, :py:obj:`[label_des, dtype=string]`, :py:obj:`[sentence, dtype=string]`.
- task = `TNEWS`
- usage = `train`,输出列: `[label, dtype=string]`, `[label_des, dtype=string]`, `[sentence, dtype=string]`, `[keywords, dtype=string]`.
- usage = `test`,输出列: `[label, dtype=string]`, `[label_des, dtype=string]`, `[sentence, dtype=string]`, `[keywords, dtype=string]`.
- usage = `eval`,输出列: `[label, dtype=string]`, `[label_des, dtype=string]`, `[sentence, dtype=string]`, `[keywords, dtype=string]`.
- task = :py:obj:`CMNLI`
- usage = :py:obj:`train`,输出列: :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`.
- usage = :py:obj:`test`,输出列: :py:obj:`[id, dtype=uint8]`, :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`.
- usage = :py:obj:`eval`,输出列: :py:obj:`[sentence1, dtype=string]`, :py:obj:`[sentence2, dtype=string]`, :py:obj:`[label, dtype=string]`.
- task = `IFLYTEK`
- usage = `train`,输出列: `[label, dtype=string]`, `[label_des, dtype=string]`, `[sentence, dtype=string]`.
- usage = `test`,输出列: `[id, dtype=string]`, `[sentence, dtype=string]`.
- usage = `eval`,输出列: `[label, dtype=string]`, `[label_des, dtype=string]`, `[sentence, dtype=string]`.
- task = :py:obj:`WSC`
- usage = :py:obj:`train`,输出列: :py:obj:`[span1_index, dtype=uint8]`, :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, :py:obj:`[text, dtype=string]`, :py:obj:`[label, dtype=string]`.
- usage = :py:obj:`test`,输出列: :py:obj:`[span1_index, dtype=uint8]`, :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, :py:obj:`[text, dtype=string]`.
- usage = :py:obj:`eval`,输出列: :py:obj:`[span1_index, dtype=uint8]`, :py:obj:`[span2_index, dtype=uint8]`, :py:obj:`[span1_text, dtype=string]`, :py:obj:`[span2_text, dtype=string]`, :py:obj:`[idx, dtype=uint8]`, :py:obj:`[text, dtype=string]`, :py:obj:`[label, dtype=string]`.
- task = `CMNLI`
- usage = `train`,输出列: `[sentence1, dtype=string]`, `[sentence2, dtype=string]`, `[label, dtype=string]`.
- usage = `test`,输出列: `[id, dtype=uint8]`, `[sentence1, dtype=string]`, `[sentence2, dtype=string]`.
- usage = `eval`,输出列: `[sentence1, dtype=string]`, `[sentence2, dtype=string]`, `[label, dtype=string]`.
- task = :py:obj:`CSL`
- usage = :py:obj:`train`,输出列: :py:obj:`[id, dtype=uint8]`, :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`, :py:obj:`[label, dtype=string]`.
- usage = :py:obj:`test`,输出列: :py:obj:`[id, dtype=uint8]`, :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`.
- usage = :py:obj:`eval`,输出列: :py:obj:`[id, dtype=uint8]`, :py:obj:`[abst, dtype=string]`, :py:obj:`[keyword, dtype=string]`, :py:obj:`[label, dtype=string]`.
- task = `WSC`
- usage = `train`,输出列: `[span1_index, dtype=uint8]`, `[span2_index, dtype=uint8]`, `[span1_text, dtype=string]`, `[span2_text, dtype=string]`, `[idx, dtype=uint8]`, `[text, dtype=string]`, `[label, dtype=string]`.
- usage = `test`,输出列: `[span1_index, dtype=uint8]`, `[span2_index, dtype=uint8]`, `[span1_text, dtype=string]`, `[span2_text, dtype=string]`, `[idx, dtype=uint8]`, `[text, dtype=string]`.
- usage = `eval`,输出列: `[span1_index, dtype=uint8]`, `[span2_index, dtype=uint8]`, `[span1_text, dtype=string]`, `[span2_text, dtype=string]`, `[idx, dtype=uint8]`, `[text, dtype=string]`, `[label, dtype=string]`.
- task = `CSL`
- usage = `train`,输出列: `[id, dtype=uint8]`, `[abst, dtype=string]`, `[keyword, dtype=string]`, `[label, dtype=string]`.
- usage = `test`,输出列: `[id, dtype=uint8]`, `[abst, dtype=string]`, `[keyword, dtype=string]`.
- usage = `eval`,输出列: `[id, dtype=uint8]`, `[abst, dtype=string]`, `[keyword, dtype=string]`, `[label, dtype=string]`.
**参数:**
- **dataset_files** (Union[str, list[str]])数据集文件路径支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串文件列表将在内部进行字典排序。
- **task** (str, 可选):任务类型,可取值为`AFQMC``Tnews``IFLYTEK``CMNLI``WSC``CSL`(默认为`AFQMC`)。
- **usage** (str, 可选):指定数据集的子集,可取值为`train``test``eval`(默认为`train`)。
- **num_samples** (int, 可选)指定从数据集中读取的样本数默认为None即读取所有图像样本
- **num_parallel_workers** (int, 可选):指定读取数据的工作线程数(默认值None即使用mindspore.dataset.config中配置的线程数
- **shuffle** (Union[bool, Shuffle level], 可选)每个epoch中数据混洗的模式默认为为mindspore.dataset.Shuffle.GLOBAL
如果为False则不混洗如果为True等同于将`shuffle`设置为mindspore.dataset.Shuffle.GLOBAL。另外也可以传入枚举变量设置shuffle级别
- Shuffle.GLOBAL混洗文件和样本。
- Shuffle.FILES仅混洗文件。
- **num_shards** (int, 可选)指定分布式训练时将数据集进行划分的分片数默认值None。指定此参数后, `num_samples` 表示每个分片的最大样本数。
- **shard_id** (int, 可选)指定分布式训练时使用的分片ID号默认值None。只有当指定了 `num_shards` 时才能指定此参数。
- **cache** (DatasetCache, 可选)数据缓存客户端实例用于加快数据集处理速度默认为None不使用缓存
- **dataset_files** (Union[str, list[str]])数据集文件路径支持单文件路径字符串、多文件路径字符串列表或可被glob库模式匹配的字符串文件列表将在内部进行字典排序。
- **task** (str, 可选):任务类型,可取值为 `AFQMC``Tnews``IFLYTEK``CMNLI``WSC``CSL`(默认为 `AFQMC` )。
- **usage** (str, 可选):指定数据集的子集,可取值为 `train``test``eval`(默认为 `train` )。
- **num_samples** (int, 可选)指定从数据集中读取的样本数默认为None即读取所有图像样本
- **num_parallel_workers** (int, 可选):指定读取数据的工作线程数(默认值None即使用mindspore.dataset.config中配置的线程数
- **shuffle** (Union[bool, Shuffle level], 可选)每个epoch中数据混洗的模式默认为为mindspore.dataset.Shuffle.GLOBAL。如果为False则不混洗如果为True等同于将 `shuffle` 设置为mindspore.dataset.Shuffle.GLOBAL。另外也可以传入枚举变量设置shuffle级别
- Shuffle.GLOBAL混洗文件和样本。
- Shuffle.FILES仅混洗文件。
- **num_shards** (int, 可选)指定分布式训练时将数据集进行划分的分片数默认值None。指定此参数后, `num_samples` 表示每个分片的最大样本数。
- **shard_id** (int, 可选)指定分布式训练时使用的分片ID号默认值None。只有当指定了 `num_shards` 时才能指定此参数。
- **cache** (DatasetCache, 可选)数据缓存客户端实例用于加快数据集处理速度默认为None不使用缓存
**异常:**
- **RuntimeError**`dataset_files` 所指的文件无效或不存在。
- **RuntimeError**`num_parallel_workers` 超过系统最大线程数。
- **RuntimeError**:指定了`num_shards`参数,但是未指定`shard_id`参数。
- **RuntimeError**:指定了`shard_id`参数,但是未指定`num_shards`参数。
**示例:**
>>> clue_dataset_dir = ["/path/to/clue_dataset_file"] # 包含一个或多个CLUE数据集文件
>>> dataset = ds.CLUEDataset(dataset_files=clue_dataset_dir, task='AFQMC', usage='train')
- **RuntimeError**`dataset_files` 所指的文件无效或不存在。
- **RuntimeError**`num_parallel_workers` 超过系统最大线程数。
- **RuntimeError**:指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
- **RuntimeError**:指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
**样例:**
>>> clue_dataset_dir = ["/path/to/clue_dataset_file"] # 包含一个或多个CLUE数据集文件
>>> dataset = ds.CLUEDataset(dataset_files=clue_dataset_dir, task='AFQMC', usage='train')
**关于CLUE数据集**
CLUE又名中文语言理解测评基准包含许多有代表性的数据集涵盖单句分类、句对分类和机器阅读理解等任务。
您可以将数据集解压成如下的文件结构并通过MindSpore的API进行读取`afqmc`数据集为例:
您可以将数据集解压成如下的文件结构并通过MindSpore的API进行读取 `afqmc` 数据集为例:
.. code-block::

View File

@ -1,31 +1,37 @@
Class mindspore.dataset.DistributedSampler(num_shards, shard_id, shuffle=True, num_samples=None, offset=-1)
mindspore.dataset.DistributedSampler
====================================
.. py:class:: mindspore.dataset.DistributedSampler(num_shards, shard_id, shuffle=True, num_samples=None, offset=-1)
分布式采样器,将数据集进行分片用于分布式训练。
**参数:**
- **num_shards** (int):数据集分片数量。
- **shard_id** (int)当前分片的分片ID应在[0, num_shards-1]范围内。
- **shuffle** (bool, optional)如果为True则索引将被打乱默认为True
- **num_samples** (int, optional)要采样的样本数默认为None对所有元素进行采样
- **offset** (int, optional)将数据集中的元素发送到的起始分片ID不应超过num_shards。仅当ConcatDataset以DistributedSampler为采样器时此参数才有效。此参数影响每个分片的样本数默认为-1每个分片具有相同的样本数
- **num_shards** (int):数据集分片数量。
- **shard_id** (int)当前分片的分片ID应在[0, num_shards-1]范围内。
- **shuffle** (bool, optional)如果为True则索引将被打乱默认为True
- **num_samples** (int, optional)要采样的样本数默认为None对所有元素进行采样
- **offset** (int, optional)将数据集中的元素发送到的起始分片ID不应超过 `num_shards` 。仅当ConcatDataset以DistributedSampler为采样器时此参数才有效。此参数影响每个分片的样本数默认为-1每个分片具有相同的样本数
**示例:**
>>> # 创建一个分布式采样器共10个分片。当前分片为分片5。
>>> sampler = ds.DistributedSampler(10, 5)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
**样例:**
>>> # 创建一个分布式采样器共10个分片。当前分片为分片5。
>>> sampler = ds.DistributedSampler(10, 5)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
... num_parallel_workers=8,
... sampler=sampler)
**异常:**
- **TypeError**num_shards不是整数值。
- **TypeError**shard_id不是整数值。
- **TypeError**shuffle不是Boolean值。
- **TypeError**num_samples不是整数值。
- **TypeError**offset不是整数值。
- **ValueError**num_samples为负值。
- **RuntimeError**num_shards不是正值。
- **RuntimeError**shard_id小于0或大于等于num_shards。
- **RuntimeError**offset大于num_shards。
- **TypeError**`num_shards` 不是整数值。
- **TypeError**`shard_id` 不是整数值。
- **TypeError**`shuffle` 不是Boolean值。
- **TypeError**`num_samples` 不是整数值。
- **TypeError**`offset` 不是整数值。
- **ValueError**`num_samples` 为负值。
- **RuntimeError**`num_shards` 不是正值。
- **RuntimeError**`shard_id` 小于0或大于等于 `num_shards`
- **RuntimeError**`offset` 大于 `num_shards`
.. include:: mindspore.dataset.BuiltinSampler.rst

View File

@ -1,321 +1,364 @@
Class mindspore.dataset.GraphData(dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1, auto_shutdown=True)
mindspore.dataset.GraphData
===========================
.. py:class:: mindspore.dataset.GraphData(dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1, auto_shutdown=True)
从共享文件和数据库中读取用于GNN训练的图数据集。
**参数:**
- **dataset_file** (str):数据集文件路径。
- **num_parallel_workers** (int, 可选)读取数据的工作线程数默认为None
- **working_mode** (str, 可选):设置工作模式,目前支持'local'/'client'/'server'(默认为'local')。
-'local',用于非分布式训练场景。
-'client',用于分布式训练场景。客户端不加载数据,而是从服务器获取数据。
-'server',用于分布式训练场景。服务器加载数据并可供客户端使用。
- **hostname** (str, 可选):图数据集服务器的主机名。该参数仅在工作模式设置为'client'或'server'时有效(默认为'127.0.0.1')。
- **port** (int, 可选)图数据服务器的端口取值范围为1024-65535。此参数仅当工作模式设置为'client'或'server'默认为50051时有效。
- **num_client** (int, 可选):期望连接到服务器的最大客户端数。服务器将根据该参数分配资源。该参数仅在工作模式设置为'server'时有效默认为1
- **auto_shutdown** (bool, 可选):当工作模式设置为'server'时有效。当连接的客户端数量达到num_client且没有客户端正在连接时服务器将自动退出默认为True
- **dataset_file** (str):数据集文件路径。
- **num_parallel_workers** (int, 可选)读取数据的工作线程数默认为None
- **working_mode** (str, 可选):设置工作模式,目前支持'local'/'client'/'server'(默认为'local')。
-'local',用于非分布式训练场景。
-'client',用于分布式训练场景。客户端不加载数据,而是从服务器获取数据。
-'server',用于分布式训练场景。服务器加载数据并可供客户端使用。
- **hostname** (str, 可选):图数据集服务器的主机名。该参数仅在工作模式设置为'client'或'server'时有效(默认为'127.0.0.1')。
- **port** (int, 可选)图数据服务器的端口取值范围为1024-65535。此参数仅当工作模式设置为'client'或'server'默认为50051时有效。
- **num_client** (int, 可选):期望连接到服务器的最大客户端数。服务器将根据该参数分配资源。该参数仅在工作模式设置为'server'时有效默认为1
- **auto_shutdown** (bool, 可选):当工作模式设置为'server'时有效。当连接的客户端数量达到 `num_client` 且没有客户端正在连接时服务器将自动退出默认为True
**样例:**
>>> graph_dataset_dir = "/path/to/graph_dataset_file"
>>> graph_dataset = ds.GraphData(dataset_file=graph_dataset_dir, num_parallel_workers=2)
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[1])
.. py:method:: get_all_edges(edge_type)
获取图的所有边。
**参数:**
**edge_type** (int):指定边的类型。
**返回:**
numpy.ndarray包含边的数组。
**样例:**
>>> edges = graph_dataset.get_all_edges(edge_type=0)
**异常:**
**TypeError**:参数`edge_type`的类型不为整型。
.. py:method:: get_all_neighbors(node_list, neighbor_type, output_format=<OutputFormat.NORMAL: 0。
获取 `node_list` 所有节点的邻居,以 `neighbor_type` 类型返回。格式的定义参见以下示例1表示两个节点之间连接0表示不连接。
.. list-table:: 邻接矩阵
:widths: 20 20 20 20 20
:header-rows: 1
* -
- 0
- 1
- 2
- 3
* - 0
- 0
- 1
- 0
- 0
* - 1
- 0
- 0
- 1
- 0
* - 2
- 1
- 0
- 0
- 1
* - 3
- 1
- 0
- 0
- 0
.. list-table:: 普通格式
:widths: 20 20 20 20 20
:header-rows: 1
* - src
- 0
- 1
- 2
- 3
* - dst_0
- 1
- 2
- 0
- 1
* - dst_1
- -1
- -1
- 3
- -1
.. list-table:: COO格式
:widths: 20 20 20 20 20 20
:header-rows: 1
* - src
- 0
- 1
- 2
- 2
- 3
* - dst
- 1
- 2
- 0
- 3
- 1
.. list-table:: CSR格式
:widths: 40 20 20 20 20 20
:header-rows: 1
* - offsetTable
- 0
- 1
- 2
- 4
-
* - dstTable
- 1
- 2
- 0
- 3
- 1
**参数:**
- **node_list** (Union[list, numpy.ndarray]):给定的节点列表。
- **neighbor_type** (int):指定邻居节点的类型。
- **output_format** (OutputFormat, 可选)输出存储格式默认为mindspore.dataset.engine.OutputFormat.NORMAL取值范围[OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]。
**返回:**
对于普通格式或COO格式将返回numpy.ndarray类型的数组表示邻居节点。如果指定了CSR格式将返回两个numpy.ndarray数组第一个表示偏移表第二个表示邻居节点。
**样例:**
>>> from mindspore.dataset.engine import OutputFormat
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)
>>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
... output_format=OutputFormat.COO)
>>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
... output_format=OutputFormat.CSR)
**异常:**
- **TypeError**:参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError**:参数 `neighbor_type` 的类型不为整型。
.. py:method:: get_all_nodes(node_type)
获取图中的所有节点。
**参数:**
**node_type** (int):指定节点的类型。
**返回:**
numpy.ndarray包含节点的数组。
**样例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
**异常:**
**TypeError**:参数`node_type`的类型不为整型。
.. py:method:: get_edges_from_nodes(node_list)
从节点获取边。
**参数:**
**node_list** (Union[list[tuple], numpy.ndarray])含一个或多个图节点ID对的列表。
**返回:**
numpy.ndarray含一个或多个边ID的数组。
**示例:**
>>> graph_dataset_dir = "/path/to/graph_dataset_file"
>>> graph_dataset = ds.GraphData(dataset_file=graph_dataset_dir, num_parallel_workers=2)
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[1])
get_all_edges(edge_type)
获取图的所有边。
**参数:**
edge_type (int):指定边的类型。
**返回:**
numpy.ndarray包含边的数组。
**示例:**
>>> edges = graph_dataset.get_all_edges(edge_type=0)
**异常:**
- **TypeError**:参数`edge_type`的类型不为整型。
get_all_neighbors(node_list, neighbor_type, output_format=<OutputFormat.NORMAL: 0。
获取`node_list`所有节点的邻居,以`neighbor_type`类型返回。
格式的定义参见以下示例1表示两个节点之间连接0表示不连接。
.. list-table:: 邻接矩阵
:widths: 20 20 20 20 20
:header-rows: 1
* -
- 0
- 1
- 2
- 3
* - 0
- 0
- 1
- 0
- 0
* - 1
- 0
- 0
- 1
- 0
* - 2
- 1
- 0
- 0
- 1
* - 3
- 1
- 0
- 0
- 0
.. list-table:: 普通格式
:widths: 20 20 20 20 20
:header-rows: 1
* - src
- 0
- 1
- 2
- 3
* - dst_0
- 1
- 2
- 0
- 1
* - dst_1
- -1
- -1
- 3
- -1
.. list-table:: COO格式
:widths: 20 20 20 20 20 20
:header-rows: 1
* - src
- 0
- 1
- 2
- 2
- 3
* - dst
- 1
- 2
- 0
- 3
- 1
.. list-table:: CSR格式
:widths: 40 20 20 20 20 20
:header-rows: 1
* - offsetTable
- 0
- 1
- 2
- 4
-
* - dstTable
- 1
- 2
- 0
- 3
- 1
**参数:**
- **node_list** (Union[list, numpy.ndarray]):给定的节点列表。
- **neighbor_type** (int):指定邻居节点的类型。
- **output_format** (OutputFormat, 可选)输出存储格式默认为mindspore.dataset.engine.OutputFormat.NORMAL取值范围[OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]。
**返回:**
对于普通格式或COO格式
将返回numpy.ndarray类型的数组表示邻居节点。
如果指定了CSR格式将返回两个numpy.ndarray数组第一个表示偏移表第二个表示邻居节点。
**示例:**
>>> from mindspore.dataset.engine import OutputFormat
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)
>>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
... output_format=OutputFormat.COO)
>>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
... output_format=OutputFormat.CSR)
**异常:**
- **TypeError**:参数`node_list`的类型不为列表或numpy.ndarray。
- **TypeError**:参数`neighbor_type`的类型不为整型。
get_all_nodes(node_type)
获取图中的所有节点。
>>> edges = graph_dataset.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
**参数:**
- **node_type** (int):指定节点的类型。
**异常:**
**返回:**
numpy.ndarray包含节点的数组。
**TypeError**:参数 `edge_list` 的类型不为列表或numpy.ndarray。
**示例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
**异常:**
- **TypeError**:参数`node_type`的类型不为整型。
.. py:method:: get_edge_feature(edge_list, feature_types)
获取 `edge_list` 列表中边的特征,以 `feature_types` 类型返回。
get_edges_from_nodes(node_list)
**参数:**
从节点获取边。
- **edge_list** (Union[list, numpy.ndarray]):包含边的列表。
- **feature_types** (Union[list, numpy.ndarray]):包含给定特征类型的列表。
**参数:**
- **node_list** (Union[list[tuple], numpy.ndarray])含一个或多个图节点ID对的列表。
**返回:**
**返回:**
numpy.ndarray含一个或多个边ID的数组。
numpy.ndarray包含特征的数组。
**示例:**
>>> edges = graph_dataset.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
**样例:**
**异常:**
- **TypeError**:参数`edge_list`的类型不为列表或numpy.ndarray。
>>> edges = graph_dataset.get_all_edges(edge_type=0)
>>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=[1])
**异常:**
get_edge_feature(edge_list, feature_types)
- **TypeError**:参数 `edge_list` 的类型不为列表或numpy.ndarray。
- **TypeError**:参数 `feature_types` 的类型不为列表或numpy.ndarray。
获取`edge_list`列表中边的特征,以`feature_types`类型返回。
**参数:**
- **edge_list** (Union[list, numpy.ndarray]):包含边的列表。
- **feature_types** (Union[list, numpy.ndarray]):包含给定特征类型的列表。
.. py:method:: get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type)
**返回:**
numpy.ndarray包含特征的数组。
获取 `node_list` 列表中节所有点的负样本邻居,以 `neg_neighbor_type` 类型返回。
**示例:**
>>> edges = graph_dataset.get_all_edges(edge_type=0)
>>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=[1])
**参数:**
**异常:**
- **TypeError**:参数`edge_list`的类型不为列表或numpy.ndarray
- **TypeError**:参数`feature_types`的类型不为列表或numpy.ndarray
- **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
- **neg_neighbor_num** (int):采样的邻居数量。
- **neg_neighbor_type** (int):指定负样本邻居的类型。
**返回:**
get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type)
numpy.ndarray包含邻居的数组。
获取`node_list`列表中节所有点的负样本邻居,以`neg_neighbor_type`类型返回。
**样例:**
**参数:**
- **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
- **neg_neighbor_num** (int):采样的邻居数量。
- **neg_neighbor_type** (int):指定负样本邻居的类型。
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
... neg_neighbor_type=2)
**返回:**
numpy.ndarray包含邻居的数组。
**异常:**
**示例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
... neg_neighbor_type=2)
- **TypeError**:参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError**:参数 `neg_neighbor_num` 的类型不为整型。
- **TypeError**:参数 `neg_neighbor_type` 的类型不为整型。
**异常:**
- **TypeError**:参数`node_list`的类型不为列表或numpy.ndarray。
- **TypeError**:参数`neg_neighbor_num`的类型不为整型。
- **TypeError**:参数`neg_neighbor_type`的类型不为整型。
.. py:method:: get_nodes_from_edges(edge_list)
get_nodes_from_edges(edge_list)
从图中的边获取节点。
从图中的边获取节点。
**参数:**
**参数:**
- **edge_list** (Union[list, numpy.ndarray]):包含边的列表。
**edge_list** (Union[list, numpy.ndarray]):包含边的列表。
**返回:**
numpy.ndarray包含节点的数组。
**返回:**
**异常:**
TypeError参数`edge_list`不为列表或ndarray。
numpy.ndarray包含节点的数组。
**异常:**
get_node_feature(node_list, feature_types)
**TypeError** 参数 `edge_list` 不为列表或ndarray。
获取`node_list`中节点的特征,以`feature_types`类型返回。
**参数:**
- **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
- **feature_types** (Union[list, numpy.ndarray]):指定特征的类型。
.. py:method:: get_node_feature(node_list, feature_types)
**返回:**
numpy.ndarray包含特征的数组。
获取 `node_list` 中节点的特征,以 `feature_types` 类型返回。
**示例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[2, 3])
**参数:**
**异常:**
- **TypeError**:参数`node_list`的类型不为列表或numpy.ndarray。
- **TypeError**:参数`feature_types`的类型不为列表或numpy.ndarray。
- **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
- **feature_types** (Union[list, numpy.ndarray]):指定特征的类型。
**返回:**
get_sampled_neighbors(node_list, neighbor_nums, neighbor_types, strategy=<SamplingStrategy.RANDOM: 0>)
numpy.ndarray包含特征的数组。
获取已采样邻居信息。此API支持多跳邻居采样。即将上一次采样结果作为下一跳采样的输入最多允许6跳。
采样结果平铺成列表,格式为[input node, 1-hop sampling result, 2-hop samling result ...]
**示例:**
**参数:**
- **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
- **neighbor_nums** (Union[list, numpy.ndarray]):每跳采样的邻居数。
- **neighbor_types** (Union[list, numpy.ndarray]):每跳采样的邻居类型。
- **strategy** (SamplingStrategy, 可选)采样策略默认为mindspore.dataset.engine.SamplingStrategy.RANDOM。取值范围[SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]。
- SamplingStrategy.RANDOM随机抽样带放回采样。
- SamplingStrategy.EDGE_WEIGHT以边缘权重为概率进行采样。
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[2, 3])
**返回:**
numpy.ndarray包含邻居的数组。
**异常:**
**示例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
... neighbor_types=[2, 1])
- **TypeError**:参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError**:参数 `feature_types` 的类型不为列表或numpy.ndarray。
**异常:**
- **TypeError**:参数`node_list`的类型不为列表或numpy.ndarray。
- **TypeError**:参数`neighbor_nums`的类型不为列表或numpy.ndarray。
- **TypeError**:参数`neighbor_types`的类型不为列表或numpy.ndarray。
.. py:method:: get_sampled_neighbors(node_list, neighbor_nums, neighbor_types, strategy=<SamplingStrategy.RANDOM: 0>)
graph_info()
获取已采样邻居信息。此API支持多跳邻居采样。即将上一次采样结果作为下一跳采样的输入最多允许6跳。采样结果平铺成列表格式为[input node, 1-hop sampling result, 2-hop samling result ...]
获取图的元信息,包括节点数、节点类型、节点特征信息、边数、边类型、边特征信息。
**参数:**
**返回:**
dict图的元信息。键为node_num、node_type、node_feature_type、edge_num、edge_type、和edge_feature_type。
- **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
- **neighbor_nums** (Union[list, numpy.ndarray]):每跳采样的邻居数。
- **neighbor_types** (Union[list, numpy.ndarray]):每跳采样的邻居类型。
- **strategy** (SamplingStrategy, 可选)采样策略默认为mindspore.dataset.engine.SamplingStrategy.RANDOM。取值范围[SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]。
- SamplingStrategy.RANDOM随机抽样带放回采样。
- SamplingStrategy.EDGE_WEIGHT以边缘权重为概率进行采样。
**返回:**
random_walk(target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1)
numpy.ndarray包含邻居的数组。
在节点中的随机游走。
*样例:**
**参数:**
- **target_nodes** (list[int]):随机游走中的起始节点列表。
- **meta_path** (list[int]):每个步长的节点类型。
- **step_home_param** (float, 可选)返回node2vec算法中的超参默认为1.0)。
- **step_away_param** (float, 可选)node2vec算法中的in和out超参默认为1.0)。
- **default_node** (int, 可选):如果找不到更多邻居,则为默认节点(默认值为-1表示不给定节点
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
... neighbor_types=[2, 1])
**返回:**
numpy.ndarray包含节点的数组。
**异常:**
**示例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> walks = graph_dataset.random_walk(target_nodes=nodes, meta_path=[2, 1, 2])
- **TypeError**:参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError**:参数 `neighbor_nums` 的类型不为列表或numpy.ndarray。
- **TypeError**:参数 `neighbor_types` 的类型不为列表或numpy.ndarray。
**异常:**
- **TypeError**:参数`target_nodes`的类型不为列表或numpy.ndarray。
- **TypeError**:参数`meta_path`的类型不为列表或numpy.ndarray。
.. py:method:: graph_info()
获取图的元信息,包括节点数、节点类型、节点特征信息、边数、边类型、边特征信息。
**返回:**
dict图的元信息。键为node_num、node_type、node_feature_type、edge_num、edge_type、和edge_feature_type。
.. py:method:: random_walk(target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1)
在节点中的随机游走。
**参数:**
- **target_nodes** (list[int]):随机游走中的起始节点列表。
- **meta_path** (list[int]):每个步长的节点类型。
- **step_home_param** (float, 可选)返回node2vec算法中的超参默认为1.0)。
- **step_away_param** (float, 可选)node2vec算法中的in和out超参默认为1.0)。
- **default_node** (int, 可选):如果找不到更多邻居,则为默认节点(默认值为-1表示不给定节点
**返回:**
numpy.ndarray包含节点的数组。
**示例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> walks = graph_dataset.random_walk(target_nodes=nodes, meta_path=[2, 1, 2])
**异常:**
- **TypeError**:参数 `target_nodes` 的类型不为列表或numpy.ndarray。
- **TypeError**:参数 `meta_path` 的类型不为列表或numpy.ndarray。

View File

@ -1,24 +1,28 @@
apply(apply_func)
.. py:method:: apply(apply_func)
对数据集对象执行给定操作函数。
参数:
apply_func (function):传入`Dataset`对象作为参数,并将返回处理后的`Dataset`对象。
**参数:**
返回:
执行了给定操作函数的数据集对象。
`apply_func` (function):传入 `Dataset` 对象作为参数,并将返回处理后的 `Dataset` 对象。
示例:
>>> # dataset是数据集类的实例化对象
>>>
>>> # 声明一个名为apply_func函数其返回值是一个Dataset对象
>>> def apply_func(data)
... data = data.batch(2)
... return data
>>>
>>> # 通过apply操作调用apply_func函数
>>> dataset = dataset.apply(apply_func)
**返回:**
异常:
TypeErrorapply_func不是一个函数。
TypeErrorapply_func未返回Dataset对象。
执行了给定操作函数的数据集对象。
**样例:**
>>> # dataset是数据集类的实例化对象
>>>
>>> # 声明一个名为apply_func函数其返回值是一个Dataset对象
>>> def apply_func(data)
... data = data.batch(2)
... return data
>>>
>>> # 通过apply操作调用apply_func函数
>>> dataset = dataset.apply(apply_func)
**异常:**
- **TypeError** `apply_func` 不是一个函数。
- **TypeError** `apply_func` 未返回Dataset对象。

View File

@ -1,57 +1,41 @@
batch(batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False)
.. py:method:: batch(batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False)
将dataset中连续`batch_size`行数据合并为一个批处理数据。
将dataset中连续 `batch_size` 行数据合并为一个批处理数据。
对一个批处理数据执行给定操作与对条数据进行给定操作用法一致。
对于任意列batch操作要求该列中的各条数据shape必须相同。
如果给定可执行函数`per_batch_map`,它将作用于批处理后的数据。
对一个批处理数据执行给定操作与对条数据进行给定操作用法一致。对于任意列batch操作要求该列中的各条数据shape必须相同。如果给定可执行函数 `per_batch_map` ,它将作用于批处理后的数据。
注:
执行`repeat``batch`操作的顺序,会影响数据批次的数量及`per_batch_map`操作。
建议在batch操作完成后执行repeat操作。
.. note::
执行 `repeat``batch` 操作的顺序,会影响数据批次的数量及 `per_batch_map` 操作。建议在batch操作完成后执行repeat操作。
参数:
batch_size (int or function)每个批处理数据包含的条数。参数需要是int或可调用对象该对象接收1个参数即BatchInfo。
drop_remainder (bool, optional)是否删除最后一个数据条数小于批处理大小的batch默认值为False
如果为True并且最后一个批次中数据行数少于`batch_size`,则这些数据将被丢弃,不会传递给后续的操作。
num_parallel_workers (int, optional)用于进行batch操作的的线程数threads默认值为None。
per_batch_map (callable, optional):是一个以(list[Tensor], list[Tensor], ..., BatchInfo)作为输入参数的可调用对象。
每个list[Tensor]代表给定列上的一批Tensor。入参中list[Tensor]的个数应与input_columns中传入列名的数量相匹配。
该可调用对象的最后一个参数始终是BatchInfo对象。`per_batch_map`应返回(list[Tensor], list[Tensor], ...)。
其出中list[Tensor]的个数应与输入相同。如果输出列数与输入列数不一致则需要指定output_columns。
input_columns (Union[str, list[str]], optional):由输入列名组成的列表。如果`per_batch_map`不为None
列表中列名的个数应与 `per_batch_map` 中包含的列数匹配默认为None
** 参数:**
output_columns (Union[str, list[str]], optional):当前操作所有输出列的列名列表。
如果len(input_columns) != len(output_columns),则此参数必须指定。
此列表中列名的数量必须与给定操作的输出列数相匹配默认为None输出列将与输入列具有相同的名称
- **batch_size** (int or function)每个批处理数据包含的条数。参数需要是int或可调用对象该对象接收1个参数即BatchInfo。
- **drop_remainder** (bool, optional)是否删除最后一个数据条数小于批处理大小的batch默认值为False。如果为True并且最后一个批次中数据行数少于 `batch_size`,则这些数据将被丢弃,不会传递给后续的操作。
- **num_parallel_workers** (int, optional)用于进行batch操作的的线程数threads默认值为None。
- **per_batch_map** (callable, optional):是一个以(list[Tensor], list[Tensor], ..., BatchInfo)作为输入参数的可调用对象。每个list[Tensor]代表给定列上的一批Tensor。入参中list[Tensor]的个数应与 `input_columns` 中传入列名的数量相匹配。该可调用对象的最后一个参数始终是BatchInfo对象。`per_batch_map`应返回(list[Tensor], list[Tensor], ...)。其出中list[Tensor]的个数应与输入相同。如果输出列数与输入列数不一致,则需要指定 `output_columns`。 - **input_columns** (Union[str, list[str]], optional):由输入列名组成的列表。如果 `per_batch_map` 不为None列表中列名的个数应与 `per_batch_map` 中包含的列数匹配默认为None
- **output_columns** (Union[str, list[str]], optional)当前操作所有输出列的列名列表。如果len(input_columns) != len(output_columns)则此参数必须指定。此列表中列名的数量必须与给定操作的输出列数相匹配默认为None输出列将与输入列具有相同的名称
- **column_order** (Union[str, list[str]], optional)指定整个数据集对象中包含的所有列名的顺序。如果len(input_column) != len(output_column),则此参数必须指定。 注意:这里的列名不仅仅是在 `input_columns``output_columns` 中指定的列。
- **pad_info** (dict, optional):用于对给定列进行填充。例如 `pad_info={"col1":([224,224],0)}` ,则将列名为"col1"的列填充到大小为[224,224]的张量并用0填充缺失的值默认为None)。
- **python_multiprocessing** (bool, optional):针对 `per_batch_map` 函数使用Python多进执行的方式进行调用。如果函数计算量大开启这个选项可能会很有帮助默认值为False
column_order (Union[str, list[str]], optional):指定整个数据集对象中包含的所有列名的顺序。
如果len(input_column) != len(output_column),则此参数必须指定。
注意:这里的列名不仅仅是在`input_columns``output_columns`中指定的列。
**返回:**
pad_info (dict, optional):用于对给定列进行填充。例如`pad_info={"col1":([224,224],0)}`
则将列名为"col1"的列填充到大小为[224,224]的张量并用0填充缺失的值默认为None)。
批处理后的数据集对象。
python_multiprocessing (bool, optional):针对`per_batch_map`函数使用Python多进执行的方式进行调用。
如果函数计算量大开启这个选项可能会很有帮助默认值为False
**样例:**
返回:
批处理后的数据集对象。
示例:
>>> # 创建一个数据集对象每100条数据合并成一个批次
>>> # 如果最后一个批次数据小于给定的批次大小batch_size),则丢弃这个批次
>>> dataset = dataset.batch(100, True)
>>> # 根据批次编号调整图像大小如果是第5批则图像大小调整为(5^2, 5^2) = (25, 25)
>>> def np_resize(col, batchInfo):
... output = col.copy()
... s = (batchInfo.get_batch_num() + 1) ** 2
... index = 0
... for c in col:
... img = Image.fromarray(c.astype('uint8')).convert('RGB')
... img = img.resize((s, s), Image.ANTIALIAS)
... output[index] = np.array(img)
... index += 1
... return (output,)
>>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
>>> # 创建一个数据集对象每100条数据合并成一个批次
>>> # 如果最后一个批次数据小于给定的批次大小batch_size),则丢弃这个批次
>>> dataset = dataset.batch(100, True)
>>> # 根据批次编号调整图像大小如果是第5批则图像大小调整为(5^2, 5^2) = (25, 25)
>>> def np_resize(col, batchInfo):
... output = col.copy()
... s = (batchInfo.get_batch_num() + 1) ** 2
... index = 0
... for c in col:
... img = Image.fromarray(c.astype('uint8')).convert('RGB')
... img = img.resize((s, s), Image.ANTIALIAS)
... output[index] = np.array(img)
... index += 1
... return (output,)
>>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)

View File

@ -1,35 +1,25 @@
bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None, pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False)
.. py:method:: bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None, pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False)
依据数据中元素长度进行分桶。每个桶将在满了的时候进行元素填充和批处理操作。
对数据集中的每一条数据执行长度计算函数。然后,根据该条数据的长度和桶的边界将该数据归到特定的桶里面。
当桶中数据条数达到指定的大小`bucket_batch_sizes`时,将根据`pad_info`对桶中元素进行填充,再进行批处理。
这样每个批次都是满的但也有特殊情况每个桶的最后一个批次batch可能不满。
对数据集中的每一条数据执行长度计算函数。然后,根据该条数据的长度和桶的边界将该数据归到特定的桶里面。当桶中数据条数达到指定的大小 `bucket_batch_sizes` 时,将根据 `pad_info` 对桶中元素进行填充再进行批处理。这样每个批次都是满的但也有特殊情况每个桶的最后一个批次batch可能不满。
参数:
column_names (list[str]):传递给长度计算函数的所有列名。
bucket_boundaries (list[int])由各个桶的上边界值组成的列表必须严格递增。如果有n个边界则创建n+1个桶分配后桶的边界如下
[0, bucket_boundaries[0])[bucket_boundaries[i], bucket_boundaries[i+1])其中0<i<n-1[bucket_boundaries[n-1], inf)。
**参数:**
bucket_batch_sizes (list[int]):由每个桶的批次大小组成的列表,必须包含`len(bucket_boundaries)+1`个元素。
element_length_function (Callable, optional)输入包含M个参数的函数其中M等于`len(column_names)`,并返回一个整数。
如果未指定该参数,则`len(column_names)`必须为1并且该列数据第一维的shape值将用作长度默认为None
- **column_names** (list[str]):传递给长度计算函数的所有列名。
- **bucket_boundaries** (list[int])由各个桶的上边界值组成的列表必须严格递增。如果有n个边界则创建n+1个桶分配后桶的边界如下[0, bucket_boundaries[0])[bucket_boundaries[i], bucket_boundaries[i+1])其中0<i<n-1[bucket_boundaries[n-1], inf)。
- **bucket_batch_sizes** (list[int]):由每个桶的批次大小组成的列表,必须包含 `len(bucket_boundaries)+1` 个元素。
- **element_length_function** (Callable, optional)输入包含M个参数的函数其中M等于 `len(column_names)` ,并返回一个整数。如果未指定该参数,则 `len(column_names)` 必须为1并且该列数据第一维的shape值将用作长度默认为None
- **pad_info** (dict, optional)有关如何对指定列进行填充的字典对象。字典中键对应要填充的列名值必须是包含2个元素的元组。元组中第一个元素对应要填充成的shape第二个元素对应要填充的值。如果某一列未指定将要填充后的shape和填充值则当前批次中该列上的每条数据都将填充至该批次中最长数据的长度填充值为0。除非 `pad_to_bucket_boundary` 为True否则 `pad_info` 中任何填充shape为None的列其每条数据长度都将被填充为当前批处理中最数据的长度。如果不需要填充请将 `pad_info` 设置为None默认为None
- **pad_to_bucket_boundary** (bool, optional)如果为True`pad_info` 中填充shape为None的列其长度都会被填充至 `bucket_boundary-1` 长度。如果有任何元素落入最后一个桶中则将报错默认为False
- **drop_remainder** (bool, optional)如果为True则丢弃每个桶中最后不足一个批次数据默认为False
pad_info (dict, optional)有关如何对指定列进行填充的字典对象。字典中键对应要填充的列名值必须是包含2个元素的元组。
元组中第一个元素对应要填充成的shape第二个元素对应要填充的值。
如果某一列未指定将要填充后的shape和填充值则当前批次中该列上的每条数据都将填充至该批次中最长数据的长度填充值为0。
除非`pad_to_bucket_boundary`为True否则`pad_info`中任何填充shape为None的列
其每条数据长度都将被填充为当前批处理中最数据的长度。如果不需要填充,请将`pad_info`设置为None默认为None
**返回:**
pad_to_bucket_boundary (bool, optional)如果为True则pad_info中填充shape为None的列
其长度都会被填充至`bucket_boundary-1`长度。如果有任何元素落入最后一个桶中则将报错默认为False
BucketBatchByLengthDataset按长度进行分桶和批处理操作后的数据集对象。
drop_remainder (bool, optional)如果为True则丢弃每个桶中最后不足一个批次数据默认为False
返回:
BucketBatchByLengthDataset按长度进行分桶和批处理操作后的数据集对象。
示例:
**样例:**
>>> # 创建一个数据集对象,其中给定条数的数据会被组成一个批次数据
>>> # 如果最后一个批次数据小于给定的批次大小batch_size),则丢弃这个批次
>>> import numpy as np

View File

@ -1,25 +1,26 @@
build_sentencepiece_vocab(columns, vocab_size, character_coverage, model_type, params)
.. py:method:: build_sentencepiece_vocab(columns, vocab_size, character_coverage, model_type, params)
用于从源数据集对象创建句子词表的函数。
参数:
**参数:**
columns(list[str]):指定从哪一列中获取单词。
vocab_size(int):词汇表大小。
character_coverage(int)模型涵盖的字符百分比必须介于0.98和1.0之间。
默认值如0.9995适用于具有丰富字符集的语言如日语或中文字符集1.0适用于其他字符集较小的语言,比如英语或拉丁文。
model_type(SentencePieceModel)模型类型枚举值包括unigram默认值、bpe、char及word。当类型为word时输入句子必须预先标记。
params(dict):依据原始数据内容构建祠表的附加参数,无附加参数时取值可以是空字典。
- **columns** (list[str]):指定从哪一列中获取单词。
- **vocab_size** (int):词汇表大小。
- **character_coverage** (int)模型涵盖的字符百分比必须介于0.98和1.0之间。默认值如0.9995适用于具有丰富字符集的语言如日语或中文字符集1.0适用于其他字符集较小的语言,比如英语或拉丁文。
- **model_type** (SentencePieceModel)模型类型枚举值包括unigram默认值、bpe、char及word。当类型为word时输入句子必须预先标记。
- **params** (dict):依据原始数据内容构建祠表的附加参数,无附加参数时取值可以是空字典。
返回:
SentencePieceVocab从数据集构建的词汇表。
**返回:**
示例:
>>> from mindspore.dataset.text import SentencePieceModel
>>>
>>> # DE_C_INTER_SENTENCEPIECE_MODE 是一个映射字典
>>> from mindspore.dataset.text.utils import DE_C_INTER_SENTENCEPIECE_MODE
>>> dataset = ds.TextFileDataset("/path/to/sentence/piece/vocab/file", shuffle=False)
>>> dataset = dataset.build_sentencepiece_vocab(["text"], 5000, 0.9995,
... DE_C_INTER_SENTENCEPIECE_MODE[SentencePieceModel.UNIGRAM],
... {})
SentencePieceVocab从数据集构建的词汇表。
**样例:**
>>> from mindspore.dataset.text import SentencePieceModel
>>>
>>> # DE_C_INTER_SENTENCEPIECE_MODE 是一个映射字典
>>> from mindspore.dataset.text.utils import DE_C_INTER_SENTENCEPIECE_MODE
>>> dataset = ds.TextFileDataset("/path/to/sentence/piece/vocab/file", shuffle=False)
>>> dataset = dataset.build_sentencepiece_vocab(["text"], 5000, 0.9995,
... DE_C_INTER_SENTENCEPIECE_MODE[SentencePieceModel.UNIGRAM],
... {})

View File

@ -5,12 +5,12 @@ mindspore.dataset.Cifar100Dataset
用于读取和解析CIFAR-100数据集的源数据文件。
生成的数据集有三列: `[image, coarse_label, fine_label]``image` 列的数据类型是uint8。`coarse_label``fine_labels` 列的数据是uint32类型的标量。
生成的数据集有三列: `[image, coarse_label, fine_label]` `image` 列的数据类型是uint8。 `coarse_label``fine_labels` 列的数据是uint32类型的标量。
**参数:**
- **dataset_dir** (str): 包含数据集文件的根目录路径。
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train``test``all`。使用`train`参数将会读取50,000个训练样本`test` 将会读取10,000个测试样本`all` 将会读取全部60,000个样本默认值为None即全部样本图片
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train` `test` `all`。使用 `train` 参数将会读取50,000个训练样本 `test` 将会读取10,000个测试样本 `all` 将会读取全部60,000个样本默认值为None即全部样本图片
- **num_samples** (int, 可选): 指定从数据集中读取的样本数可以小于数据集总数默认值为None即全部样本图片)。
- **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数默认值None即使用mindspore.dataset.config中配置的线程数
- **shuffle** (bool, 可选): 是否混洗数据集默认为None下表中会展示不同配置的预期行为
@ -24,12 +24,12 @@ mindspore.dataset.Cifar100Dataset
- **RuntimeError:** `dataset_dir` 路径下不包含数据文件。
- **RuntimeError:** `num_parallel_workers` 超过系统最大线程数。
- **RuntimeError:** 同时指定了 `sampler``shuffle` 参数。
- **RuntimeError:** 同时指定了`sampler``num_shards`参数。
- **RuntimeError:** 同时指定了 `sampler` `num_shards` 参数。
- **RuntimeError:** 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
- **RuntimeError:** 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
- **ValueError:** `shard_id` 参数错误小于0或者大于等于 `num_shards`)。
.. note:: 此数据集可以指定 `sampler` 参数,但`sampler``shuffle` 是互斥的。下表展示了几种合法的输入参数及预期的行为。
.. note:: 此数据集可以指定 `sampler` 参数,但 `sampler``shuffle` 是互斥的。下表展示了几种合法的输入参数及预期的行为。
.. list-table:: 配置 `sampler``shuffle` 的不同组合得到的预期排序结果
:widths: 25 25 50

View File

@ -10,7 +10,7 @@ mindspore.dataset.Cifar10Dataset
**参数:**
- **dataset_dir** (str): 包含数据集文件的根目录路径。
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train``test``all`。使用`train`参数将会读取50,000个训练样本`test` 将会读取10,000个测试样本`all` 将会读取全部60,000个样本默认值为None即全部样本图片
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train``test 或`all`。使用 `train` 参数将会读取50,000个训练样本`test` 将会读取10,000个测试样本 `all` 将会读取全部60,000个样本默认值为None即全部样本图片
- **num_samples** (int, 可选): 指定从数据集中读取的样本数可以小于数据集总数默认值为None即全部样本图片)。
- **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数默认值None即使用mindspore.dataset.config中配置的线程数
- **shuffle** (bool, 可选): 是否混洗数据集默认为None下表中会展示不同配置的预期行为
@ -23,20 +23,20 @@ mindspore.dataset.Cifar10Dataset
- **RuntimeError:** `dataset_dir` 路径下不包含数据文件。
- **RuntimeError:** `num_parallel_workers` 超过系统最大线程数。
- **RuntimeError:** 同时指定了`sampler``shuffle`参数。
- **RuntimeError:** 同时指定了`sampler``num_shards`参数。
- **RuntimeError:** 指定了`num_shards`参数,但是未指定`shard_id`参数。
- **RuntimeError:** 指定了`shard_id`参数,但是未指定`num_shards`参数。
- **ValueError:** `shard_id`参数错误小于0或者大于等于 `num_shards`)。
- **RuntimeError:** 同时指定了 `sampler` `shuffle` 参数。
- **RuntimeError:** 同时指定了 `sampler` `num_shards` 参数。
- **RuntimeError:** 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
- **RuntimeError:** 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
- **ValueError:** `shard_id` 参数错误小于0或者大于等于 `num_shards` )。
.. note:: 此数据集可以通过`sampler`指定任意采样器,但参数`sampler``shuffle` 的行为是互斥的。下表展示了几种合法的输入参数及预期的行为。
.. note:: 此数据集可以通过 `sampler` 指定任意采样器,但参数 `sampler``shuffle` 的行为是互斥的。下表展示了几种合法的输入参数及预期的行为。
.. list-table:: 配置`sampler``shuffle`的不同组合得到的预期排序结果
.. list-table:: 配置 `sampler` `shuffle` 的不同组合得到的预期排序结果
:widths: 25 25 50
:header-rows: 1
* - 参数`sampler`
- 参数`shuffle`
* - 参数 `sampler`
- 参数 `shuffle`
- 预期数据顺序
* - None
- None
@ -47,13 +47,13 @@ mindspore.dataset.Cifar10Dataset
* - None
- False
- 顺序排列
* - 参数`sampler`
* - 参数 `sampler`
- None
- 由`sampler`行为定义的顺序
* - 参数`sampler`
- 由 `sampler` 行为定义的顺序
* - 参数 `sampler`
- True
- 不允许
* - 参数`sampler`
* - 参数 `sampler`
- False
- 不允许

View File

@ -1,3 +1,3 @@
close_pool()
.. py:method:: close_pool()
关闭数据集对象中的多进程池。如果您熟悉多进程库,可以将此视为进程池对象的析构函数。

View File

@ -1,20 +1,24 @@
concat(datasets)
mindspore.dataset.concat
=========================
对传入的多个数据集对象进行拼接操作。
重载“+”运算符来进行数据集对象拼接操作。
.. py:method:: concat(datasets)
注:
用于拼接的多个数据集对象其列名、每列数据的维度rank)和类型必须相同。
对传入的多个数据集对象进行拼接操作。重载“+”运算符来进行数据集对象拼接操作。
参数:
datasets (Union[list, class Dataset]):与当前数据集对象拼接的数据集对象列表或单个数据集对象。
.. note::用于拼接的多个数据集对象其列名、每列数据的维度rank)和类型必须相同。
**参数:**
**datasets** (Union[list, class Dataset]):与当前数据集对象拼接的数据集对象列表或单个数据集对象。
返回:
ConcatDataset拼接后的数据集对象。
**返回:**
示例:
>>> # 通过使用“+”运算符拼接dataset_1和dataset_2获得拼接后的数据集对象
>>> dataset = dataset_1 + dataset_2
>>> # 通过concat操作拼接dataset_1和dataset_2获得拼接后的数据集对象
>>> dataset = dataset_1.concat(dataset_2)
ConcatDataset拼接后的数据集对象。
**样例:**
>>> # 通过使用“+”运算符拼接dataset_1和dataset_2获得拼接后的数据集对象
>>> dataset = dataset_1 + dataset_2
>>> # 通过concat操作拼接dataset_1和dataset_2获得拼接后的数据集对象
>>> dataset = dataset_1.concat(dataset_2)

View File

@ -1,22 +1,22 @@
create_dict_iterator(num_epochs=-1, output_numpy=False)
.. py:method:: create_dict_iterator(num_epochs=-1, output_numpy=False)
基于数据集对象创建迭代器,输出数据为字典类型。
基于数据集对象创建迭代器,输出数据为字典类型。
字典中列的顺序可能与数据集对象中原始顺序不同。
字典中列的顺序可能与数据集对象中原始顺序不同。
参数:
num_epochs (int, optional):迭代器可以迭代的最多轮次数(默认为-1迭代器可以迭代无限次
output_numpy (bool, optional)是否输出NumPy数据类型如果`output_numpy`为False
迭代器输出的每列数据类型为MindSpore.Tensor默认为False
**参数:**
返回:
DictIterator基于数据集对象创建的字典迭代器
- **num_epochs** (int, optional):迭代器可以迭代的最多轮次数(默认为-1迭代器可以迭代无限次
- **output_numpy** (bool, optional)是否输出NumPy数据类型如果 `output_numpy` 为False迭代器输出的每列数据类型为MindSpore.Tensor默认为False
示例:
>>> # dataset是数据集类的实例化对象
>>> iterator = dataset.create_dict_iterator()
>>> for item in iterator:
... # item 是一个dict
... print(type(item))
... break
<class 'dict'>
返回:
DictIterator基于数据集对象创建的字典迭代器。
示例:
>>> # dataset是数据集类的实例化对象
>>> iterator = dataset.create_dict_iterator()
>>> for item in iterator:
... # item 是一个dict
... print(type(item))
... break
<class 'dict'>

View File

@ -1,27 +1,26 @@
create_tuple_iterator(columns=None, num_epochs=-1, output_numpy=False, do_copy=True)
.. py:method:: create_tuple_iterator(columns=None, num_epochs=-1, output_numpy=False, do_copy=True)
基于数据集对象创建迭代器输出数据为ndarray组成的列表。
可以使用columns指定输出的所有列名及列的顺序。如果columns未指定列的顺序将保持不变。
参数:
columns (list[str], optional):用于指定列顺序的列名列表
默认为None表示所有列
num_epochs (int, optional):迭代器可以迭代的最多轮次数
(默认为-1迭代器可以迭代无限次
output_numpy (bool, optional)是否输出NumPy数据类型
如果output_numpy为False迭代器输出的每列数据类型为MindSpore.Tensor默认为False
do_copy (bool, optional)当输出数据类型为mindspore.Tensor时
通过此参数指定转换方法采用False主要考虑以获得更好的性能默认为True
**参数:**
返回:
TupleIterator基于数据集对象创建的元组迭代器。
- **columns** (list[str], optional)用于指定列顺序的列名列表默认为None表示所有列
- **num_epochs** (int, optional):迭代器可以迭代的最多轮次数(默认为-1迭代器可以迭代无限次
- **output_numpy** (bool, optional)是否输出NumPy数据类型如果output_numpy为False迭代器输出的每列数据类型为MindSpore.Tensor默认为False
- **do_copy** (bool, optional)当输出数据类型为mindspore.Tensor时通过此参数指定转换方法采用False主要考虑以获得更好的性能默认为True
示例:
>>> # dataset是数据集类的实例化对象
>>> iterator = dataset.create_tuple_iterator()
>>> for item in iterator
... # item 是一个列表
... print(type(item))
... break
<class 'list'>
**返回:**
TupleIterator基于数据集对象创建的元组迭代器。
**样例:**
>>> # dataset是数据集类的实例化对象
>>> iterator = dataset.create_tuple_iterator()
>>> for item in iterator
... # item 是一个列表
... print(type(item))
... break
<class 'list'>