forked from mindspore-Ecosystem/mindspore
!26795 update api cn format
Merge pull request !26795 from yingchen/code_docs_api1125
This commit is contained in:
commit
efad2cbe9d
|
@ -1,5 +1,5 @@
|
|||
mindspore.dataset.DSCallback
|
||||
==============================
|
||||
=============================
|
||||
|
||||
.. py:class:: mindspore.dataset.DSCallback(step_size=1)
|
||||
|
||||
|
@ -7,22 +7,19 @@ mindspore.dataset.DSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **step_size** (int, optional):调用 `ds_step_begin` 和 `ds_step_end` 之间间隔的step数(默认为1)。
|
||||
**step_size** (int, optional):调用 `ds_step_begin` 和 `ds_step_end` 之间间隔的step数(默认为1)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.dataset import DSCallback
|
||||
>>>
|
||||
>>> class PrintInfo(DSCallback):
|
||||
... def ds_epoch_end(self, ds_run_context):
|
||||
... print(cb_params.cur_epoch_num)
|
||||
... print(cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> # dataset为任意数据集实例,op为任意数据处理算子
|
||||
>>> dataset = dataset.map(operations=op, callbacks=PrintInfo())
|
||||
|
||||
>>> from mindspore.dataset import DSCallback
|
||||
>>>
|
||||
>>> class PrintInfo(DSCallback):
|
||||
... def ds_epoch_end(self, ds_run_context):
|
||||
... print(cb_params.cur_epoch_num)
|
||||
... print(cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> # dataset为任意数据集实例,op为任意数据处理算子
|
||||
>>> dataset = dataset.map(operations=op, callbacks=PrintInfo())
|
||||
|
||||
.. py:method:: ds_begin(ds_run_context)
|
||||
|
||||
|
@ -30,8 +27,7 @@ mindspore.dataset.DSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
**ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
.. py:method:: ds_epoch_begin(ds_run_context)
|
||||
|
||||
|
@ -39,8 +35,7 @@ mindspore.dataset.DSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
**ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
.. py:method:: ds_epoch_end(ds_run_context)
|
||||
|
||||
|
@ -48,8 +43,7 @@ mindspore.dataset.DSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
**ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
.. py:method:: ds_step_begin(ds_run_context)
|
||||
|
||||
|
@ -57,8 +51,7 @@ mindspore.dataset.DSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
**ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
.. py:method:: ds_step_end(ds_run_context)
|
||||
|
||||
|
@ -66,4 +59,4 @@ mindspore.dataset.DSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
**ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
|
|
@ -10,27 +10,23 @@ mindspore.dataset.DatasetCache
|
|||
|
||||
**参数:**
|
||||
|
||||
- **session_id** (int):当前数据缓存客户端的会话ID,用户在命令行开启缓存服务端后可通过 `cache_admin -g` 获取。
|
||||
- **size** (int, optional):设置数据缓存服务可用的内存大小(默认为0,即内存使用没有上限。注意,这可能会产生计算机内存不足的风险)。
|
||||
- **spilling** (bool, optional):如果共享内存不足,是否将溢出部分缓存到磁盘(默认为False)。
|
||||
- **hostname** (str, optional):数据缓存服务客户端的主机IP(默认为None,使用默认主机名127.0.0.1)。
|
||||
- **port** (int, optional):指定连接到数据缓存服务端的端口号(默认为None,使用端口50052)。
|
||||
- **num_connections** (int, optional):TCP/IP连接数量(默认为None,使用默认值12)。
|
||||
- **prefetch_size** (int, optional):指定缓存队列大小,使用缓存功能算子时,将直接从缓存队列中获取数据(默认为None,使用默认值20)。
|
||||
|
||||
- **session_id** (int):当前数据缓存客户端的会话ID,用户在命令行开启缓存服务端后可通过 `cache_admin -g` 获取。
|
||||
- **size** (int, optional):设置数据缓存服务可用的内存大小(默认为0,即内存使用没有上限。注意,这可能会产生计算机内存不足的风险)。
|
||||
- **spilling** (bool, optional):如果共享内存不足,是否将溢出部分缓存到磁盘(默认为False)。
|
||||
- **hostname** (str, optional):数据缓存服务客户端的主机IP(默认为None,使用默认主机名127.0.0.1)。
|
||||
- **port** (int, optional):指定连接到数据缓存服务端的端口号(默认为None,使用端口50052)。
|
||||
- **num_connections** (int, optional):TCP/IP连接数量(默认为None,使用默认值12)。
|
||||
- **prefetch_size** (int, optional):指定缓存队列大小,使用缓存功能算子时,将直接从缓存队列中获取数据(默认为None,使用默认值20)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> # 创建数据缓存客户端实例,其中 `session_id` 由命令 `cache_admin -g` 生成
|
||||
>>> some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
>>>
|
||||
>>> dataset_dir = "path/to/imagefolder_directory"
|
||||
>>> ds1 = ds.ImageFolderDataset(dataset_dir, cache=some_cache)
|
||||
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
>>> # 创建数据缓存客户端实例,其中 `session_id` 由命令 `cache_admin -g` 生成
|
||||
>>> some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
>>>
|
||||
>>> dataset_dir = "path/to/imagefolder_directory"
|
||||
>>> ds1 = ds.ImageFolderDataset(dataset_dir, cache=some_cache)
|
||||
|
||||
.. py:method:: get_stat()
|
||||
|
||||
|
|
|
@ -9,36 +9,34 @@ mindspore.dataset.MnistDataset
|
|||
|
||||
**参数:**
|
||||
|
||||
- **dataset_dir** (str): 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train`、`test` 或 `all`。使用 `train` 参数将会读取60,000个训练样本,`test` 将会读取10,000个测试样本,`all` 将会读取全部70,000个样本(默认值为None,即全部样本图片)。
|
||||
- **num_samples** (int, 可选): 指定从数据集中读取的样本数(可以小于数据集总数,默认值为None,即全部样本图片)。
|
||||
- **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数(默认值None,即使用mindspore.dataset.config中配置的线程数)。
|
||||
- **shuffle** (bool, 可选): 是否混洗数据集(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **sampler** (Sampler, 可选): 指定从数据集中选取样本的采样器(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **num_shards** (int, 可选): 分布式训练时,将数据集划分成指定的分片数(默认值None)。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选): 分布式训练时,指定使用的分片ID号(默认值None)。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选): 单节点数据缓存,能够加快数据加载和处理的速度(默认值None,即不使用缓存加速)。
|
||||
- **dataset_dir** (str): 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train`、`test` 或 `all`。使用 `train` 参数将会读取60,000个训练样本,`test` 将会读取10,000个测试样本,`all` 将会读取全部70,000个样本(默认值为None,即全部样本图片)。
|
||||
- **num_samples** (int, 可选): 指定从数据集中读取的样本数(可以小于数据集总数,默认值为None,即全部样本图片)。
|
||||
- **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数(默认值None,即使用mindspore.dataset.config中配置的线程数)。
|
||||
- **shuffle** (bool, 可选): 是否混洗数据集(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **sampler** (Sampler, 可选): 指定从数据集中选取样本的采样器(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **num_shards** (int, 可选): 分布式训练时,将数据集划分成指定的分片数(默认值None)。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选): 分布式训练时,指定使用的分片ID号(默认值None)。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选): 单节点数据缓存,能够加快数据加载和处理的速度(默认值None,即不使用缓存加速)。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **RuntimeError**: `dataset_dir` 路径下不包含数据文件。
|
||||
- **RuntimeError**: `num_parallel_workers` 超过系统最大线程数。
|
||||
- **RuntimeError**: 同时指定了`sampler`和`shuffle`参数。
|
||||
- **RuntimeError**: 同时指定了`sampler`和`num_shards`参数。
|
||||
- **RuntimeError**: 指定了`num_shards`参数,但是未指定`shard_id`参数。
|
||||
- **RuntimeError**: 指定了`shard_id`参数,但是未指定`num_shards`参数。
|
||||
- **ValueError**: `shard_id`参数错误(小于0或者大于等于 `num_shards`)。
|
||||
- **RuntimeError**: `dataset_dir` 路径下不包含数据文件。
|
||||
- **RuntimeError**: `num_parallel_workers` 超过系统最大线程数。
|
||||
- **RuntimeError**: 同时指定了`sampler`和`shuffle`参数。
|
||||
- **RuntimeError**: 同时指定了`sampler`和`num_shards`参数。
|
||||
- **RuntimeError**: 指定了`num_shards`参数,但是未指定`shard_id`参数。
|
||||
- **RuntimeError**: 指定了`shard_id`参数,但是未指定`num_shards`参数。
|
||||
- **ValueError**: `shard_id`参数错误(小于0或者大于等于 `num_shards`)。
|
||||
|
||||
**注:**
|
||||
.. note:: 此数据集可以指定 `sampler` 参数,但 `sampler` 和 `shuffle` 是互斥的。下表展示了几种合法的输入参数及预期的行为。
|
||||
|
||||
此数据集可以指定`sampler`参数,但`sampler` 和 `shuffle` 是互斥的。下表展示了几种合法的输入参数及预期的行为。
|
||||
|
||||
.. list-table:: 配置`sampler`和`shuffle`的不同组合得到的预期排序结果
|
||||
.. list-table:: 配置 `sampler` 和 `shuffle` 的不同组合得到的预期排序结果
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - 参数`sampler`
|
||||
- 参数`shuffle`
|
||||
* - 参数 `sampler`
|
||||
- 参数 `shuffle`
|
||||
- 预期数据顺序
|
||||
* - None
|
||||
- None
|
||||
|
@ -49,26 +47,24 @@ mindspore.dataset.MnistDataset
|
|||
* - None
|
||||
- False
|
||||
- 顺序排列
|
||||
* - 参数`sampler`
|
||||
* - 参数 `sampler`
|
||||
- None
|
||||
- 由`sampler`行为定义的顺序
|
||||
* - 参数`sampler`
|
||||
- 由 `sampler` 行为定义的顺序
|
||||
* - 参数 `sampler`
|
||||
- True
|
||||
- 不允许
|
||||
* - 参数`sampler`
|
||||
* - 参数 `sampler`
|
||||
- False
|
||||
- 不允许
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
|
||||
>>>
|
||||
>>> # 从MNIST数据集中随机读取3个样本
|
||||
>>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_samples=3)
|
||||
>>>
|
||||
>>> # 提示:在MNIST数据集生成的数据集对象中,每一次迭代得到的数据行都有"image"和"label"两个键
|
||||
>>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
|
||||
>>>
|
||||
>>> # 从MNIST数据集中随机读取3个样本
|
||||
>>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_samples=3)
|
||||
>>>
|
||||
>>> # 提示:在MNIST数据集生成的数据集对象中,每一次迭代得到的数据行都有"image"和"label"两个键
|
||||
|
||||
**关于MNIST数据集:**
|
||||
|
||||
|
|
|
@ -10,22 +10,18 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **step_size** (int, optional):每个step包含的数据行数。step大小通常与batch大小相等(默认值为1)。
|
||||
|
||||
**step_size** (int, optional):每个step包含的数据行数。step大小通常与batch大小相等(默认值为1)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.dataset import WaitedDSCallback
|
||||
>>>
|
||||
>>> my_cb = WaitedDSCallback(32)
|
||||
>>> # dataset为任意数据集实例
|
||||
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
||||
>>> data = data.batch(32)
|
||||
>>> # 定义网络
|
||||
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||
|
||||
>>> from mindspore.dataset import WaitedDSCallback
|
||||
>>>
|
||||
>>> my_cb = WaitedDSCallback(32)
|
||||
>>> # dataset为任意数据集实例
|
||||
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
||||
>>> data = data.batch(32)
|
||||
>>> # 定义网络
|
||||
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||
|
||||
.. py:method:: begin(run_context)
|
||||
|
||||
|
@ -33,8 +29,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext):网络训练运行信息。
|
||||
|
||||
**run_context** (RunContext):网络训练运行信息。
|
||||
|
||||
.. py:method:: ds_begin(ds_run_context)
|
||||
|
||||
|
@ -42,8 +37,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
**ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
.. py:method:: ds_epoch_begin(ds_run_context)
|
||||
|
||||
|
@ -52,8 +46,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context**:数据处理管道运行信息。
|
||||
|
||||
**ds_run_context**:数据处理管道运行信息。
|
||||
|
||||
.. py:method:: ds_epoch_end(ds_run_context)
|
||||
|
||||
|
@ -61,8 +54,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
**ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
.. py:method:: ds_step_begin(ds_run_context)
|
||||
|
||||
|
@ -71,8 +63,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context**:数据处理管道运行信息。
|
||||
|
||||
**ds_run_context**:数据处理管道运行信息。
|
||||
|
||||
.. py:method:: ds_step_end(ds_run_context)
|
||||
|
||||
|
@ -80,8 +71,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
**ds_run_context** (RunContext):数据处理管道运行信息。
|
||||
|
||||
.. py:method:: end(run_context)
|
||||
|
||||
|
@ -89,8 +79,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **run_context**:网络训练运行信息。
|
||||
|
||||
**run_context**:网络训练运行信息。
|
||||
|
||||
.. py:method:: epoch_begin(run_context)
|
||||
|
||||
|
@ -98,8 +87,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext):网络训练运行信息。
|
||||
|
||||
**run_context** (RunContext):网络训练运行信息。
|
||||
|
||||
.. py:method:: epoch_end(run_context)
|
||||
|
||||
|
@ -107,8 +95,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **run_context**:网络训练运行信息。
|
||||
|
||||
**run_context**:网络训练运行信息。
|
||||
|
||||
.. py:method:: step_begin(run_context)
|
||||
|
||||
|
@ -116,8 +103,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **run_context** (RunContext):网络训练运行信息。
|
||||
|
||||
**run_context** (RunContext):网络训练运行信息。
|
||||
|
||||
.. py:method:: step_end(run_context)
|
||||
|
||||
|
@ -125,8 +111,7 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
**run_context**:网络训练运行信息。
|
||||
|
||||
**run_context**:网络训练运行信息。
|
||||
|
||||
.. py:method:: sync_epoch_begin(train_run_context, ds_run_context)
|
||||
|
||||
|
@ -134,9 +119,8 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **train_run_context**:包含前一个epoch的反馈信息的网络训练运行信息。
|
||||
- **ds_run_context**:数据处理管道运行信息。
|
||||
|
||||
- **train_run_context**:包含前一个epoch的反馈信息的网络训练运行信息。
|
||||
- **ds_run_context**:数据处理管道运行信息。
|
||||
|
||||
.. py:method:: sync_step_begin(train_run_context, ds_run_context)
|
||||
|
||||
|
@ -144,5 +128,5 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
**参数:**
|
||||
|
||||
- **train_run_context**:包含前一个step的反馈信息的网络训练运行信息。
|
||||
- **ds_run_context**:数据处理管道运行信息。
|
||||
- **train_run_context**:包含前一个step的反馈信息的网络训练运行信息。
|
||||
- **ds_run_context**:数据处理管道运行信息。
|
||||
|
|
|
@ -9,36 +9,34 @@ mindspore.dataset.Cifar100Dataset
|
|||
|
||||
**参数:**
|
||||
|
||||
- **dataset_dir** (str): 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train`,`test`或`all`。使用`train`参数将会读取50,000个训练样本,`test` 将会读取10,000个测试样本,`all` 将会读取全部60,000个样本(默认值为None,即全部样本图片)。
|
||||
- **num_samples** (int, 可选): 指定从数据集中读取的样本数(可以小于数据集总数,默认值为None,即全部样本图片)。
|
||||
- **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数(默认值None,即使用mindspore.dataset.config中配置的线程数)。
|
||||
- **shuffle** (bool, 可选): 是否混洗数据集(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **sampler** (Sampler, 可选): 指定从数据集中选取样本的采样器(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **num_shards** (int, 可选): 分布式训练时,将数据集划分成指定的分片数(默认值None)。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选): 分布式训练时,指定使用的分片ID号(默认值None)。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选): 单节点数据缓存,能够加快数据加载和处理的速度(默认值None,即不使用缓存加速)。
|
||||
- **dataset_dir** (str): 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train`,`test`或`all`。使用`train`参数将会读取50,000个训练样本,`test` 将会读取10,000个测试样本,`all` 将会读取全部60,000个样本(默认值为None,即全部样本图片)。
|
||||
- **num_samples** (int, 可选): 指定从数据集中读取的样本数(可以小于数据集总数,默认值为None,即全部样本图片)。
|
||||
- **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数(默认值None,即使用mindspore.dataset.config中配置的线程数)。
|
||||
- **shuffle** (bool, 可选): 是否混洗数据集(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **sampler** (Sampler, 可选): 指定从数据集中选取样本的采样器(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **num_shards** (int, 可选): 分布式训练时,将数据集划分成指定的分片数(默认值None)。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选): 分布式训练时,指定使用的分片ID号(默认值None)。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选): 单节点数据缓存,能够加快数据加载和处理的速度(默认值None,即不使用缓存加速)。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **RuntimeError:** `dataset_dir` 路径下不包含数据文件。
|
||||
- **RuntimeError:** `num_parallel_workers` 超过系统最大线程数。
|
||||
- **RuntimeError:** 同时指定了`sampler`和`shuffle`参数。
|
||||
- **RuntimeError:** 同时指定了`sampler`和`num_shards`参数。
|
||||
- **RuntimeError:** 指定了`num_shards`参数,但是未指定`shard_id`参数。
|
||||
- **RuntimeError:** 指定了`shard_id`参数,但是未指定`num_shards`参数。
|
||||
- **ValueError:** `shard_id`参数错误(小于0或者大于等于 `num_shards`)。
|
||||
- **RuntimeError:** `dataset_dir` 路径下不包含数据文件。
|
||||
- **RuntimeError:** `num_parallel_workers` 超过系统最大线程数。
|
||||
- **RuntimeError:** 同时指定了 `sampler` 和 `shuffle` 参数。
|
||||
- **RuntimeError:** 同时指定了`sampler`和`num_shards`参数。
|
||||
- **RuntimeError:** 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
|
||||
- **RuntimeError:** 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
|
||||
- **ValueError:** `shard_id` 参数错误(小于0或者大于等于 `num_shards`)。
|
||||
|
||||
**注:**
|
||||
.. note:: 此数据集可以指定 `sampler` 参数,但`sampler` 和 `shuffle` 是互斥的。下表展示了几种合法的输入参数及预期的行为。
|
||||
|
||||
此数据集可以指定`sampler`参数,但`sampler` 和 `shuffle` 是互斥的。下表展示了几种合法的输入参数及预期的行为。
|
||||
|
||||
.. list-table:: 配置`sampler`和`shuffle`的不同组合得到的预期排序结果
|
||||
.. list-table:: 配置 `sampler` 和 `shuffle` 的不同组合得到的预期排序结果
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - 参数`sampler`
|
||||
- 参数`shuffle`
|
||||
* - 参数 `sampler`
|
||||
- 参数 `shuffle`
|
||||
- 预期数据顺序
|
||||
* - None
|
||||
- None
|
||||
|
@ -49,29 +47,27 @@ mindspore.dataset.Cifar100Dataset
|
|||
* - None
|
||||
- False
|
||||
- 顺序排列
|
||||
* - 参数`sampler`
|
||||
* - 参数 `sampler`
|
||||
- None
|
||||
- 由`sampler`行为定义的顺序
|
||||
* - 参数`sampler`
|
||||
- 由 `sampler` 行为定义的顺序
|
||||
* - 参数 `sampler`
|
||||
- True
|
||||
- 不允许
|
||||
* - 参数`sampler`
|
||||
* - 参数 `sampler`
|
||||
- False
|
||||
- 不允许
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> cifar100_dataset_dir = "/path/to/cifar100_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) 按数据集文件的读取顺序,依次获取CIFAR-100数据集中的所有样本
|
||||
>>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, shuffle=False)
|
||||
>>>
|
||||
>>> # 2) 从CIFAR100数据集中随机抽取350个样本
|
||||
>>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, num_samples=350, shuffle=True)
|
||||
>>>
|
||||
>>> # 提示: 在CIFAR-100数据集生成的数据集对象中,每一次迭代得到的数据行都有"image", "fine_label" 和 "coarse_label"三个键
|
||||
>>> cifar100_dataset_dir = "/path/to/cifar100_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) 按数据集文件的读取顺序,依次获取CIFAR-100数据集中的所有样本
|
||||
>>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, shuffle=False)
|
||||
>>>
|
||||
>>> # 2) 从CIFAR100数据集中随机抽取350个样本
|
||||
>>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, num_samples=350, shuffle=True)
|
||||
>>>
|
||||
>>> # 提示: 在CIFAR-100数据集生成的数据集对象中,每一次迭代得到的数据行都有"image", "fine_label" 和 "coarse_label"三个键
|
||||
|
||||
**关于CIFAR-100数据集:**
|
||||
|
||||
|
|
|
@ -9,29 +9,27 @@ mindspore.dataset.Cifar10Dataset
|
|||
|
||||
**参数:**
|
||||
|
||||
- **dataset_dir** (str): 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train`,`test`或`all`。使用`train`参数将会读取50,000个训练样本,`test` 将会读取10,000个测试样本,`all` 将会读取全部60,000个样本(默认值为None,即全部样本图片)。
|
||||
- **num_samples** (int, 可选): 指定从数据集中读取的样本数(可以小于数据集总数,默认值为None,即全部样本图片)。
|
||||
- **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数(默认值None,即使用mindspore.dataset.config中配置的线程数)。
|
||||
- **shuffle** (bool, 可选): 是否混洗数据集(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **sampler** (Sampler, 可选): 指定从数据集中选取样本的采样器(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **num_shards** (int, 可选): 分布式训练时,将数据集划分成指定的分片数(默认值None)。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选): 分布式训练时,指定使用的分片ID号(默认值None)。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选): 单节点数据缓存,能够加快数据加载和处理的速度(默认值None,即不使用缓存加速)。
|
||||
- **dataset_dir** (str): 包含数据集文件的根目录路径。
|
||||
- **usage** (str, 可选): 指定数据集的子集,可取值为 `train`,`test`或`all`。使用`train`参数将会读取50,000个训练样本,`test` 将会读取10,000个测试样本,`all` 将会读取全部60,000个样本(默认值为None,即全部样本图片)。
|
||||
- **num_samples** (int, 可选): 指定从数据集中读取的样本数(可以小于数据集总数,默认值为None,即全部样本图片)。
|
||||
- **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数(默认值None,即使用mindspore.dataset.config中配置的线程数)。
|
||||
- **shuffle** (bool, 可选): 是否混洗数据集(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **sampler** (Sampler, 可选): 指定从数据集中选取样本的采样器(默认为None,下表中会展示不同配置的预期行为)。
|
||||
- **num_shards** (int, 可选): 分布式训练时,将数据集划分成指定的分片数(默认值None)。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选): 分布式训练时,指定使用的分片ID号(默认值None)。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **cache** (DatasetCache, 可选): 单节点数据缓存,能够加快数据加载和处理的速度(默认值None,即不使用缓存加速)。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **RuntimeError:** `dataset_dir` 路径下不包含数据文件。
|
||||
- **RuntimeError:** `num_parallel_workers` 超过系统最大线程数。
|
||||
- **RuntimeError:** 同时指定了`sampler`和`shuffle`参数。
|
||||
- **RuntimeError:** 同时指定了`sampler`和`num_shards`参数。
|
||||
- **RuntimeError:** 指定了`num_shards`参数,但是未指定`shard_id`参数。
|
||||
- **RuntimeError:** 指定了`shard_id`参数,但是未指定`num_shards`参数。
|
||||
- **ValueError:** `shard_id`参数错误(小于0或者大于等于 `num_shards`)。
|
||||
- **RuntimeError:** `dataset_dir` 路径下不包含数据文件。
|
||||
- **RuntimeError:** `num_parallel_workers` 超过系统最大线程数。
|
||||
- **RuntimeError:** 同时指定了`sampler`和`shuffle`参数。
|
||||
- **RuntimeError:** 同时指定了`sampler`和`num_shards`参数。
|
||||
- **RuntimeError:** 指定了`num_shards`参数,但是未指定`shard_id`参数。
|
||||
- **RuntimeError:** 指定了`shard_id`参数,但是未指定`num_shards`参数。
|
||||
- **ValueError:** `shard_id`参数错误(小于0或者大于等于 `num_shards`)。
|
||||
|
||||
**注:**
|
||||
|
||||
此数据集可以通过`sampler`指定任意采样器,但参数`sampler` 和 `shuffle` 的行为是互斥的。下表展示了几种合法的输入参数及预期的行为。
|
||||
.. note:: 此数据集可以通过`sampler`指定任意采样器,但参数`sampler` 和 `shuffle` 的行为是互斥的。下表展示了几种合法的输入参数及预期的行为。
|
||||
|
||||
.. list-table:: 配置`sampler`和`shuffle`的不同组合得到的预期排序结果
|
||||
:widths: 25 25 50
|
||||
|
@ -60,21 +58,19 @@ mindspore.dataset.Cifar10Dataset
|
|||
- 不允许
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) 按数据集文件的读取顺序,获取CIFAR-10数据集中的所有样本
|
||||
>>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, shuffle=False)
|
||||
>>>
|
||||
>>> # 2) 从CIFAR10数据集中随机抽取350个样本
|
||||
>>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=350, shuffle=True)
|
||||
>>>
|
||||
>>> # 3) 对CIFAR10数据集进行分布式训练,并将数据集拆分为2个分片,当前数据集仅加载分片ID号为0的数据
|
||||
>>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_shards=2, shard_id=0)
|
||||
>>>
|
||||
>>> # 提示:在CIFAR-10数据集生成的数据集对象中,每一次迭代得到的数据行都有"image"和"label"两个键
|
||||
>>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory"
|
||||
>>>
|
||||
>>> # 1) 按数据集文件的读取顺序,获取CIFAR-10数据集中的所有样本
|
||||
>>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, shuffle=False)
|
||||
>>>
|
||||
>>> # 2) 从CIFAR10数据集中随机抽取350个样本
|
||||
>>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_samples=350, shuffle=True)
|
||||
>>>
|
||||
>>> # 3) 对CIFAR10数据集进行分布式训练,并将数据集拆分为2个分片,当前数据集仅加载分片ID号为0的数据
|
||||
>>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, num_shards=2, shard_id=0)
|
||||
>>>
|
||||
>>> # 提示:在CIFAR-10数据集生成的数据集对象中,每一次迭代得到的数据行都有"image"和"label"两个键
|
||||
|
||||
**关于CIFAR-10数据集:**
|
||||
|
||||
|
|
|
@ -7,18 +7,16 @@ mindspore.dataset.compare
|
|||
|
||||
**参数:**
|
||||
|
||||
- **pipeline1** (Dataset):数据处理管道。
|
||||
- **pipeline2** (Dataset):数据处理管道。
|
||||
- **pipeline1** (Dataset):数据处理管道。
|
||||
- **pipeline2** (Dataset):数据处理管道。
|
||||
|
||||
**返回:**
|
||||
|
||||
bool,两个数据处理管道是否相等。
|
||||
bool,两个数据处理管道是否相等。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, 100)
|
||||
>>> ds.compare(pipeline1, pipeline2)
|
||||
>>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, 100)
|
||||
>>> ds.compare(pipeline1, pipeline2)
|
||||
|
|
@ -10,30 +10,28 @@ mindspore.dataset.deserialize
|
|||
|
||||
**参数:**
|
||||
|
||||
- **input_dict** (dict):包含序列化数据集图的Python字典。
|
||||
- **json_filepath** (str):JSON文件的路径,用户可通过 `mindspore.dataset.serialize()` 接口生成。
|
||||
- **input_dict** (dict):包含序列化数据集图的Python字典。
|
||||
- **json_filepath** (str):JSON文件的路径,用户可通过 `mindspore.dataset.serialize()` 接口生成。
|
||||
|
||||
**返回:**
|
||||
|
||||
成功时,返回Dataset对象;失败时,则返回None。
|
||||
成功时,返回Dataset对象;失败时,则返回None。
|
||||
|
||||
**异常:**
|
||||
|
||||
**OSError:** 无法打开JSON文件。
|
||||
**OSError:** 无法打开JSON文件。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes是输入参数
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> # 用例1:序列化/反序列化 JSON文件
|
||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> # 用例2:序列化/反序列化 Python字典
|
||||
>>> serialized_data = ds.engine.serialize(dataset)
|
||||
>>> dataset = ds.engine.deserialize(input_dict=serialized_data)
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes是输入参数
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> # 用例1:序列化/反序列化 JSON文件
|
||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> # 用例2:序列化/反序列化 Python字典
|
||||
>>> serialized_data = ds.engine.serialize(dataset)
|
||||
>>> dataset = ds.engine.deserialize(input_dict=serialized_data)
|
||||
|
||||
|
|
@ -10,26 +10,24 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **dataset** (Dataset): 数据处理管道对象。
|
||||
- **json_filepath** (str): 生成序列化JSON文件的路径。
|
||||
- **dataset** (Dataset): 数据处理管道对象。
|
||||
- **json_filepath** (str): 生成序列化JSON文件的路径。
|
||||
|
||||
**返回:**
|
||||
|
||||
Dict,包含序列化数据集图的字典。
|
||||
Dict,包含序列化数据集图的字典。
|
||||
|
||||
**异常:**
|
||||
|
||||
**OSError:** 无法打开文件。
|
||||
**OSError:** 无法打开文件。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes是输入参数
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> # 将其序列化为JSON文件
|
||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> serialized_data = ds.engine.serialize(dataset) # 将其序列化为Python字典
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes是输入参数
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> # 将其序列化为JSON文件
|
||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||
>>> serialized_data = ds.engine.serialize(dataset) # 将其序列化为Python字典
|
||||
|
|
@ -7,16 +7,14 @@ mindspore.dataset.show
|
|||
|
||||
**参数:**
|
||||
|
||||
- **dataset** (Dataset): 数据处理管道对象。
|
||||
- **indentation** (int, optional): 设置MindSpore的INFO级别日志文件打印时的缩进字符数。若为None,则不缩进。
|
||||
- **dataset** (Dataset): 数据处理管道对象。
|
||||
- **indentation** (int, optional): 设置MindSpore的INFO级别日志文件打印时的缩进字符数。若为None,则不缩进。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> one_hot_encode = c_transforms.OneHot(10)
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> ds.show(dataset)
|
||||
>>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
|
||||
>>> one_hot_encode = c_transforms.OneHot(10)
|
||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||
>>> ds.show(dataset)
|
||||
|
|
@ -7,22 +7,22 @@
|
|||
|
||||
**参数:**
|
||||
|
||||
- **image** (ndarray): 待绘制的图像,shape为(C, H, W)或(H, W, C),通道顺序为RGB。
|
||||
- **bboxes** (ndarray): 边界框(包含类别置信度),shape为(N, 4)或(N, 5),格式为(N,X,Y,W,H)。
|
||||
- **labels** (ndarray): 边界框的类别,shape为(N, 1)。
|
||||
- **segm** (ndarray): 图像分割掩码,shape为(M, H, W),M表示类别总数(默认值None,不绘制掩码)。
|
||||
- **class_names** (list[str], dict): 类别索引到类别名的映射表(默认值None,仅显示类别索引)。
|
||||
- **score_threshold** (float): 绘制边界框的类别置信度阈值(默认值0,绘制所有边界框)。
|
||||
- **bbox_color** (tuple(int)): 指定绘制边界框时线条的颜色,顺序为BGR(默认值(0,255,0),表示'green')。
|
||||
- **text_color** (tuple(int)):指定类别文本的显示颜色,顺序为BGR(默认值(203, 192, 255),表示'pink')。
|
||||
- **mask_color** (tuple(int)):指定掩码的显示颜色,顺序为BGR(默认值(128, 0, 128),表示'purple')。
|
||||
- **thickness** (int): 指定边界框和类别文本的线条粗细(默认值2)。
|
||||
- **font_size** (int, float): 指定类别文本字体大小(默认值0.8)。
|
||||
- **show** (bool): 是否显示图像(默认值为True)。
|
||||
- **win_name** (str): 指定窗口名称(默认值"win")。
|
||||
- **wait_time** (int): 指定cv2.waitKey的时延,单位为ms,即图像显示的自动切换间隔(默认值2000,表示间隔为2000ms)。
|
||||
- **out_file** (str, optional): 输出图像的文件名,用于在绘制后将结果存储到本地(默认值None,不保存)。
|
||||
- **image** (ndarray): 待绘制的图像,shape为(C, H, W)或(H, W, C),通道顺序为RGB。
|
||||
- **bboxes** (ndarray): 边界框(包含类别置信度),shape为(N, 4)或(N, 5),格式为(N,X,Y,W,H)。
|
||||
- **labels** (ndarray): 边界框的类别,shape为(N, 1)。
|
||||
- **segm** (ndarray): 图像分割掩码,shape为(M, H, W),M表示类别总数(默认值None,不绘制掩码)。
|
||||
- **class_names** (list[str], dict): 类别索引到类别名的映射表(默认值None,仅显示类别索引)。
|
||||
- **score_threshold** (float): 绘制边界框的类别置信度阈值(默认值0,绘制所有边界框)。
|
||||
- **bbox_color** (tuple(int)): 指定绘制边界框时线条的颜色,顺序为BGR(默认值(0,255,0),表示'green')。
|
||||
- **text_color** (tuple(int)):指定类别文本的显示颜色,顺序为BGR(默认值(203, 192, 255),表示'pink')。
|
||||
- **mask_color** (tuple(int)):指定掩码的显示颜色,顺序为BGR(默认值(128, 0, 128),表示'purple')。
|
||||
- **thickness** (int): 指定边界框和类别文本的线条粗细(默认值2)。
|
||||
- **font_size** (int, float): 指定类别文本字体大小(默认值0.8)。
|
||||
- **show** (bool): 是否显示图像(默认值为True)。
|
||||
- **win_name** (str): 指定窗口名称(默认值"win")。
|
||||
- **wait_time** (int): 指定cv2.waitKey的时延,单位为ms,即图像显示的自动切换间隔(默认值2000,表示间隔为2000ms)。
|
||||
- **out_file** (str, optional): 输出图像的文件名,用于在绘制后将结果存储到本地(默认值None,不保存)。
|
||||
|
||||
**返回:**
|
||||
|
||||
ndarray,带边界框和类别置信度的图像。
|
||||
ndarray,带边界框和类别置信度的图像。
|
||||
|
|
|
@ -7,17 +7,15 @@ mindspore.dataset.audio.transforms.AllpassBiquad
|
|||
|
||||
**参数:**
|
||||
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考 https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考 https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.AllpassBiquad(44100, 200.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.AllpassBiquad(44100, 200.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,20 +7,16 @@ mindspore.dataset.audio.transforms.AmplitudeToDB
|
|||
|
||||
**参数:**
|
||||
|
||||
- **stype** (ScaleType, optional):输入音频的原始标度(默认值为ScaleType.POWER)。
|
||||
取值可为ScaleType.MAGNITUDE或ScaleType.POWER。
|
||||
- **ref_value** (float, optional):系数参考值,用于计算分贝系数 `db_multiplier` ,
|
||||
:math:`db_multiplier = Log10(max(ref_value, amin))`。
|
||||
- **amin** (float, optional):波形取值下界,低于该值的波形将会被裁切。取值必须大于0。
|
||||
- **top_db** (float, optional):最小负截止分贝值,建议的取值为80.0(默认值为80.0)。
|
||||
- **stype** (ScaleType, optional):输入音频的原始标度(默认值为ScaleType.POWER)。取值可为ScaleType.MAGNITUDE或ScaleType.POWER。
|
||||
- **ref_value** (float, optional):系数参考值,用于计算分贝系数 `db_multiplier` , :math: `db_multiplier = Log10(max(ref_value, amin))`。
|
||||
- **amin** (float, optional):波形取值下界,低于该值的波形将会被裁切。取值必须大于0。
|
||||
- **top_db** (float, optional):最小负截止分贝值,建议的取值为80.0(默认值为80.0)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 400//2+1, 30])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.AmplitudeToDB(stype=ScaleType.POWER)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 400//2+1, 30])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.AmplitudeToDB(stype=ScaleType.POWER)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,11 +7,9 @@ mindspore.dataset.audio.transforms.Angle
|
|||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[1.43, 5.434], [23.54, 89.38]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Angle()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[1.43, 5.434], [23.54, 89.38]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Angle()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,19 +7,16 @@ mindspore.dataset.audio.transforms.BandBiquad
|
|||
|
||||
**参数:**
|
||||
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
- **noise** (bool, optional):若为True,则使用非音调音频(如打击乐)模式;
|
||||
若为False,则使用音调音频(如语音、歌曲或器乐)模式(默认为False)。
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
- **noise** (bool, optional):若为True,则使用非音调音频(如打击乐)模式;若为False,则使用音调音频(如语音、歌曲或器乐)模式(默认为False)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.BandBiquad(44100, 200.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.BandBiquad(44100, 200.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,19 +7,16 @@ mindspore.dataset.audio.transforms.BandpassBiquad
|
|||
|
||||
**参数:**
|
||||
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
- **const_skirt_gain** (bool, optional):若为True,则使用恒定裙边增益(峰值增益为Q)。
|
||||
若为False,则使用恒定的0dB峰值增益(默认为False)。
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
- **const_skirt_gain** (bool, optional):若为True,则使用恒定裙边增益(峰值增益为Q)。若为False,则使用恒定的0dB峰值增益(默认为False)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.BandpassBiquad(44100, 200.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.BandpassBiquad(44100, 200.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,17 +7,15 @@ mindspore.dataset.audio.transforms.BandrejectBiquad
|
|||
|
||||
**参数:**
|
||||
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03],[9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.BandrejectBiquad(44100, 200.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03],[9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.BandrejectBiquad(44100, 200.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,18 +7,16 @@ mindspore.dataset.audio.transforms.BassBiquad
|
|||
|
||||
**参数:**
|
||||
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **gain** (float):期望提升(或衰减)的音频增益,单位为dB。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
- **sample_rate** (int):采样率,例如44100 (Hz),不能为零。
|
||||
- **gain** (float):期望提升(或衰减)的音频增益,单位为dB。
|
||||
- **central_freq** (float):中心频率(单位:Hz)。
|
||||
- **Q** (float, optional):品质因子,参考https://en.wikipedia.org/wiki/Q_factor,取值范围(0, 1](默认值为0.707)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.BassBiquad(44100, 100.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.BassBiquad(44100, 100.0)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,15 +7,13 @@ mindspore.dataset.audio.transforms.ComplexNorm
|
|||
|
||||
**参数:**
|
||||
|
||||
- **power** (float, optional):范数的幂,取值非负(默认为1.0)。
|
||||
**power** (float, optional):范数的幂,取值非负(默认为1.0)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([2, 4, 2])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.ComplexNorm()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([2, 4, 2])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.ComplexNorm()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -3,21 +3,17 @@ mindspore.dataset.audio.transforms.Contrast
|
|||
|
||||
.. py:class:: mindspore.dataset.audio.transforms.Contrast(enhancement_amount=75.0)
|
||||
|
||||
给形如(..., time)维度的音频波形施加对比度增强效果。实现方式类似于SoX库。
|
||||
与音频压缩相比,该效果通过修改音频信号使其听起来更响亮。
|
||||
给形如(..., time)维度的音频波形施加对比度增强效果。实现方式类似于SoX库。与音频压缩相比,该效果通过修改音频信号使其听起来更响亮。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **enhancement_amount** (float):控制音频增益的量。取值范围为[0,100](默认为75.0)。
|
||||
注意当 `enhancement_amount` 等于0时,对比度增强效果仍然会很显著。
|
||||
**enhancement_amount** (float):控制音频增益的量。取值范围为[0,100](默认为75.0)。注意当 `enhancement_amount` 等于0时,对比度增强效果仍然会很显著。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Contrast()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Contrast()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,21 +7,16 @@ mindspore.dataset.audio.transforms.FrequencyMasking
|
|||
|
||||
**参数:**
|
||||
|
||||
- **iid_masks** (bool, optional):是否添加随机掩码(默认为False)。
|
||||
- **frequency_mask_param** (int):当 `iid_masks` 为True时,掩码长度将从[0, frequency_mask_param]中均匀采样;
|
||||
当 `iid_masks` 为False时,使用该值作为掩码的长度。取值范围为[0, freq_length],其中 `freq_length` 为波形
|
||||
在频域的长度(默认为0)。
|
||||
- **mask_start** (int):添加掩码的起始位置,只有当 `iid_masks` 为True时,该值才会生效。
|
||||
取值范围为[0, freq_length - frequency_mask_param],其中 `freq_length` 为波形在频域的长度(默认为0)。
|
||||
- **mask_value** (double):添加掩码的取值(默认为0.0)。
|
||||
- **iid_masks** (bool, optional):是否添加随机掩码(默认为False)。
|
||||
- **frequency_mask_param** (int):当 `iid_masks` 为True时,掩码长度将从[0, frequency_mask_param]中均匀采样;当 `iid_masks` 为False时,使用该值作为掩码的长度。取值范围为[0, freq_length],其中 `freq_length` 为波形在频域的长度(默认为0)。
|
||||
- **mask_start** (int):添加掩码的起始位置,只有当 `iid_masks` 为True时,该值才会生效。取值范围为[0, freq_length - frequency_mask_param],其中 `freq_length` 为波形在频域的长度(默认为0)。
|
||||
- **mask_value** (double):添加掩码的取值(默认为0.0)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 3, 2])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.FrequencyMasking(frequency_mask_param=1)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 3, 2])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.FrequencyMasking(frequency_mask_param=1)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -13,12 +13,10 @@ mindspore.dataset.audio.transforms.LowpassBiquad
|
|||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[0.8236, 0.2049, 0.3335], [0.5933, 0.9911, 0.2482],
|
||||
... [0.3007, 0.9054, 0.7598], [0.5394, 0.2842, 0.5634], [0.6363, 0.2226, 0.2288]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.LowpassBiquad(4000, 1500, 0.7)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[0.8236, 0.2049, 0.3335], [0.5933, 0.9911, 0.2482],
|
||||
... [0.3007, 0.9054, 0.7598], [0.5394, 0.2842, 0.5634], [0.6363, 0.2226, 0.2288]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.LowpassBiquad(4000, 1500, 0.7)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,21 +7,16 @@ mindspore.dataset.audio.transforms.TimeMasking
|
|||
|
||||
**参数:**
|
||||
|
||||
- **iid_masks** (bool, optional):是否添加随机掩码(默认为False)。
|
||||
- **time_mask_param** (int): 当 `iid_masks` 为True时,掩码长度将从[0, time_mask_param]中均匀采样;
|
||||
当 `iid_masks` 为False时,使用该值作为掩码的长度。取值范围为[0, time_length],其中 `time_length` 为波形
|
||||
在时域的长度(默认为0)。
|
||||
- **mask_start** (int):添加掩码的起始位置,只有当 `iid_masks` 为True时,该值才会生效。
|
||||
取值范围为[0, time_length - time_mask_param],其中 `time_length` 为波形在时域的长度(默认为0)。
|
||||
- **mask_value** (double):添加掩码的取值(默认为0.0)。
|
||||
- **iid_masks** (bool, optional):是否添加随机掩码(默认为False)。
|
||||
- **time_mask_param** (int): 当 `iid_masks` 为True时,掩码长度将从[0, time_mask_param]中均匀采样;当 `iid_masks` 为False时,使用该值作为掩码的长度。取值范围为[0, time_length],其中 `time_length` 为波形在时域的长度(默认为0)。
|
||||
- **mask_start** (int):添加掩码的起始位置,只有当 `iid_masks` 为True时,该值才会生效。取值范围为[0, time_length - time_mask_param],其中 `time_length` 为波形在时域的长度(默认为0)。
|
||||
- **mask_value** (double):添加掩码的取值(默认为0.0)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 3, 2])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.TimeMasking(time_mask_param=1)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 3, 2])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.TimeMasking(time_mask_param=1)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -7,17 +7,15 @@ mindspore.dataset.audio.transforms.TimeStretch
|
|||
|
||||
**参数:**
|
||||
|
||||
- **hop_length** (int, optional):STFT窗之间每跳的长度,即连续帧之间的样本数(默认为None,取 `n_freq - 1`)。
|
||||
- **n_freq** (int, optional):STFT中的滤波器组数(默认为201)。
|
||||
- **fixed_rate** (float, optional):频谱在时域加快或减缓的比例(默认为None,取1.0)。
|
||||
- **hop_length** (int, optional):STFT窗之间每跳的长度,即连续帧之间的样本数(默认为None,取 `n_freq - 1`)。
|
||||
- **n_freq** (int, optional):STFT中的滤波器组数(默认为201)。
|
||||
- **fixed_rate** (float, optional):频谱在时域加快或减缓的比例(默认为None,取1.0)。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 30])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.TimeStretch()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 30])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.TimeStretch()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
|
|
@ -6,5 +6,6 @@ mindspore.dataset.audio.utils.ScaleType
|
|||
音频标度枚举类。
|
||||
|
||||
可选枚举值为:ScaleType.MAGNITUDE和ScaleType.POWER。
|
||||
|
||||
- **ScaleType.MAGNITUDE**:代表输入音频的标度为振幅。
|
||||
- **ScaleType.POWER**:代表输入音频的标度为功率。
|
|
@ -18,80 +18,68 @@ mindspore.train.summary
|
|||
|
||||
**参数:**
|
||||
|
||||
- **log_dir** (str):`log_dir` 是用来保存summary的目录。
|
||||
- **log_dir** (str):`log_dir` 是用来保存summary的目录。
|
||||
- **file_prefix** (str):文件的前缀。默认值:events。
|
||||
- **file_suffix** (str):文件的后缀。默认值:_MS。
|
||||
- **network** (Cell):通过网络获取用于保存图形summary的管道。默认值:None。
|
||||
- **max_file_size** (int, optional):可写入磁盘的每个文件的最大大小(以字节为单位)。例如,如果不大于4GB,则设置 `max_file_size=4*1024**3` 。默认值:None,表示无限制。
|
||||
- **raise_exception** (bool, 可选):设置在记录数据中发生RuntimeError或OSError异常时是否抛出异常。默认值:False,表示打印错误日志,不抛出异常。
|
||||
- **export_options** (Union[None, dict]):可以将保存在summary中的数据导出,并使用字典自定义所需的数据和文件格式。注:导出的文件大小不受 `max_file_size` 的限制。例如,您可以设置{'tensor_format':'npy'}将Tensor导出为NPY文件。支持控制的数据如下所示。默认值:None,表示不导出数据。
|
||||
|
||||
- **file_prefix** (str):文件的前缀。默认值:events。
|
||||
- **tensor_format** (Union[str, None]):自定义导出的Tensor的格式。支持["npy", None]。默认值:None,表示不导出Tensor。
|
||||
|
||||
- **file_suffix** (str):文件的后缀。默认值:_MS。
|
||||
|
||||
- **network** (Cell):通过网络获取用于保存图形summary的管道。默认值:None。
|
||||
|
||||
- **max_file_size** (int, optional):可写入磁盘的每个文件的最大大小(以字节为单位)。例如,如果不大于4GB,则设置 `max_file_size=4*1024**3` 。默认值:None,表示无限制。
|
||||
|
||||
- **raise_exception** (bool, 可选):设置在记录数据中发生RuntimeError或OSError异常时是否抛出异常。默认值:False,表示打印错误日志,不抛出异常。
|
||||
|
||||
- **export_options** (Union[None, dict]):可以将保存在summary中的数据导出,并使用字典自定义所需的数据和文件格式。注:导出的文件大小不受 `max_file_size` 的限制。例如,您可以设置{'tensor_format':'npy'}将Tensor导出为NPY文件。支持控制的数据如下所示。默认值:None,表示不导出数据。
|
||||
|
||||
- **tensor_format** (Union[str, None]):自定义导出的Tensor的格式。支持["npy", None]。默认值:None,表示不导出Tensor。
|
||||
|
||||
- **npy**:将Tensor导出为NPY文件。
|
||||
- **npy**:将Tensor导出为NPY文件。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError:** 参数类型不正确。
|
||||
- **RuntimeError** :运行时错误。
|
||||
- **OSError:** 系统错误。
|
||||
- **TypeError:** 参数类型不正确。
|
||||
- **RuntimeError** :运行时错误。
|
||||
- **OSError:** 系统错误。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... # 在with语句中使用以自动关闭
|
||||
... with SummaryRecord(log_dir="./summary_dir") as summary_record:
|
||||
... pass
|
||||
...
|
||||
... # 在try .. finally .. 语句中使用以确保关闭
|
||||
... try:
|
||||
... summary_record = SummaryRecord(log_dir="./summary_dir")
|
||||
... finally:
|
||||
... summary_record.close()
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... # 在with语句中使用以自动关闭
|
||||
... with SummaryRecord(log_dir="./summary_dir") as summary_record:
|
||||
... pass
|
||||
...
|
||||
... # 在try .. finally .. 语句中使用以确保关闭
|
||||
... try:
|
||||
... summary_record = SummaryRecord(log_dir="./summary_dir")
|
||||
... finally:
|
||||
... summary_record.close()
|
||||
|
||||
|
||||
.. py:method:: add_value(plugin, name, value)
|
||||
|
||||
添加稍后记录的值。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **plugin** (str):数据类型标签。
|
||||
- **name** (str):数据名称。
|
||||
- **value** (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): 待存储的值。
|
||||
- **plugin** (str):数据类型标签。
|
||||
- **name** (str):数据名称。
|
||||
- **value** (Union[Tensor, GraphProto, TrainLineage, EvaluationLineage, DatasetGraph, UserDefinedInfo]): 待存储的值。
|
||||
|
||||
- 当plugin为"graph"时,参数值的数据类型应为"GraphProto"对象。具体详情,请参见 mindspore/ccsrc/anf_ir.proto。
|
||||
- 当plugin为"scalar"、"image"、"tensor"或"histogram"时,参数值的数据类型应为"Tensor"对象。
|
||||
- 当plugin为"train_lineage"时,参数值的数据类型应为"TrainLineage"对象。具体详情,请参见 mindspore/ccsrc/lineage.proto。
|
||||
- 当plugin为"eval_lineage"时,参数值的数据类型应为"EvaluationLineage"对象。具体详情,请参见 mindspore/ccsrc/lineage.proto。
|
||||
- 当plugin为"dataset_graph"时,参数值的数据类型应为"DatasetGraph"对象。具体详情,请参见 mindspore/ccsrc/lineage.proto。
|
||||
- 当plugin为"custom_lineage_data"时,参数值的数据类型应为"UserDefinedInfo"对象。具体详情,请参见 mindspore/ccsrc/lineage.proto。
|
||||
- 当plugin为"explainer"时,参数值的数据类型应为"Explain"对象。具体详情,请参见 mindspore/ccsrc/summary.proto。
|
||||
- 当plugin为"graph"时,参数值的数据类型应为"GraphProto"对象。具体详情,请参见 mindspore/ccsrc/anf_ir.proto。
|
||||
- 当plugin为"scalar"、"image"、"tensor"或"histogram"时,参数值的数据类型应为"Tensor"对象。
|
||||
- 当plugin为"train_lineage"时,参数值的数据类型应为"TrainLineage"对象。具体详情,请参见 mindspore/ccsrc/lineage.proto。
|
||||
- 当plugin为"eval_lineage"时,参数值的数据类型应为"EvaluationLineage"对象。具体详情,请参见 mindspore/ccsrc/lineage.proto。
|
||||
- 当plugin为"dataset_graph"时,参数值的数据类型应为"DatasetGraph"对象。具体详情,请参见 mindspore/ccsrc/lineage.proto。
|
||||
- 当plugin为"custom_lineage_data"时,参数值的数据类型应为"UserDefinedInfo"对象。具体详情,请参见 mindspore/ccsrc/lineage.proto。
|
||||
- 当plugin为"explainer"时,参数值的数据类型应为"Explain"对象。具体详情,请参见 mindspore/ccsrc/summary.proto。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **ValueError:** 参数值无效。
|
||||
|
||||
- **TypeError:** 参数类型错误。
|
||||
- **ValueError:** 参数值无效。
|
||||
- **TypeError:** 参数类型错误。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... summary_record.add_value('scalar', 'loss', Tensor(0.1))
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... summary_record.add_value('scalar', 'loss', Tensor(0.1))
|
||||
|
||||
|
||||
.. py:method:: close()
|
||||
|
@ -99,15 +87,13 @@ mindspore.train.summary
|
|||
将所有事件持久化并关闭SummaryRecord。请使用with语句或try…finally语句进行自动关闭。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... try:
|
||||
... summary_record = SummaryRecord(log_dir="./summary_dir")
|
||||
... finally:
|
||||
... summary_record.close()
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... try:
|
||||
... summary_record = SummaryRecord(log_dir="./summary_dir")
|
||||
... finally:
|
||||
... summary_record.close()
|
||||
|
||||
|
||||
.. py:method:: flush()
|
||||
|
@ -117,13 +103,11 @@ mindspore.train.summary
|
|||
调用该函数以确保所有挂起事件都已写入到磁盘。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... summary_record.flush()
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... summary_record.flush()
|
||||
|
||||
|
||||
.. py:method:: log_dir
|
||||
|
@ -133,16 +117,14 @@ mindspore.train.summary
|
|||
|
||||
**返回:**
|
||||
|
||||
str,日志文件的完整路径。
|
||||
str,日志文件的完整路径。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... log_dir = summary_record.log_dir
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... log_dir = summary_record.log_dir
|
||||
|
||||
|
||||
.. py:method:: record(step, train_network=None, plugin_filter=None)
|
||||
|
@ -151,29 +133,27 @@ mindspore.train.summary
|
|||
|
||||
**参数:**
|
||||
|
||||
- **step** (int):表示训练step的编号。
|
||||
- **train_network** (Cell):表示用于保存图形的备用网络。默认值:None,表示当原始网络图为None时,不保存图形summary。
|
||||
- **plugin_filter** (Optional[Callable[[str], bool]]):过滤器函数,用于通过返回False来过滤正在写入的插件。默认值:None。
|
||||
- **step** (int):表示训练step的编号。
|
||||
- **train_network** (Cell):表示用于保存图形的备用网络。默认值:None,表示当原始网络图为None时,不保存图形summary。
|
||||
- **plugin_filter** (Optional[Callable[[str], bool]]):过滤器函数,用于通过返回False来过滤正在写入的插件。默认值:None。
|
||||
|
||||
**返回:**
|
||||
|
||||
bool,表示记录进程是否成功。
|
||||
bool,表示记录进程是否成功。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError:** 参数类型错误。
|
||||
- **RuntimeError:** 磁盘空间不足。
|
||||
- **TypeError:** 参数类型错误。
|
||||
- **RuntimeError:** 磁盘空间不足。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... summary_record.record(step=2)
|
||||
...
|
||||
True
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... summary_record.record(step=2)
|
||||
...
|
||||
True
|
||||
|
||||
|
||||
.. py:method:: set_mode(mode)
|
||||
|
@ -182,17 +162,15 @@ mindspore.train.summary
|
|||
|
||||
**参数:**
|
||||
|
||||
**mode** (str):待设置的模式,为"train"或"eval"。当模式为"eval"时,`summary_record` 不记录summary算子的数据。
|
||||
**mode** (str):待设置的模式,为"train"或"eval"。当模式为"eval"时,`summary_record` 不记录summary算子的数据。
|
||||
|
||||
**异常:**
|
||||
|
||||
**ValueError:** 无法识别模式。
|
||||
**ValueError:** 无法识别模式。
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... summary_record.set_mode('eval')
|
||||
>>> from mindspore.train.summary import SummaryRecord
|
||||
>>> if __name__ == '__main__':
|
||||
... with SummaryRecord(log_dir="./summary_dir", file_prefix="xx_", file_suffix="_yy") as summary_record:
|
||||
... summary_record.set_mode('eval')
|
|
@ -14,9 +14,9 @@ mindspore.ParameterTuple
|
|||
|
||||
**参数:**
|
||||
|
||||
- **prefix** (str):参数的命名空间。
|
||||
- **init** (Union[Tensor, str, numbers.Number]):初始化参数的shape和dtype。 `init` 的定义与 `Parameter` API中的定义相同。默认值:'same'。
|
||||
- **prefix** (str):参数的命名空间。
|
||||
- **init** (Union[Tensor, str, numbers.Number]):初始化参数的shape和dtype。 `init` 的定义与 `Parameter` API中的定义相同。默认值:'same'。
|
||||
|
||||
**返回:**
|
||||
|
||||
新的参数元组。
|
||||
新的参数元组。
|
|
@ -9,53 +9,51 @@ mindspore.ms_function
|
|||
|
||||
**参数:**
|
||||
|
||||
- **fn** (Function):要编译成图的Python函数。默认值:None。
|
||||
- **obj** (Object):用于区分编译后函数的Python对象。默认值:None。
|
||||
- **input_signature** (Tensor):用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。如果指定了 `input_signature` ,则 `fn` 的每个输入都必须是 `Tensor` 。
|
||||
- **fn** (Function):要编译成图的Python函数。默认值:None。
|
||||
- **obj** (Object):用于区分编译后函数的Python对象。默认值:None。
|
||||
- **input_signature** (Tensor):用于表示输入参数的Tensor。Tensor的shape和dtype将作为函数的输入shape和dtype。如果指定了 `input_signature` ,则 `fn` 的每个输入都必须是 `Tensor` 。
|
||||
|
||||
并且 `fn` 的输入参数将不会接受 `\**kwargs` 参数。实际输入的shape和dtype必须与 `input_signature` 的相同。否则,将引发TypeError。默认值:None。
|
||||
|
||||
**返回:**
|
||||
|
||||
函数,如果 `fn` 不是None,则返回一个已经将输入 `fn` 编译成图的可执行函数;如果 `fn` 为None,则返回一个装饰器。当这个装饰器使用单个 `fn` 参数进行调用时,等价于 `fn` 不是None的场景。
|
||||
函数,如果 `fn` 不是None,则返回一个已经将输入 `fn` 编译成图的可执行函数;如果 `fn` 为None,则返回一个装饰器。当这个装饰器使用单个 `fn` 参数进行调用时,等价于 `fn` 不是None的场景。
|
||||
|
||||
**支持平台:**
|
||||
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
**样例:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import ms_function
|
||||
...
|
||||
>>> x =tensor(np.ones([1,1,3,3]).astype(np.float32))
|
||||
>>> y =tensor(np.ones([1,1,3,3]).astype(np.float32))
|
||||
...
|
||||
>>> # 通过调用ms_function创建可调用的MindSpore图
|
||||
>>> def tensor_add(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> tensor_add_graph = ms_function(fn=tensor_add)
|
||||
>>> out = tensor_add_graph(x, y)
|
||||
...
|
||||
>>> # 通过装饰器@ms_function创建一个可调用的MindSpore图
|
||||
>>> @ms_function
|
||||
... def tensor_add_with_dec(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> out = tensor_add_with_dec(x, y)
|
||||
...
|
||||
>>> # 通过带有input_signature参数的装饰器@ms_function创建一个可调用的MindSpore图
|
||||
>>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
|
||||
... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
|
||||
... def tensor_add_with_sig(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> out = tensor_add_with_sig(x, y)
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import ms_function
|
||||
...
|
||||
>>> x =tensor(np.ones([1,1,3,3]).astype(np.float32))
|
||||
>>> y =tensor(np.ones([1,1,3,3]).astype(np.float32))
|
||||
...
|
||||
>>> # 通过调用ms_function创建可调用的MindSpore图
|
||||
>>> def tensor_add(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> tensor_add_graph = ms_function(fn=tensor_add)
|
||||
>>> out = tensor_add_graph(x, y)
|
||||
...
|
||||
>>> # 通过装饰器@ms_function创建一个可调用的MindSpore图
|
||||
>>> @ms_function
|
||||
... def tensor_add_with_dec(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> out = tensor_add_with_dec(x, y)
|
||||
...
|
||||
>>> # 通过带有input_signature参数的装饰器@ms_function创建一个可调用的MindSpore图
|
||||
>>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
|
||||
... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
|
||||
... def tensor_add_with_sig(x, y):
|
||||
... z = x + y
|
||||
... return z
|
||||
...
|
||||
>>> out = tensor_add_with_sig(x, y)
|
||||
|
Loading…
Reference in New Issue