modify format 928

This commit is contained in:
huodagu 2022-09-27 16:28:21 +08:00
parent 9edd678164
commit 9214e60957
34 changed files with 89 additions and 83 deletions

View File

@ -76,7 +76,7 @@
#### FrontEnd
- [BETA] Add `mindspore.Model.fit` API, add `mindspore.train.callback.EarlyStopping` and `mindspore.train.callback.ReduceLROnPlateau` in Callback.
- [BETA] Add `mindspore.train.Model.fit` API, add `mindspore.train.callback.EarlyStopping` and `mindspore.train.callback.ReduceLROnPlateau` in Callback.
- [BETA] Support custom operator implemented by Julia.
- [BETA] Support custom operator implemented by MindSpore Hybrid DSL.
- [STABLE] The export() interface supports the export of a model using a custom encryption algorithm, and the load() interface supports the import of a model using a custom decryption algorithm.
@ -229,7 +229,7 @@
##### Python API
- DVPP simulation algorithm is no longer supported. Remove `mindspore.dataset.vision.c_transforms.SoftDvppDecodeRandomCropResizeJpeg` and `mindspore.dataset.vision.c_transforms.SoftDvppDecodeResizeJpeg` interfaces.
- Add `on_train_epoch_end` method in LossMonitor, which implements printing metric information in the epoch level when it is used in `mindspore.Model.fit`.
- Add `on_train_epoch_end` method in LossMonitor, which implements printing metric information in the epoch level when it is used in `mindspore.train.Model.fit`.
- TimeMonitor printing content changes, and the printed content is added to "train" or "eval" to distinguish between training and inference phases.
- `filter_prefix` of `mindspore.load_checkpoint` interface: empty string ("") is no longer supported, and the matching rules are changed from strong matching to fuzzy matching.
@ -242,7 +242,7 @@ For examples:
- `mindspore.context.set_context` can be simplified to `mindspore.set_context`.
- `mindspore.parallel.set_algo_parameters` can be simplified to `mindspore.set_algo_parameters`.
- `mindspore.profiler.Profiler` can be simplified to `mindspore.Profiler`.
- `mindspore.train.callback.Callback` can be simplified to `mindspore.Callback`.
- `mindspore.train.callback.Callback` can be simplified to `mindspore.train.Callback`.
The API pages are aggregated to <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html>.

View File

@ -76,7 +76,7 @@
#### FrontEnd
- [BETA] 提供`mindspore.Model.fit` API增加两种callback方法 `mindspore.train.callback.EarlyStopping``mindspore.train.callback.ReduceLROnPlateau`
- [BETA] 提供`mindspore.train.Model.fit` API增加两种callback方法 `mindspore.train.callback.EarlyStopping``mindspore.train.callback.ReduceLROnPlateau`
- [BETA] 自定义算子支持Julia算子。
- [BETA] 自定义算子支持Hybrid DSL算子。
- [STABLE] export()接口支持自定义加密算法导出模型load()接口支持自定义解密算法导入模型。
@ -229,7 +229,7 @@
##### Python API
- 不再支持DVPP模拟算法删除 `mindspore.dataset.vision.c_transforms.SoftDvppDecodeRandomCropResizeJpeg``mindspore.dataset.vision.c_transforms.SoftDvppDecodeResizeJpeg` 接口。
- LossMonitor中增加`on_train_epoch_end` 方法,实现在 `mindspore.Model.fit` 中使用时打印epoch级别的metric信息。
- LossMonitor中增加`on_train_epoch_end` 方法,实现在 `mindspore.train.Model.fit` 中使用时打印epoch级别的metric信息。
- TimeMonitor打印内容变更打印内容加入"train"或"eval"用于区分训练和推理阶段。
- load_checkpoint 接口的`filter_prefix`:不再支持空字符串(""),匹配规则由强匹配修改为模糊匹配。
@ -242,7 +242,7 @@ mindspore.context、mindspore.parallel、mindspore.profiler、mindspore.train模
- `mindspore.context.set_context`可简化为`mindspore.set_context`。
- `mindspore.parallel.set_algo_parameters`可简化为`mindspore.set_algo_parameters`。
- `mindspore.profiler.Profiler`可简化为`mindspore.Profiler`。
- `mindspore.train.callback.Callback`可简化为`mindspore.Callback`。
- `mindspore.train.callback.Callback`可简化为`mindspore.train.Callback`。
API页面统一汇总至<https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.html>

View File

@ -14,7 +14,7 @@ mindspore.Tensor.clip
参数:
- **xmin** (Tensor, scalar, None) - 最小值。如果值为None则不在间隔的下边缘执行裁剪操作。`xmin``xmax` 只能有一个为None。
- **xmax** (Tensor, scalar, None) - 最大值。如果值为None则不在间隔的上边缘执行裁剪操作。`xmin``xmax` 只能有一个为None。如果 `xmin``xmax` 是Tensor则三个Tensor将被广播进行shape匹配。
- **dtype** (`mindspore.dtype` , 可选) - 覆盖输出Tensor的dtype。默认值为None。
- **dtype** (mindspore.dtype, 可选) - 覆盖输出Tensor的dtype。默认值为None。
返回:
Tensor含有输入Tensor的元素其中values < `xmin` 被替换为 `xmin` values > `xmax` 被替换为 `xmax`

View File

@ -6,7 +6,7 @@ mindspore.Tensor.cumsum
返回指定轴方向上元素的累加值。
.. note::
如果 `dtype``int8` , `int16``bool` ,则结果 `dtype` 将提升为 `int32` ,不支持 `int64`
如果 `self.dtype``int8` , `int16``bool` ,则结果 `dtype` 将提升为 `int32` ,不支持 `int64`
参数:
- **axis** (int, 可选) - 轴,在该轴方向上的累积和。默认情况下,计算所有元素的累加和。

View File

@ -8,7 +8,7 @@ mindspore.Tensor.take
参数:
- **indices** (Tensor) - 待提取的值的shape为 `(Nj...)` 的索引。
- **axis** (int, 可选) - 在指定维度上选择值。默认情况下使用展开的输入数组。默认值None。
- **mode** ('raise', 'wrap', 'clip', 可选)
- **mode** ('raise', 'wrap', 'clip', 可选) -
- raise抛出错误。
- wrap绕接。

View File

@ -24,7 +24,7 @@ mindspore.SummaryLandscape
参数:
- **callback_fn** (python function) - Python函数对象用户需要写一个没有输入的函数返回值要求如下。
- mindspore.Model用户的模型。
- mindspore.train.Model用户的模型。
- mindspore.nn.Cell用户的网络。
- mindspore.dataset创建loss所需要的用户数据集。
- mindspore.train.Metrics用户的评估指标。

View File

@ -7,7 +7,7 @@ mindspore.set_auto_parallel_context
配置自动并行当前CPU仅支持数据并行。
.. note::
配置时,必须输入配置的名称。如果某个程序具有不同并行模式下的任务,需要提前调用reset_auto_parallel_context()为下一个任务设置新的并行模式。若要设置或更改并行模式必须在创建任何Initializer之前调用接口否则在编译网络时可能会出现RuntimeError。
配置时,必须输入配置的名称。如果某个程序具有不同并行模式下的任务,需要提前调用 :func:`mindspore.reset_auto_parallel_context` 为下一个任务设置新的并行模式。若要设置或更改并行模式必须在创建任何Initializer之前调用接口否则在编译网络时可能会出现RuntimeError。
某些配置适用于特定的并行模式,有关详细信息,请参见下表:

View File

@ -8,7 +8,7 @@ mindspore.set_context
在运行程序之前应配置context。如果没有配置默认情况下将根据设备目标进行自动设置。
.. note::
设置属性时,必须输入属性名称。
设置属性时,必须输入属性名称。net初始化后不建议更改模式因为一些操作的实现在Graph模式和PyNative模式下是不同的。默认值GRAPH_MODE。
某些配置适用于特定的设备,有关详细信息,请参见下表:

View File

@ -19,7 +19,7 @@ mindspore.nn.probability.distribution.Distribution
而指数分布的 `dist_spec_args``rate`
所有方法都包含一个 `dist_spec_args` 作为可选参数。
传入 `dist_spec_args` 可以让该方法基于新的分布的参数值进行运算但如此做不会改变原始分布的参数。
传入 `dist_spec_args` 可以让该方法基于新的分布的参数值进行运算但如此做不会改变原始分布的参数。
.. py:method:: cdf(value, *args, **kwargs)

View File

@ -25,7 +25,7 @@ mindspore.ops.AlltoAll
- **input_x** (Tensor) - shape为 :math:`(x_1, x_2, ..., x_R)`
输出:
Tensor设输入的shape是 :math:`(x_1, x_2, ..., x_R)`则输出的shape为 :math:`(y_1, y_2, ..., y_R),其中:
Tensor设输入的shape是 :math:`(x_1, x_2, ..., x_R)`则输出的shape为 :math:`(y_1, y_2, ..., y_R)`,其中:
:math:`y_{split\_dim} = x_{split\_dim} / split\_count`

View File

@ -8,7 +8,7 @@ mindspore.ops.assign_sub
输入 `variable``value` 会通过隐式数据类型转换使数据类型保持一致。如果数据类型不同,低精度的数据类型会被转换到高精度的数据类型。如果 `value` 为标量会被自动转换为Tensor其数据类型会与 `variable` 保持一致。
.. Note::
由于 `variable` 类型为 `Parameter` ,其数据类型不能改变因此只允许 `value` 的数据类型转变为 `variable` 的数据类型。而且由于不同设备支持的转换类型会有所不同,推荐在使用此操作时使用相同的数据类型。
由于 `variable` 类型为 `Parameter` ,其数据类型不能改变因此只允许 `value` 的数据类型转变为 `variable` 的数据类型。而且由于不同设备支持的转换类型会有所不同,推荐在使用此操作时使用相同的数据类型。
参数:
- **variable** (Parameter) - 待更新的网络参数shape: :math:`(N,*)` ,其中 :math:`*` 表示任何数量的附加维度。其轶应小于8。

View File

@ -25,19 +25,19 @@ mindspore.train.EarlyStopping
训练开始时初始化相关的变量。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: on_train_end(run_context)
打印是第几个epoch执行早停。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: on_train_epoch_end(run_context)
训练过程中,若监控指标在等待 `patience` 个epoch后仍没有改善则停止训练。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。

View File

@ -8,18 +8,18 @@ mindspore.train.History
用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 `Tensor``numpy.ndarray`,则记录此返回值均值,如果返回 `tuple``list`,则记录第一个元素。
.. note::
通常使用在 `mindspore.Model.train``mindspore.Model.fit` 中。
通常使用在 `mindspore.train.Model.train``mindspore.train.Model.fit` 中。
.. py:method:: begin(run_context)
训练开始时初始化History对象的epoch属性。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: epoch_end(run_context)
epoch结束时记录网络输出和评估指标的相关信息。
参数:
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。

View File

@ -5,7 +5,7 @@ mindspore.train.LambdaCallback
用于自定义简单的callback。
使用匿名函数构建callback定义的匿名函数将在 `mindspore.Model.{train | eval | fit}` 的对应阶段被调用。
使用匿名函数构建callback定义的匿名函数将在 `mindspore.train.Model.{train | eval | fit}` 的对应阶段被调用。
请注意callback的每个阶段都需要一个位置参数`run_context`

View File

@ -21,11 +21,11 @@ mindspore.train.LossMonitor
LossMoniter用于 `model.fit`即边训练边推理场景时打印训练的loss和当前epoch推理的metrics。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: step_end(run_context)
step结束时打印训练loss。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。

View File

@ -83,7 +83,7 @@
如果 `valid_dataset` 不为None在训练过程中同时执行推理。
更多详细信息请参考 `mindspore.Model.train``mindspore.Model.eval`
更多详细信息请参考 `mindspore.train.Model.train``mindspore.train.Model.eval`
参数:
- **epoch** (int) - 训练执行轮次。通常每个epoch都会使用全量数据集进行训练。当 `dataset_sink_mode` 设置为True且 `sink_size` 大于零时则每个epoch训练次数为 `sink_size` 而不是数据集的总步数。如果 `epoch``initial_epoch` 一起使用,它表示训练的最后一个 `epoch` 是多少。

View File

@ -1,5 +1,5 @@
mindspore.train.OnRequestExit
===========================
=============================
.. py:class:: mindspore.train.OnRequestExit(save_ckpt=True, save_mindir=True, file_name='Net', directory='./', sig=signal.SIGUSR1)
@ -24,48 +24,46 @@ mindspore.train.OnRequestExit
在训练开始时,注册用户传入停止信号的处理程序。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: on_train_step_end(run_context)
在训练step结束时根据是否接收到停止信号`run_context`的`_stop_requested`属性置为True。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`
.. py:method:: on_train_step_end(run_context)
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_epoch_end(run_context)
在训练epoch结束时根据是否接收到停止信号`run_context`的`_stop_requested`属性置为True。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`
.. py:method:: on_train_step_end(run_context)
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`
.. py:method:: on_train_end(run_context)
在训练结束时根据是否接收到停止信号保存checkpoint或者mindir。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: on_eval_begin(run_context)
在推理开始时,注册用户传入停止信号的处理程序。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: on_eval_step_end(run_context)
在推理step结束时根据是否接收到停止信号`run_context`的`_stop_requested`属性置为True。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: on_eval_end(run_context)
在推理结束时根据是否接收到停止信号保存checkpoint或者mindir。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。

View File

@ -30,11 +30,11 @@ mindspore.train.ReduceLROnPlateau
训练开始时初始化相关的变量。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: on_train_epoch_end(run_context)
训练过程中,若监控指标在等待 `patience` 个epoch后仍没有改善则改变学习率。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。

View File

@ -16,11 +16,11 @@ mindspore.train.TimeMonitor
在epoch开始时记录时间。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
.. py:method:: epoch_end(run_context)
在epoch结束时打印epoch的耗时。
参数:
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。

View File

@ -464,7 +464,7 @@ def set_auto_parallel_context(**kwargs):
Note:
Attribute name is required for setting attributes.
If a program has tasks on different parallel modes, before setting a new parallel mode for the
next task, interface mindspore.reset_auto_parallel_context() should be called to reset
next task, interface :func:`mindspore.reset_auto_parallel_context` should be called to reset
the configuration.
Setting or changing parallel modes must be called before creating any Initializer, otherwise,
it may have RuntimeError when compiling the network.

View File

@ -432,7 +432,7 @@ def choice_with_mask(input_x, count=256, seed=0, seed2=0):
sample, while the mask tensor denotes which elements in the index tensor are valid.
Args:
input_x (Tensor): The input tensor.
input_x (Tensor[bool]): The input tensor.
The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
seed (int): Random seed. Default: 0.

View File

@ -769,7 +769,7 @@ class AlltoAll(PrimitiveWithInfer):
Outputs:
Tensor. If the shape of input tensor is :math:`(x_1, x_2, ..., x_R)`, then the shape of output tensor is
:math:`(y_1, y_2, ..., y_R), where:
:math:`(y_1, y_2, ..., y_R)`, where:
:math:`y_{split\_dim} = x_{split\_dim} / split\_count`

View File

@ -126,7 +126,7 @@ class EarlyStopping(Callback):
Args:
run_context (RunContext): Context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
self.wait = 0
@ -144,7 +144,7 @@ class EarlyStopping(Callback):
Args:
run_context (RunContext): Context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
cb_params = run_context.original_args()
@ -185,7 +185,7 @@ class EarlyStopping(Callback):
Args:
run_context (RunContext): Context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
if self.stopped_epoch > 0 and self.verbose:
@ -201,7 +201,7 @@ class EarlyStopping(Callback):
Args:
cb_params (dict): A dictionary stores context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
monitor_candidates = {}
if self.monitor == "loss":

View File

@ -31,7 +31,7 @@ class History(Callback):
outputs will be recorded.
Note:
Normally used in `mindspore.Model.train` or `mindspore.Model.fit`.
Normally used in `mindspore.train.Model.train` or `mindspore.train.Model.fit`.
Examples:
>>> import numpy as np
@ -61,8 +61,8 @@ class History(Callback):
Initialize the `epoch` property at the begin of training.
Args:
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. For more details,
please refer to :class:`mindspore.RunContext`.
run_context (RunContext): Context of the `mindspore.train.Model.{train | eval}`. For more details,
please refer to :class:`mindspore.train.RunContext`.
"""
self.epoch = {"epoch": []}
@ -71,8 +71,8 @@ class History(Callback):
Records the first element of network outputs and metrics information at the end of epoch.
Args:
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. For more details,
please refer to :class:`mindspore.RunContext`.
run_context (RunContext): Context of the `mindspore.train.Model.{train | eval}`. For more details,
please refer to :class:`mindspore.train.RunContext`.
"""
cb_params = run_context.original_args()
epoch = cb_params.get("cur_epoch_num", 1)

View File

@ -23,7 +23,7 @@ class LambdaCallback(Callback):
Callback for creating simple, custom callbacks.
This callback is constructed with anonymous functions that will be called
at the appropriate time (during `mindspore.Model.{train | eval | fit}`). Note that
at the appropriate time (during `mindspore.train.Model.{train | eval | fit}`). Note that
each stage of callbacks expects one positional arguments: `run_context`.
Note:

View File

@ -262,7 +262,7 @@ class SummaryLandscape:
callback_fn (python function): A python function object. User needs to write a function,
it has no input, and the return requirements are as follows.
- mindspore.Model: User's model object.
- mindspore.train.Model: User's model object.
- mindspore.nn.Cell: User's network object.
- mindspore.dataset: User's dataset object for create loss landscape.
- mindspore.train.Metrics: User's metrics object.

View File

@ -63,7 +63,7 @@ class LossMonitor(Callback):
Args:
run_context (RunContext): Include some information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
cb_params = run_context.original_args()
@ -94,7 +94,7 @@ class LossMonitor(Callback):
Args:
run_context (RunContext): Include some information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
cb_params = run_context.original_args()
metrics = cb_params.get("metrics")

View File

@ -87,7 +87,7 @@ class OnRequestExit(Callback):
Args:
run_context (RunContext): Context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
if os.path.isfile(f"{self.train_file_path}.ckpt"):
cb_params = run_context.original_args()
@ -101,7 +101,7 @@ class OnRequestExit(Callback):
Args:
run_context (RunContext): Include some information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
if self.stop:
run_context.request_stop()
@ -112,7 +112,7 @@ class OnRequestExit(Callback):
Args:
run_context (RunContext): Include some information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
if self.stop:
run_context.request_stop()
@ -123,7 +123,7 @@ class OnRequestExit(Callback):
Args:
run_context (RunContext): Include some information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
if not self.stop:
return
@ -141,7 +141,7 @@ class OnRequestExit(Callback):
Args:
run_context (RunContext): Context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
cb_params = run_context.original_args()
eval_net = cb_params.eval_network
@ -157,7 +157,7 @@ class OnRequestExit(Callback):
Args:
run_context (RunContext): Include some information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
if self.stop:
run_context.request_stop()
@ -168,7 +168,7 @@ class OnRequestExit(Callback):
Args:
run_context (RunContext): Include some information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
if not self.stop:
return

View File

@ -124,7 +124,7 @@ class ReduceLROnPlateau(Callback):
Args:
run_context (RunContext): Context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
self.cooldown_counter = 0
self.wait = 0
@ -140,7 +140,7 @@ class ReduceLROnPlateau(Callback):
Args:
run_context (RunContext): Context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
cb_params = run_context.original_args()
cur_lr = cb_params.optimizer.learning_rate
@ -188,7 +188,7 @@ class ReduceLROnPlateau(Callback):
Args:
cb_params (dict): A dictionary stores context information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
monitor_candidates = {}
if self.monitor == "loss":

View File

@ -58,7 +58,7 @@ class TimeMonitor(Callback):
Args:
run_context (RunContext): Context of the process running. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
self.epoch_time = time.time()
@ -68,7 +68,7 @@ class TimeMonitor(Callback):
Args:
run_context (RunContext): Context of the process running. For more details,
please refer to :class:`mindspore.RunContext`.
please refer to :class:`mindspore.train.RunContext`.
"""
epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size

View File

@ -186,9 +186,11 @@ def connect_network_with_dataset(network, dataset_helper):
Note:
In the case of running the network on Ascend/GPU in graph mode, this function will wrap the input network with
'GetNext', in other cases, the input network will be returned with no change.
The 'GetNext' is required to get data only in sink mode, so this function is not applicable to no-sink mode.
:class:`mindspore.ops.GetNext`. In other cases, the input network will be returned with no change.
The :class:`mindspore.ops.GetNext` is required to get data only in sink mode,
so this function is not applicable to no-sink mode.
when dataset_helper's dataset_sink_mode is True, it can only be connected to one network.
Args:
network (Cell): The training network for dataset.
dataset_helper (DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue

View File

@ -72,8 +72,8 @@ class Dice(Metric):
Updates the internal evaluation result :math:`y_pred` and :math:`y`.
Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, ...)`.
inputs (tuple): Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, ...)`.
Raises:
ValueError: If the number of the inputs is not 2.

View File

@ -1088,7 +1088,7 @@ class Model:
Evaluation process will be performed during training process if `valid_dataset` is provided.
More details please refer to `mindspore.Model.train` and `mindspore.Model.eval`.
More details please refer to `mindspore.train.Model.train` and `mindspore.train.Model.eval`.
Args:
epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch.
@ -1231,7 +1231,7 @@ class Model:
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
"""
Evaluation process in `mindspore.Model.fit`.
Evaluation process in `mindspore.train.Model.fit`.
Args:
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process

View File

@ -46,7 +46,10 @@ def define_model():
class MyCallbackOldMethod(Callback):
""" Raise warning in `mindspore.Model.train` and `mindspore.Model.eval`; raise error in `mindspore.Model.fit`"""
"""
Raise warning in `mindspore.train.Model.train` and `mindspore.train.Model.eval`;
raise error in `mindspore.train.Model.fit`.
"""
def begin(self, run_context):
print("custom callback: print on begin, just for test.")
@ -62,7 +65,10 @@ class MyCallbackOldMethod(Callback):
class MyCallbackNewMethod(Callback):
""" Custom callback running in `mindspore.Model.train`, `mindspore.Model.eval`, `mindspore.Model.fit`"""
"""
Custom callback running in `mindspore.train.Model.train`, `mindspore.train.Model.eval`,
`mindspore.train.Model.fit`.
"""
def on_train_epoch_end(self, run_context):
cb_params = run_context.original_args()
print("custom callback: train epoch end, loss is %s" % (cb_params.get("net_outputs")))
@ -74,7 +80,7 @@ class MyCallbackNewMethod(Callback):
def test_fit_train_dataset_non_sink_mode():
"""
Feature: `mindspore.Model.fit` with train dataset in non-sink mode.
Feature: `mindspore.train.Model.fit` with train dataset in non-sink mode.
Description: test fit with train dataset in non-sink mode.
Expectation: run in non-sink mode.
"""
@ -87,7 +93,7 @@ def test_fit_train_dataset_non_sink_mode():
def test_fit_train_dataset_sink_mode():
"""
Feature: `mindspore.Model.fit` with train dataset in sink mode.
Feature: `mindspore.train.Model.fit` with train dataset in sink mode.
Description: test fit with train dataset in sink mode.
Expectation: run in sink mode.
"""
@ -100,7 +106,7 @@ def test_fit_train_dataset_sink_mode():
def test_fit_valid_dataset_non_sink_mode():
"""
Feature: `mindspore.Model.fit` with valid dataset in non-sink mode.
Feature: `mindspore.train.Model.fit` with valid dataset in non-sink mode.
Description: test fit with valid dataset in non-sink mode.
Expectation: run in non-sink mode.
"""
@ -113,7 +119,7 @@ def test_fit_valid_dataset_non_sink_mode():
def test_fit_valid_dataset_sink_mode():
"""
Feature: `mindspore.Model.fit` with valid dataset in sink mode.
Feature: `mindspore.train.Model.fit` with valid dataset in sink mode.
Description: test fit with valid dataset in sink mode.
Expectation: run in sink mode.
"""
@ -126,7 +132,7 @@ def test_fit_valid_dataset_sink_mode():
def test_fit_without_valid_dataset():
"""
Feature: `mindspore.Model.fit` without `valid_dataset` input .
Feature: `mindspore.train.Model.fit` without `valid_dataset` input .
Description: test fit when `valid_dataset` is None and `valid_dataset_sink_mode` is True or False.
Expectation: network train without eval process, `valid_dataset_sink_mode` does not take effect.
"""
@ -139,7 +145,7 @@ def test_fit_without_valid_dataset():
def test_fit_valid_frequency():
"""
Feature: check `valid_frequency` input in `mindspore.Model.fit`.
Feature: check `valid_frequency` input in `mindspore.train.Model.fit`.
Description: when `valid_frequency` is integer, list or other types.
Expectation: raise ValueError when the type of valid_frequency is not int or list.
"""
@ -156,7 +162,7 @@ def test_fit_valid_frequency():
def test_fit_callbacks():
"""
Feature: check `callbacks` input in `mindspore.Model.fit`.
Feature: check `callbacks` input in `mindspore.train.Model.fit`.
Description: test internal or custom callbacks in fit.
Expectation: raise ValueError when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'.
"""
@ -178,7 +184,7 @@ def test_fit_callbacks():
def test_train_eval_callbacks():
"""
Feature: check `callbacks` input in `mindspore.Model.train` or `mindspore.Model.eval`.
Feature: check `callbacks` input in `mindspore.train.Model.train` or `mindspore.train.Model.eval`.
Description: test internal or custom callbacks in train or eval.
Expectation: raise warning when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'.
"""