forked from mindspore-Ecosystem/mindspore
modify format 928
This commit is contained in:
parent
9edd678164
commit
9214e60957
|
@ -76,7 +76,7 @@
|
|||
|
||||
#### FrontEnd
|
||||
|
||||
- [BETA] Add `mindspore.Model.fit` API, add `mindspore.train.callback.EarlyStopping` and `mindspore.train.callback.ReduceLROnPlateau` in Callback.
|
||||
- [BETA] Add `mindspore.train.Model.fit` API, add `mindspore.train.callback.EarlyStopping` and `mindspore.train.callback.ReduceLROnPlateau` in Callback.
|
||||
- [BETA] Support custom operator implemented by Julia.
|
||||
- [BETA] Support custom operator implemented by MindSpore Hybrid DSL.
|
||||
- [STABLE] The export() interface supports the export of a model using a custom encryption algorithm, and the load() interface supports the import of a model using a custom decryption algorithm.
|
||||
|
@ -229,7 +229,7 @@
|
|||
##### Python API
|
||||
|
||||
- DVPP simulation algorithm is no longer supported. Remove `mindspore.dataset.vision.c_transforms.SoftDvppDecodeRandomCropResizeJpeg` and `mindspore.dataset.vision.c_transforms.SoftDvppDecodeResizeJpeg` interfaces.
|
||||
- Add `on_train_epoch_end` method in LossMonitor, which implements printing metric information in the epoch level when it is used in `mindspore.Model.fit`.
|
||||
- Add `on_train_epoch_end` method in LossMonitor, which implements printing metric information in the epoch level when it is used in `mindspore.train.Model.fit`.
|
||||
- TimeMonitor printing content changes, and the printed content is added to "train" or "eval" to distinguish between training and inference phases.
|
||||
- `filter_prefix` of `mindspore.load_checkpoint` interface: empty string ("") is no longer supported, and the matching rules are changed from strong matching to fuzzy matching.
|
||||
|
||||
|
@ -242,7 +242,7 @@ For examples:
|
|||
- `mindspore.context.set_context` can be simplified to `mindspore.set_context`.
|
||||
- `mindspore.parallel.set_algo_parameters` can be simplified to `mindspore.set_algo_parameters`.
|
||||
- `mindspore.profiler.Profiler` can be simplified to `mindspore.Profiler`.
|
||||
- `mindspore.train.callback.Callback` can be simplified to `mindspore.Callback`.
|
||||
- `mindspore.train.callback.Callback` can be simplified to `mindspore.train.Callback`.
|
||||
|
||||
The API pages are aggregated to <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html>.
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@
|
|||
|
||||
#### FrontEnd
|
||||
|
||||
- [BETA] 提供`mindspore.Model.fit` API,增加两种callback方法 `mindspore.train.callback.EarlyStopping` 和 `mindspore.train.callback.ReduceLROnPlateau`。
|
||||
- [BETA] 提供`mindspore.train.Model.fit` API,增加两种callback方法 `mindspore.train.callback.EarlyStopping` 和 `mindspore.train.callback.ReduceLROnPlateau`。
|
||||
- [BETA] 自定义算子支持Julia算子。
|
||||
- [BETA] 自定义算子支持Hybrid DSL算子。
|
||||
- [STABLE] export()接口支持自定义加密算法导出模型,load()接口支持自定义解密算法导入模型。
|
||||
|
@ -229,7 +229,7 @@
|
|||
##### Python API
|
||||
|
||||
- 不再支持DVPP模拟算法,删除 `mindspore.dataset.vision.c_transforms.SoftDvppDecodeRandomCropResizeJpeg` 和 `mindspore.dataset.vision.c_transforms.SoftDvppDecodeResizeJpeg` 接口。
|
||||
- LossMonitor中增加`on_train_epoch_end` 方法,实现在 `mindspore.Model.fit` 中使用时,打印epoch级别的metric信息。
|
||||
- LossMonitor中增加`on_train_epoch_end` 方法,实现在 `mindspore.train.Model.fit` 中使用时,打印epoch级别的metric信息。
|
||||
- TimeMonitor打印内容变更,打印内容加入"train"或"eval"用于区分训练和推理阶段。
|
||||
- load_checkpoint 接口的`filter_prefix`:不再支持空字符串(""),匹配规则由强匹配修改为模糊匹配。
|
||||
|
||||
|
@ -242,7 +242,7 @@ mindspore.context、mindspore.parallel、mindspore.profiler、mindspore.train模
|
|||
- `mindspore.context.set_context`可简化为`mindspore.set_context`。
|
||||
- `mindspore.parallel.set_algo_parameters`可简化为`mindspore.set_algo_parameters`。
|
||||
- `mindspore.profiler.Profiler`可简化为`mindspore.Profiler`。
|
||||
- `mindspore.train.callback.Callback`可简化为`mindspore.Callback`。
|
||||
- `mindspore.train.callback.Callback`可简化为`mindspore.train.Callback`。
|
||||
|
||||
API页面统一汇总至:<https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.html>。
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ mindspore.Tensor.clip
|
|||
参数:
|
||||
- **xmin** (Tensor, scalar, None) - 最小值。如果值为None,则不在间隔的下边缘执行裁剪操作。`xmin` 或 `xmax` 只能有一个为None。
|
||||
- **xmax** (Tensor, scalar, None) - 最大值。如果值为None,则不在间隔的上边缘执行裁剪操作。`xmin` 或 `xmax` 只能有一个为None。如果 `xmin` 或 `xmax` 是Tensor,则三个Tensor将被广播进行shape匹配。
|
||||
- **dtype** (`mindspore.dtype` , 可选) - 覆盖输出Tensor的dtype。默认值为None。
|
||||
- **dtype** (mindspore.dtype, 可选) - 覆盖输出Tensor的dtype。默认值为None。
|
||||
|
||||
返回:
|
||||
Tensor,含有输入Tensor的元素,其中values < `xmin` 被替换为 `xmin` ,values > `xmax` 被替换为 `xmax` 。
|
||||
|
|
|
@ -6,7 +6,7 @@ mindspore.Tensor.cumsum
|
|||
返回指定轴方向上元素的累加值。
|
||||
|
||||
.. note::
|
||||
如果 `dtype` 为 `int8` , `int16` 或 `bool` ,则结果 `dtype` 将提升为 `int32` ,不支持 `int64` 。
|
||||
如果 `self.dtype` 为 `int8` , `int16` 或 `bool` ,则结果 `dtype` 将提升为 `int32` ,不支持 `int64` 。
|
||||
|
||||
参数:
|
||||
- **axis** (int, 可选) - 轴,在该轴方向上的累积和。默认情况下,计算所有元素的累加和。
|
||||
|
|
|
@ -8,7 +8,7 @@ mindspore.Tensor.take
|
|||
参数:
|
||||
- **indices** (Tensor) - 待提取的值的shape为 `(Nj...)` 的索引。
|
||||
- **axis** (int, 可选) - 在指定维度上选择值。默认情况下,使用展开的输入数组。默认值:None。
|
||||
- **mode** ('raise', 'wrap', 'clip', 可选)
|
||||
- **mode** ('raise', 'wrap', 'clip', 可选) -
|
||||
|
||||
- raise:抛出错误。
|
||||
- wrap:绕接。
|
||||
|
|
|
@ -24,7 +24,7 @@ mindspore.SummaryLandscape
|
|||
参数:
|
||||
- **callback_fn** (python function) - Python函数对象,用户需要写一个没有输入的函数,返回值要求如下。
|
||||
|
||||
- mindspore.Model:用户的模型。
|
||||
- mindspore.train.Model:用户的模型。
|
||||
- mindspore.nn.Cell:用户的网络。
|
||||
- mindspore.dataset:创建loss所需要的用户数据集。
|
||||
- mindspore.train.Metrics:用户的评估指标。
|
||||
|
|
|
@ -7,7 +7,7 @@ mindspore.set_auto_parallel_context
|
|||
配置自动并行,当前CPU仅支持数据并行。
|
||||
|
||||
.. note::
|
||||
配置时,必须输入配置的名称。如果某个程序具有不同并行模式下的任务,需要提前调用reset_auto_parallel_context()为下一个任务设置新的并行模式。若要设置或更改并行模式,必须在创建任何Initializer之前调用接口,否则,在编译网络时,可能会出现RuntimeError。
|
||||
配置时,必须输入配置的名称。如果某个程序具有不同并行模式下的任务,需要提前调用 :func:`mindspore.reset_auto_parallel_context` 为下一个任务设置新的并行模式。若要设置或更改并行模式,必须在创建任何Initializer之前调用接口,否则,在编译网络时,可能会出现RuntimeError。
|
||||
|
||||
某些配置适用于特定的并行模式,有关详细信息,请参见下表:
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ mindspore.set_context
|
|||
在运行程序之前,应配置context。如果没有配置,默认情况下将根据设备目标进行自动设置。
|
||||
|
||||
.. note::
|
||||
设置属性时,必须输入属性名称。
|
||||
设置属性时,必须输入属性名称。net初始化后不建议更改模式,因为一些操作的实现在Graph模式和PyNative模式下是不同的。默认值:GRAPH_MODE。
|
||||
|
||||
某些配置适用于特定的设备,有关详细信息,请参见下表:
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ mindspore.nn.probability.distribution.Distribution
|
|||
而指数分布的 `dist_spec_args` 为 `rate`。
|
||||
|
||||
所有方法都包含一个 `dist_spec_args` 作为可选参数。
|
||||
传入 `dist_spec_args` 可以让该方法基于新的分布的参数值进行运算。但如此做不会改变原始分布的参数。
|
||||
传入 `dist_spec_args` 可以让该方法基于新的分布的参数值进行运算,但如此做不会改变原始分布的参数。
|
||||
|
||||
.. py:method:: cdf(value, *args, **kwargs)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ mindspore.ops.AlltoAll
|
|||
- **input_x** (Tensor) - shape为 :math:`(x_1, x_2, ..., x_R)`。
|
||||
|
||||
输出:
|
||||
Tensor,设输入的shape是 :math:`(x_1, x_2, ..., x_R)`,则输出的shape为 :math:`(y_1, y_2, ..., y_R),其中:
|
||||
Tensor,设输入的shape是 :math:`(x_1, x_2, ..., x_R)`,则输出的shape为 :math:`(y_1, y_2, ..., y_R)`,其中:
|
||||
|
||||
:math:`y_{split\_dim} = x_{split\_dim} / split\_count`
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ mindspore.ops.assign_sub
|
|||
输入 `variable` 和 `value` 会通过隐式数据类型转换使数据类型保持一致。如果数据类型不同,低精度的数据类型会被转换到高精度的数据类型。如果 `value` 为标量会被自动转换为Tensor,其数据类型会与 `variable` 保持一致。
|
||||
|
||||
.. Note::
|
||||
由于 `variable` 类型为 `Parameter` ,其数据类型不能改变。因此只允许 `value` 的数据类型转变为 `variable` 的数据类型。而且由于不同设备支持的转换类型会有所不同,推荐在使用此操作时使用相同的数据类型。
|
||||
由于 `variable` 类型为 `Parameter` ,其数据类型不能改变,因此只允许 `value` 的数据类型转变为 `variable` 的数据类型。而且由于不同设备支持的转换类型会有所不同,推荐在使用此操作时使用相同的数据类型。
|
||||
|
||||
参数:
|
||||
- **variable** (Parameter) - 待更新的网络参数,shape: :math:`(N,*)` ,其中 :math:`*` 表示任何数量的附加维度。其轶应小于8。
|
||||
|
|
|
@ -25,19 +25,19 @@ mindspore.train.EarlyStopping
|
|||
训练开始时初始化相关的变量。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_end(run_context)
|
||||
|
||||
打印是第几个epoch执行早停。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_epoch_end(run_context)
|
||||
|
||||
训练过程中,若监控指标在等待 `patience` 个epoch后仍没有改善,则停止训练。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
|
|
|
@ -8,18 +8,18 @@ mindspore.train.History
|
|||
用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 `Tensor` 或 `numpy.ndarray`,则记录此返回值均值,如果返回 `tuple` 或 `list`,则记录第一个元素。
|
||||
|
||||
.. note::
|
||||
通常使用在 `mindspore.Model.train` 和 `mindspore.Model.fit` 中。
|
||||
通常使用在 `mindspore.train.Model.train` 和 `mindspore.train.Model.fit` 中。
|
||||
|
||||
.. py:method:: begin(run_context)
|
||||
|
||||
训练开始时初始化History对象的epoch属性。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: epoch_end(run_context)
|
||||
|
||||
epoch结束时记录网络输出和评估指标的相关信息。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的一些基本信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
|
|
@ -5,7 +5,7 @@ mindspore.train.LambdaCallback
|
|||
|
||||
用于自定义简单的callback。
|
||||
|
||||
使用匿名函数构建callback,定义的匿名函数将在 `mindspore.Model.{train | eval | fit}` 的对应阶段被调用。
|
||||
使用匿名函数构建callback,定义的匿名函数将在 `mindspore.train.Model.{train | eval | fit}` 的对应阶段被调用。
|
||||
|
||||
请注意,callback的每个阶段都需要一个位置参数:`run_context`。
|
||||
|
||||
|
|
|
@ -21,11 +21,11 @@ mindspore.train.LossMonitor
|
|||
LossMoniter用于 `model.fit`,即边训练边推理场景时,打印训练的loss和当前epoch推理的metrics。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: step_end(run_context)
|
||||
|
||||
step结束时打印训练loss。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
|
|
@ -83,7 +83,7 @@
|
|||
|
||||
如果 `valid_dataset` 不为None,在训练过程中同时执行推理。
|
||||
|
||||
更多详细信息请参考 `mindspore.Model.train` 和 `mindspore.Model.eval`。
|
||||
更多详细信息请参考 `mindspore.train.Model.train` 和 `mindspore.train.Model.eval`。
|
||||
|
||||
参数:
|
||||
- **epoch** (int) - 训练执行轮次。通常每个epoch都会使用全量数据集进行训练。当 `dataset_sink_mode` 设置为True且 `sink_size` 大于零时,则每个epoch训练次数为 `sink_size` 而不是数据集的总步数。如果 `epoch` 与 `initial_epoch` 一起使用,它表示训练的最后一个 `epoch` 是多少。
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
mindspore.train.OnRequestExit
|
||||
===========================
|
||||
=============================
|
||||
|
||||
.. py:class:: mindspore.train.OnRequestExit(save_ckpt=True, save_mindir=True, file_name='Net', directory='./', sig=signal.SIGUSR1)
|
||||
|
||||
|
@ -24,48 +24,46 @@ mindspore.train.OnRequestExit
|
|||
在训练开始时,注册用户传入停止信号的处理程序。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_step_end(run_context)
|
||||
|
||||
在训练step结束时,根据是否接收到停止信号,将`run_context`的`_stop_requested`属性置为True。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
.. py:method:: on_train_step_end(run_context)
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_epoch_end(run_context)
|
||||
|
||||
在训练epoch结束时,根据是否接收到停止信号,将`run_context`的`_stop_requested`属性置为True。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
.. py:method:: on_train_step_end(run_context)
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_end(run_context)
|
||||
|
||||
在训练结束时,根据是否接收到停止信号,保存checkpoint或者mindir。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_begin(run_context)
|
||||
|
||||
在推理开始时,注册用户传入停止信号的处理程序。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_step_end(run_context)
|
||||
|
||||
在推理step结束时,根据是否接收到停止信号,将`run_context`的`_stop_requested`属性置为True。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_eval_end(run_context)
|
||||
|
||||
在推理结束时,根据是否接收到停止信号,保存checkpoint或者mindir。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
|
|
@ -30,11 +30,11 @@ mindspore.train.ReduceLROnPlateau
|
|||
训练开始时初始化相关的变量。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: on_train_epoch_end(run_context)
|
||||
|
||||
训练过程中,若监控指标在等待 `patience` 个epoch后仍没有改善,则改变学习率。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
|
|
@ -16,11 +16,11 @@ mindspore.train.TimeMonitor
|
|||
在epoch开始时记录时间。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
||||
.. py:method:: epoch_end(run_context)
|
||||
|
||||
在epoch结束时打印epoch的耗时。
|
||||
|
||||
参数:
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.train.RunContext`。
|
||||
|
|
|
@ -464,7 +464,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
Note:
|
||||
Attribute name is required for setting attributes.
|
||||
If a program has tasks on different parallel modes, before setting a new parallel mode for the
|
||||
next task, interface mindspore.reset_auto_parallel_context() should be called to reset
|
||||
next task, interface :func:`mindspore.reset_auto_parallel_context` should be called to reset
|
||||
the configuration.
|
||||
Setting or changing parallel modes must be called before creating any Initializer, otherwise,
|
||||
it may have RuntimeError when compiling the network.
|
||||
|
|
|
@ -432,7 +432,7 @@ def choice_with_mask(input_x, count=256, seed=0, seed2=0):
|
|||
sample, while the mask tensor denotes which elements in the index tensor are valid.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): The input tensor.
|
||||
input_x (Tensor[bool]): The input tensor.
|
||||
The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
|
||||
count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
|
||||
seed (int): Random seed. Default: 0.
|
||||
|
|
|
@ -769,7 +769,7 @@ class AlltoAll(PrimitiveWithInfer):
|
|||
|
||||
Outputs:
|
||||
Tensor. If the shape of input tensor is :math:`(x_1, x_2, ..., x_R)`, then the shape of output tensor is
|
||||
:math:`(y_1, y_2, ..., y_R), where:
|
||||
:math:`(y_1, y_2, ..., y_R)`, where:
|
||||
|
||||
:math:`y_{split\_dim} = x_{split\_dim} / split\_count`
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ class EarlyStopping(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
|
||||
self.wait = 0
|
||||
|
@ -144,7 +144,7 @@ class EarlyStopping(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
|
@ -185,7 +185,7 @@ class EarlyStopping(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
|
||||
if self.stopped_epoch > 0 and self.verbose:
|
||||
|
@ -201,7 +201,7 @@ class EarlyStopping(Callback):
|
|||
|
||||
Args:
|
||||
cb_params (dict): A dictionary stores context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
monitor_candidates = {}
|
||||
if self.monitor == "loss":
|
||||
|
|
|
@ -31,7 +31,7 @@ class History(Callback):
|
|||
outputs will be recorded.
|
||||
|
||||
Note:
|
||||
Normally used in `mindspore.Model.train` or `mindspore.Model.fit`.
|
||||
Normally used in `mindspore.train.Model.train` or `mindspore.train.Model.fit`.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -61,8 +61,8 @@ class History(Callback):
|
|||
Initialize the `epoch` property at the begin of training.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
run_context (RunContext): Context of the `mindspore.train.Model.{train | eval}`. For more details,
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
self.epoch = {"epoch": []}
|
||||
|
||||
|
@ -71,8 +71,8 @@ class History(Callback):
|
|||
Records the first element of network outputs and metrics information at the end of epoch.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
run_context (RunContext): Context of the `mindspore.train.Model.{train | eval}`. For more details,
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
epoch = cb_params.get("cur_epoch_num", 1)
|
||||
|
|
|
@ -23,7 +23,7 @@ class LambdaCallback(Callback):
|
|||
Callback for creating simple, custom callbacks.
|
||||
|
||||
This callback is constructed with anonymous functions that will be called
|
||||
at the appropriate time (during `mindspore.Model.{train | eval | fit}`). Note that
|
||||
at the appropriate time (during `mindspore.train.Model.{train | eval | fit}`). Note that
|
||||
each stage of callbacks expects one positional arguments: `run_context`.
|
||||
|
||||
Note:
|
||||
|
|
|
@ -262,7 +262,7 @@ class SummaryLandscape:
|
|||
callback_fn (python function): A python function object. User needs to write a function,
|
||||
it has no input, and the return requirements are as follows.
|
||||
|
||||
- mindspore.Model: User's model object.
|
||||
- mindspore.train.Model: User's model object.
|
||||
- mindspore.nn.Cell: User's network object.
|
||||
- mindspore.dataset: User's dataset object for create loss landscape.
|
||||
- mindspore.train.Metrics: User's metrics object.
|
||||
|
|
|
@ -63,7 +63,7 @@ class LossMonitor(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
|
@ -94,7 +94,7 @@ class LossMonitor(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
metrics = cb_params.get("metrics")
|
||||
|
|
|
@ -87,7 +87,7 @@ class OnRequestExit(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
if os.path.isfile(f"{self.train_file_path}.ckpt"):
|
||||
cb_params = run_context.original_args()
|
||||
|
@ -101,7 +101,7 @@ class OnRequestExit(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
if self.stop:
|
||||
run_context.request_stop()
|
||||
|
@ -112,7 +112,7 @@ class OnRequestExit(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
if self.stop:
|
||||
run_context.request_stop()
|
||||
|
@ -123,7 +123,7 @@ class OnRequestExit(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
if not self.stop:
|
||||
return
|
||||
|
@ -141,7 +141,7 @@ class OnRequestExit(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
eval_net = cb_params.eval_network
|
||||
|
@ -157,7 +157,7 @@ class OnRequestExit(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
if self.stop:
|
||||
run_context.request_stop()
|
||||
|
@ -168,7 +168,7 @@ class OnRequestExit(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
if not self.stop:
|
||||
return
|
||||
|
|
|
@ -124,7 +124,7 @@ class ReduceLROnPlateau(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
self.cooldown_counter = 0
|
||||
self.wait = 0
|
||||
|
@ -140,7 +140,7 @@ class ReduceLROnPlateau(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
cur_lr = cb_params.optimizer.learning_rate
|
||||
|
@ -188,7 +188,7 @@ class ReduceLROnPlateau(Callback):
|
|||
|
||||
Args:
|
||||
cb_params (dict): A dictionary stores context information of the model. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
monitor_candidates = {}
|
||||
if self.monitor == "loss":
|
||||
|
|
|
@ -58,7 +58,7 @@ class TimeMonitor(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
self.epoch_time = time.time()
|
||||
|
||||
|
@ -68,7 +68,7 @@ class TimeMonitor(Callback):
|
|||
|
||||
Args:
|
||||
run_context (RunContext): Context of the process running. For more details,
|
||||
please refer to :class:`mindspore.RunContext`.
|
||||
please refer to :class:`mindspore.train.RunContext`.
|
||||
"""
|
||||
epoch_seconds = (time.time() - self.epoch_time) * 1000
|
||||
step_size = self.data_size
|
||||
|
|
|
@ -186,9 +186,11 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
|
||||
Note:
|
||||
In the case of running the network on Ascend/GPU in graph mode, this function will wrap the input network with
|
||||
'GetNext', in other cases, the input network will be returned with no change.
|
||||
The 'GetNext' is required to get data only in sink mode, so this function is not applicable to no-sink mode.
|
||||
:class:`mindspore.ops.GetNext`. In other cases, the input network will be returned with no change.
|
||||
The :class:`mindspore.ops.GetNext` is required to get data only in sink mode,
|
||||
so this function is not applicable to no-sink mode.
|
||||
when dataset_helper's dataset_sink_mode is True, it can only be connected to one network.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network for dataset.
|
||||
dataset_helper (DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue
|
||||
|
|
|
@ -72,8 +72,8 @@ class Dice(Metric):
|
|||
Updates the internal evaluation result :math:`y_pred` and :math:`y`.
|
||||
|
||||
Args:
|
||||
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
|
||||
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, ...)`.
|
||||
inputs (tuple): Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
|
||||
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, ...)`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of the inputs is not 2.
|
||||
|
|
|
@ -1088,7 +1088,7 @@ class Model:
|
|||
|
||||
Evaluation process will be performed during training process if `valid_dataset` is provided.
|
||||
|
||||
More details please refer to `mindspore.Model.train` and `mindspore.Model.eval`.
|
||||
More details please refer to `mindspore.train.Model.train` and `mindspore.train.Model.eval`.
|
||||
|
||||
Args:
|
||||
epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch.
|
||||
|
@ -1231,7 +1231,7 @@ class Model:
|
|||
|
||||
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
||||
"""
|
||||
Evaluation process in `mindspore.Model.fit`.
|
||||
Evaluation process in `mindspore.train.Model.fit`.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
||||
|
|
|
@ -46,7 +46,10 @@ def define_model():
|
|||
|
||||
|
||||
class MyCallbackOldMethod(Callback):
|
||||
""" Raise warning in `mindspore.Model.train` and `mindspore.Model.eval`; raise error in `mindspore.Model.fit`"""
|
||||
"""
|
||||
Raise warning in `mindspore.train.Model.train` and `mindspore.train.Model.eval`;
|
||||
raise error in `mindspore.train.Model.fit`.
|
||||
"""
|
||||
def begin(self, run_context):
|
||||
print("custom callback: print on begin, just for test.")
|
||||
|
||||
|
@ -62,7 +65,10 @@ class MyCallbackOldMethod(Callback):
|
|||
|
||||
|
||||
class MyCallbackNewMethod(Callback):
|
||||
""" Custom callback running in `mindspore.Model.train`, `mindspore.Model.eval`, `mindspore.Model.fit`"""
|
||||
"""
|
||||
Custom callback running in `mindspore.train.Model.train`, `mindspore.train.Model.eval`,
|
||||
`mindspore.train.Model.fit`.
|
||||
"""
|
||||
def on_train_epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
print("custom callback: train epoch end, loss is %s" % (cb_params.get("net_outputs")))
|
||||
|
@ -74,7 +80,7 @@ class MyCallbackNewMethod(Callback):
|
|||
|
||||
def test_fit_train_dataset_non_sink_mode():
|
||||
"""
|
||||
Feature: `mindspore.Model.fit` with train dataset in non-sink mode.
|
||||
Feature: `mindspore.train.Model.fit` with train dataset in non-sink mode.
|
||||
Description: test fit with train dataset in non-sink mode.
|
||||
Expectation: run in non-sink mode.
|
||||
"""
|
||||
|
@ -87,7 +93,7 @@ def test_fit_train_dataset_non_sink_mode():
|
|||
|
||||
def test_fit_train_dataset_sink_mode():
|
||||
"""
|
||||
Feature: `mindspore.Model.fit` with train dataset in sink mode.
|
||||
Feature: `mindspore.train.Model.fit` with train dataset in sink mode.
|
||||
Description: test fit with train dataset in sink mode.
|
||||
Expectation: run in sink mode.
|
||||
"""
|
||||
|
@ -100,7 +106,7 @@ def test_fit_train_dataset_sink_mode():
|
|||
|
||||
def test_fit_valid_dataset_non_sink_mode():
|
||||
"""
|
||||
Feature: `mindspore.Model.fit` with valid dataset in non-sink mode.
|
||||
Feature: `mindspore.train.Model.fit` with valid dataset in non-sink mode.
|
||||
Description: test fit with valid dataset in non-sink mode.
|
||||
Expectation: run in non-sink mode.
|
||||
"""
|
||||
|
@ -113,7 +119,7 @@ def test_fit_valid_dataset_non_sink_mode():
|
|||
|
||||
def test_fit_valid_dataset_sink_mode():
|
||||
"""
|
||||
Feature: `mindspore.Model.fit` with valid dataset in sink mode.
|
||||
Feature: `mindspore.train.Model.fit` with valid dataset in sink mode.
|
||||
Description: test fit with valid dataset in sink mode.
|
||||
Expectation: run in sink mode.
|
||||
"""
|
||||
|
@ -126,7 +132,7 @@ def test_fit_valid_dataset_sink_mode():
|
|||
|
||||
def test_fit_without_valid_dataset():
|
||||
"""
|
||||
Feature: `mindspore.Model.fit` without `valid_dataset` input .
|
||||
Feature: `mindspore.train.Model.fit` without `valid_dataset` input .
|
||||
Description: test fit when `valid_dataset` is None and `valid_dataset_sink_mode` is True or False.
|
||||
Expectation: network train without eval process, `valid_dataset_sink_mode` does not take effect.
|
||||
"""
|
||||
|
@ -139,7 +145,7 @@ def test_fit_without_valid_dataset():
|
|||
|
||||
def test_fit_valid_frequency():
|
||||
"""
|
||||
Feature: check `valid_frequency` input in `mindspore.Model.fit`.
|
||||
Feature: check `valid_frequency` input in `mindspore.train.Model.fit`.
|
||||
Description: when `valid_frequency` is integer, list or other types.
|
||||
Expectation: raise ValueError when the type of valid_frequency is not int or list.
|
||||
"""
|
||||
|
@ -156,7 +162,7 @@ def test_fit_valid_frequency():
|
|||
|
||||
def test_fit_callbacks():
|
||||
"""
|
||||
Feature: check `callbacks` input in `mindspore.Model.fit`.
|
||||
Feature: check `callbacks` input in `mindspore.train.Model.fit`.
|
||||
Description: test internal or custom callbacks in fit.
|
||||
Expectation: raise ValueError when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'.
|
||||
"""
|
||||
|
@ -178,7 +184,7 @@ def test_fit_callbacks():
|
|||
|
||||
def test_train_eval_callbacks():
|
||||
"""
|
||||
Feature: check `callbacks` input in `mindspore.Model.train` or `mindspore.Model.eval`.
|
||||
Feature: check `callbacks` input in `mindspore.train.Model.train` or `mindspore.train.Model.eval`.
|
||||
Description: test internal or custom callbacks in train or eval.
|
||||
Expectation: raise warning when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue