add model.fit and refactor callbacks

This commit is contained in:
liutongtong 2022-05-18 11:11:03 +08:00
parent f35e313d33
commit f6c8053c2b
15 changed files with 869 additions and 94 deletions

View File

@ -6,8 +6,7 @@ mindspore.Callback
用于构建Callback函数的基类。Callback函数是一个上下文管理器在运行模型时被调用。
可以使用此机制进行一些自定义操作。
Callback类的每个方法对应了训练或推理过程的不同阶段这些方法有相同的入参 `run_context`,用于保存模型
训练或推理过程模型的相关信息。定义Callback子类或自定义Callback时请根据需要重写对应的方法。
Callback类的每个方法对应了训练或推理过程的不同阶段这些方法有相同的入参 `run_context`用于保存训练或推理过程中模型的相关信息。定义Callback子类或自定义Callback时请根据需要重写名称前缀为“on_train”或“on_eval”的方法否则自定义的Callback在 `model.fit` 中使用时会产生错误。
自定义Callback场景下在类方法中通过 `RunContext.original_args()` 方法可以获取模型训练或推理过程中已有
的上下文信息,此信息为一个存储了已有属性的字典型变量;用户也可以在此信息中添加其他的自定义属性;此外,
@ -16,7 +15,7 @@ mindspore.Callback
.. py:method:: begin(run_context)
在网络执行之前被调用一次。
在网络执行之前被调用一次。`on_train_begin``on_eval_begin` 方法具有兼容性。
**参数:**
@ -24,7 +23,7 @@ mindspore.Callback
.. py:method:: end(run_context)
网络执行后被调用一次。
网络执行后被调用一次。`on_train_end``on_eval_end` 方法具有兼容性。
**参数:**
@ -32,7 +31,7 @@ mindspore.Callback
.. py:method:: epoch_begin(run_context)
在每个epoch开始之前被调用。
在每个epoch开始之前被调用。`on_train_epoch_begin``on_eval_epoch_begin` 方法具有兼容性。
**参数:**
@ -40,7 +39,7 @@ mindspore.Callback
.. py:method:: epoch_end(run_context)
在每个epoch结束后被调用。
在每个epoch结束后被调用。`on_train_epoch_end``on_eval_epoch_end` 方法具有兼容性。
**参数:**
@ -52,7 +51,7 @@ mindspore.Callback
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
- **run_context** (RunContext) - 包含模型的一些基本信息。`on_train_step_begin``on_eval_step_begin` 方法具有兼容性。
.. py:method:: step_end(run_context)
@ -60,4 +59,100 @@ mindspore.Callback
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。与 `on_train_step_end``on_eval_step_end` 方法具有兼容性。
.. py:method:: on_train_begin(run_context)
在网络执行训练之前调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method::on_train_end(run_context)
网络训练执行结束时调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_train_epoch_begin(run_context)
在训练的每个epoch开始之前被调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_train_epoch_end(run_context)
在训练的每个epoch结束后被调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_train_step_begin(run_context)
在训练的每个step开始之前被调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_train_step_end(run_context)
在训练的每个step完成后被调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_eval_begin(run_context)
在网络执行推理之前调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_eval_end(run_context)
网络执行推理之后调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_eval_epoch_begin(run_context)
在推理的epoch开始之前被调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_evalepoch_end(run_context)
在推理的epoch结束后被调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_eval_step_begin(run_context)
在推理的每个step开始之前被调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。
.. py:method:: on_eval_step_end(run_context)
在推理的每个step完成后被调用。
**参数:**
- **run_context** (RunContext) - 包含模型的一些基本信息。

View File

@ -3,12 +3,12 @@ mindspore.History
.. py:class:: mindspore.History
将网络输出的相关信息记录到 `History` 对象中。
将网络输出和评估指标的相关信息记录到 `History` 对象中。
用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 `Tensor``numpy.ndarray`,则记录此返回值均值,如果返回 `tuple``list`,则记录第一个元素。
.. note::
通常使用在 `mindspore.Model.train` 中。
通常使用在 `mindspore.Model.train` `mindspore.Model.fit` 中。
.. py:method:: begin(run_context)
@ -20,7 +20,7 @@ mindspore.History
.. py:method:: epoch_end(run_context)
epoch结束时记录网络输出的相关信息。
epoch结束时记录网络输出和评估指标的相关信息。
**参数:**

View File

@ -5,7 +5,7 @@ mindspore.LambdaCallback
用于自定义简单的callback。
使用匿名函数构建callback定义的匿名函数将在 `mindspore.Model.{train | eval}` 的对应阶段被调用。
使用匿名函数构建callback定义的匿名函数将在 `mindspore.Model.{train | eval | fit}` 的对应阶段被调用。
请注意callback的每个阶段都需要一个位置参数`run_context`
@ -14,9 +14,15 @@ mindspore.LambdaCallback
**参数:**
- **epoch_begin** (Function) - 每个epoch开始时被调用。
- **epoch_end** (Function) - 每个epoch结束时被调用。
- **step_begin** (Function) - 每个step开始时被调用。
- **step_end** (Function) - 每个step结束时被调用。
- **begin** (Function) - 模型训练、评估开始时被调用。
- **end** (Function) - 模型训练、评估结束时被调用。
- **on_train_epoch_begin** (Function) - 训练每个epoch开始时被调用。
- **on_train_epoch_end** (Function) - 训练每个epoch结束时被调用。
- **on_train_step_begin** (Function) - 训练每个step开始时被调用。
- **on_train_step_end** (Function) - 训练每个step结束时被调用。
- **on_train_begin** (Function) - 模型训练开始时被调用。
- **on_train_end** (Function) - 模型训练结束时被调用。
- **on_eval_epoch_begin** (Function) - 推理的epoch开始时被调用。
- **on_eval_epoch_end** (Function) - 推理的epoch结束时被调用。
- **on_eval_step_begin** (Function) - 推理的每个step开始时被调用。
- **on_eval_step_end** (Function) - 推理的每个step结束时被调用。
- **on_eval_begin** (Function) - 模型推理开始时被调用。
- **on_eval_end** (Function) - 模型推理结束时被调用。

View File

@ -3,7 +3,7 @@ mindspore.LossMonitor
.. py:class:: mindspore.LossMonitor(per_print_times=1)
监控训练的loss。
训练场景下,监控训练的loss边训练边推理场景下监控训练的loss和推理的metrics。
如果loss是NAN或INF则终止训练。
@ -25,3 +25,11 @@ mindspore.LossMonitor
**参数:**
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`
.. py:method:: on_train_epoch_end(run_context)
LossMoniter用于 `model.fit`即边训练边推理场景时打印训练的loss和当前epoch推理的metrics。
**参数:**
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`

View File

@ -162,6 +162,24 @@
- **sink_size** (int) 控制每次数据下沉的数据量。`dataset_sink_mode` 为False时 `sink_size` 无效。如果sink_size=-1则每一次epoch下沉完整数据集。如果sink_size>0则每一次epoch下沉数据量为sink_size的数据集。默认值-1。
- **initial_epoch** (int) - 从哪个epoch开始训练一般用于中断恢复训练场景。
.. py:method:: fit(epoch, train_dataset, valid_dataset=None, valid_frequency=1, callbacks=None, dataset_sink_mode=True, valid_dataset_sink_mode=True, sink_size=-1, initial_epoch=0)
模型边训练边推理接口。
如果 `valid_dataset` 不为None在训练过程中同时执行推理。更多详细信息请参考 `mindspore.Model.model.train`
**参数:**
- **epoch** (int) 训练执行轮次。通常每个epoch都会使用全量数据集进行训练。当 `dataset_sink_mode` 设置为True且 `sink_size` 大于零时则每个epoch训练次数为 `sink_size` 而不是数据集的总步数。如果 `epoch``initial_epoch` 一起使用,它表示训练的最后一个 `epoch` 是多少。
- **train_dataset** (Dataset) 训练数据集迭代器。如果定义了 `loss_fn` ,则数据和标签会被分别传给 `network``loss_fn` 此时数据集需要返回一个元组data, label。如果数据集中有多个数据或者标签可以设置 `loss_fn` 为None并在 `network` 中实现损失函数计算此时数据集返回的所有数据组成的元组data1, data2, data3, ...)会传给 `network`
- **valid_dataset** (Dataset) 评估模型的数据集迭代器。默认值None。
- **valid_frequency** (Dataset) 此参数只有在valid_dataset不为None时生效。如果为int类型表示执行推理的频率例如 `valid_frequency=2`则每2个训练epoch执行一次推理如果为list类型指明在哪几个epoch时执行推理例如 `valid_frequency=[1, 5]`则在第1个和第5个epoch执行推理。默认值1。
- **callbacks** (Optional[list[Callback], Callback]) 训练过程中需要执行的回调对象或者回调对象列表。默认值None。
- **dataset_sink_mode** (bool) 训练数据是否直接下沉至处理器进行处理。使用PYNATIVE_MODE模式或CPU处理器时模型训练流程将以非下沉模式执行。默认值True。
- **valid_dataset_sink_mode** (bool) - 推理数据是否直接下沉至处理器进行处理。默认值True。
- **sink_size** (int) 控制每次数据下沉的数据量。`dataset_sink_mode` 为False时 `sink_size` 无效。如果sink_size=-1则每一次epoch下沉完整数据集。如果sink_size>0则每一次epoch下沉数据量为sink_size的数据集。默认值-1。
- **initial_epoch** (int) - 从哪个epoch开始训练一般用于中断恢复训练场景。
.. py:method:: train_network
:property:
@ -169,4 +187,4 @@
**返回:**
预测网络实例。
预测网络实例。

View File

@ -3,7 +3,7 @@ mindspore.TimeMonitor
.. py:class:: mindspore.TimeMonitor(data_size=None)
监控训练时间。
监控训练或推理的时间。
**参数:**

View File

@ -82,7 +82,8 @@ class Callback:
Each method of Callback class corresponds to a stage in training or eval process, and those methods
have the same input `run_context`, which hold context information of the model in
training or eval process. When defining a Callback subclass or creating a custom Callback,
override these methods.
Note that you should override methods with names prefixed with "on_train" or "on_eval",
otherwise ValueError will be raised if the custimized Callbacks used in `model.fit`.
When creating a custom Callback, model context information can be obtained in Callback
methods by calling `RunContext.original_args()`, which is a dictionary varivable
@ -122,6 +123,7 @@ class Callback:
def begin(self, run_context):
"""
Called once before the network executing.
A backwards compatibility alias for `on_train_begin` and `on_eval_begin`.
Args:
run_context (RunContext): Include some information of the model.
@ -130,6 +132,7 @@ class Callback:
def epoch_begin(self, run_context):
"""
Called before each epoch beginning.
A backwards compatibility alias for `on_train_epoch_begin` and `on_eval_epoch_begin`.
Args:
run_context (RunContext): Include some information of the model.
@ -138,6 +141,7 @@ class Callback:
def epoch_end(self, run_context):
"""
Called after each epoch finished.
A backwards compatibility alias for `on_train_epoch_end` and `on_eval_epoch_end`.
Args:
run_context (RunContext): Include some information of the model.
@ -146,6 +150,7 @@ class Callback:
def step_begin(self, run_context):
"""
Called before each step beginning.
A backwards compatibility alias for `on_train_step_begin` and `on_eval_step_begin`.
Args:
run_context (RunContext): Include some information of the model.
@ -154,6 +159,7 @@ class Callback:
def step_end(self, run_context):
"""
Called after each step finished.
A backwards compatibility alias for `on_train_step_end` and `on_eval_step_end`.
Args:
run_context (RunContext): Include some information of the model.
@ -162,11 +168,120 @@ class Callback:
def end(self, run_context):
"""
Called once after network training.
A backwards compatibility alias for `on_train_end` and `on_eval_end`.
Args:
run_context (RunContext): Include some information of the model.
"""
def on_train_begin(self, run_context):
"""
Called once before the network training.
Args:
run_context (RunContext): Include some information of the model.
"""
self.begin(run_context)
def on_train_epoch_begin(self, run_context):
"""
Called before each training epoch begin.
Args:
run_context (RunContext): Include some information of the model.
"""
self.epoch_begin(run_context)
def on_train_epoch_end(self, run_context):
"""
Called after each training epoch end.
Args:
run_context (RunContext): Include some information of the model.
"""
self.epoch_end(run_context)
def on_train_step_begin(self, run_context):
"""
Called before each training step begin.
Args:
run_context (RunContext): Include some information of the model.
"""
self.step_begin(run_context)
def on_train_step_end(self, run_context):
"""
Called after each training step end.
Args:
run_context (RunContext): Include some information of the model.
"""
self.step_end(run_context)
def on_train_end(self, run_context):
"""
Called after training end.
Args:
run_context (RunContext): Include some information of the model.
"""
self.end(run_context)
def on_eval_begin(self, run_context):
"""
Called before eval begin.
Args:
run_context (RunContext): Include some information of the model.
"""
self.begin(run_context)
def on_eval_epoch_begin(self, run_context):
"""
Called before eval epoch begin.
Args:
run_context (RunContext): Include some information of the model.
"""
self.epoch_begin(run_context)
def on_eval_epoch_end(self, run_context):
"""
Called after eval epoch end.
Args:
run_context (RunContext): Include some information of the model.
"""
self.epoch_end(run_context)
def on_eval_step_begin(self, run_context):
"""
Called before each eval step begin.
Args:
run_context (RunContext): Include some information of the model.
"""
self.step_begin(run_context)
def on_eval_step_end(self, run_context):
"""
Called after each eval step end.
Args:
run_context (RunContext): Include some information of the model.
"""
self.step_end(run_context)
def on_eval_end(self, run_context):
"""
Called after eval end.
Args:
run_context (RunContext): Include some information of the model.
"""
self.end(run_context)
class CallbackManager(Callback):
"""
@ -208,7 +323,7 @@ class CallbackManager(Callback):
return self._stack.__exit__(*err)
def begin(self, run_context):
"""Called once before network training."""
"""Called once before network train or eval."""
for cb in self._callbacks:
cb.begin(run_context)
@ -223,7 +338,7 @@ class CallbackManager(Callback):
cb.epoch_end(run_context)
def step_begin(self, run_context):
"""Called before each epoch begin."""
"""Called before each step begin."""
for cb in self._callbacks:
cb.step_begin(run_context)
@ -233,10 +348,70 @@ class CallbackManager(Callback):
cb.step_end(run_context)
def end(self, run_context):
"""Called once after network training."""
"""Called once after network train or eval."""
for cb in self._callbacks:
cb.end(run_context)
def on_train_begin(self, run_context):
"""Called before network train."""
for cb in self._callbacks:
cb.on_train_begin(run_context)
def on_train_epoch_begin(self, run_context):
"""Called before each train epoch begin."""
for cb in self._callbacks:
cb.on_train_epoch_begin(run_context)
def on_train_epoch_end(self, run_context):
"""Called after each train epoch finished."""
for cb in self._callbacks:
cb.on_train_epoch_end(run_context)
def on_train_step_begin(self, run_context):
"""Called before each train step begin."""
for cb in self._callbacks:
cb.on_train_step_begin(run_context)
def on_train_step_end(self, run_context):
"""Called after each train step finished."""
for cb in self._callbacks:
cb.step_end(run_context)
def on_train_end(self, run_context):
"""Called after network train end."""
for cb in self._callbacks:
cb.on_train_end(run_context)
def on_eval_begin(self, run_context):
"""Called before network eval."""
for cb in self._callbacks:
cb.on_eval_begin(run_context)
def on_eval_epoch_begin(self, run_context):
"""Called before eval epoch begin."""
for cb in self._callbacks:
cb.on_eval_epoch_begin(run_context)
def on_eval_epoch_end(self, run_context):
"""Called after eval epoch finished."""
for cb in self._callbacks:
cb.on_eval_epoch_end(run_context)
def on_eval_step_begin(self, run_context):
"""Called before each eval step begin."""
for cb in self._callbacks:
cb.on_eval_step_begin(run_context)
def on_eval_step_end(self, run_context):
"""Called after each eval step finished."""
for cb in self._callbacks:
cb.on_eval_step_end(run_context)
def on_eval_end(self, run_context):
"""Called after network eval end."""
for cb in self._callbacks:
cb.on_eval_end(run_context)
class InternalCallbackParam(dict):
"""Internal callback object's parameters."""

View File

@ -21,7 +21,7 @@ from ._callback import Callback
class History(Callback):
"""
Records the network outputs information into a `History` object.
Records the network outputs and metrics information into a `History` object.
The network outputs information will be the loss value if not custimizing the train network or eval network;
if the custimized network returns a `Tensor` or `numpy.ndarray`, the mean value of network output
@ -29,7 +29,7 @@ class History(Callback):
outputs will be recorded.
Note:
Normally used in `mindspore.Model.train`.
Normally used in `mindspore.Model.train` or `mindspore.Model.fit`.
Examples:
>>> import numpy as np
@ -65,7 +65,7 @@ class History(Callback):
def epoch_end(self, run_context):
"""
Records the first element of network outputs at the end of epoch.
Records the first element of network outputs and metrics information at the end of epoch.
Args:
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. For more details,

View File

@ -22,19 +22,25 @@ class LambdaCallback(Callback):
Callback for creating simple, custom callbacks.
This callback is constructed with anonymous functions that will be called
at the appropriate time (during `mindspore.Model.{train | eval}`). Note that
at the appropriate time (during `mindspore.Model.{train | eval | fit}`). Note that
each stage of callbacks expects one positional arguments: `run_context`.
Note:
This is an experimental interface that is subject to change or deletion.
Args:
epoch_begin (Function): called at the beginning of every epoch.
epoch_end (Function): called at the end of every epoch.
step_begin (Function): called at the beginning of every batch.
step_end (Function): called at the end of every batch.
begin (Function): called at the beginning of model train/eval.
end (Function): called at the end of model train/eval.
on_train_epoch_begin (Function): called at each train epoch begin.
on_train_epoch_end (Function): called at each train epoch end.
on_train_step_begin (Function): called at each train step begin.
on_train_step_end (Function): called at each train step end.
on_train_begin (Function): called at the beginning of model train.
on_train_end (Function): called at the end of model train.
on_eval_epoch_begin (Function): called at eval epoch begin.
on_eval_epoch_end (Function): called at eval epoch end.
on_eval_step_begin (Function): called at each eval step begin.
on_eval_step_end (Function): called at each eval step end.
on_eval_begin (Function): called at the beginning of model eval.
on_eval_end (Function): called at the end of model eval.
Examples:
>>> import numpy as np
@ -46,19 +52,28 @@ class LambdaCallback(Callback):
>>> net = nn.Dense(10, 5)
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> lambda_callback = LambdaCallback(epoch_end=
>>> lambda_callback = LambdaCallback(on_train_epoch_end=
... lambda run_context: print("loss: ", run_context.original_args().net_outputs))
>>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
>>> model.train(2, train_dataset, callbacks=[lambda_callback])
loss: 1.6127687
loss: 1.6106578
"""
def __init__(self, epoch_begin=None, epoch_end=None, step_begin=None,
step_end=None, begin=None, end=None):
def __init__(self, on_train_epoch_begin=None, on_train_epoch_end=None, on_train_step_begin=None,
on_train_step_end=None, on_train_begin=None, on_train_end=None,
on_eval_epoch_begin=None, on_eval_epoch_end=None, on_eval_step_begin=None,
on_eval_step_end=None, on_eval_begin=None, on_eval_end=None):
super(LambdaCallback, self).__init__()
self.epoch_begin = epoch_begin if epoch_begin else lambda run_context: None
self.epoch_end = epoch_end if epoch_end else lambda run_context: None
self.step_begin = step_begin if step_begin else lambda run_context: None
self.step_end = step_end if step_end else lambda run_context: None
self.begin = begin if begin else lambda run_context: None
self.end = end if end else lambda run_context: None
self.on_train_epoch_begin = on_train_epoch_begin if on_train_epoch_begin else lambda run_context: None
self.on_train_epoch_end = on_train_epoch_end if on_train_epoch_end else lambda run_context: None
self.on_train_step_begin = on_train_step_begin if on_train_step_begin else lambda run_context: None
self.on_train_step_end = on_train_step_end if on_train_step_end else lambda run_context: None
self.on_train_begin = on_train_begin if on_train_begin else lambda run_context: None
self.on_train_end = on_train_end if on_train_end else lambda run_context: None
self.on_eval_epoch_begin = on_eval_epoch_begin if on_eval_epoch_begin else lambda run_context: None
self.on_eval_epoch_end = on_eval_epoch_end if on_eval_epoch_end else lambda run_context: None
self.on_eval_step_begin = on_eval_step_begin if on_eval_step_begin else lambda run_context: None
self.on_eval_step_end = on_eval_step_end if on_eval_step_end else lambda run_context: None
self.on_eval_begin = on_eval_begin if on_eval_begin else lambda run_context: None
self.on_eval_end = on_eval_end if on_eval_end else lambda run_context: None

View File

@ -23,7 +23,7 @@ from ._callback import Callback
class LossMonitor(Callback):
"""
Monitor the loss in training.
Monitor the loss in train or monitor the loss and eval metrics in fit.
If the loss is NAN or INF, it will terminate training.
@ -90,3 +90,17 @@ class LossMonitor(Callback):
if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times:
self._last_print_time = cb_params.cur_step_num
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
def on_train_epoch_end(self, run_context):
"""
When LossMoniter used in `model.fit`, print eval metrics at the end of epoch if current epoch
should do evaluation.
Args:
run_context (RunContext): Include some information of the model. For more details,
please refer to :class:`mindspore.RunContext`.
"""
cb_params = run_context.original_args()
metrics = cb_params.get("metrics")
if metrics:
print("Eval result: epoch %d, metrics: %s" % (cb_params.cur_epoch_num, metrics))

View File

@ -22,7 +22,7 @@ from ._callback import Callback
class TimeMonitor(Callback):
"""
Monitor the time in training.
Monitor the time in train or eval process.
Args:
data_size (int): How many steps are the intervals between print information each time.
@ -71,6 +71,7 @@ class TimeMonitor(Callback):
epoch_seconds = (time.time() - self.epoch_time) * 1000
step_size = self.data_size
cb_params = run_context.original_args()
mode = cb_params.get("mode", "")
if hasattr(cb_params, "batch_num"):
batch_num = cb_params.batch_num
if isinstance(batch_num, int) and batch_num > 0:
@ -78,4 +79,5 @@ class TimeMonitor(Callback):
Validator.check_positive_int(step_size)
step_seconds = epoch_seconds / step_size
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, step_seconds), flush=True)
print("{} epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format
(mode.title(), epoch_seconds, step_seconds), flush=True)

View File

@ -27,7 +27,8 @@ from .callback._checkpoint import _chg_ckpt_file_name_if_same_exist
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, Validator
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor
from .callback import __all__ as internal_cb_names
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
@ -503,7 +504,8 @@ class Model:
return [callbacks]
@_save_final_ckpt
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0):
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0,
valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True):
"""
Training.
@ -541,6 +543,7 @@ class Model:
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = self._transform_callbacks(callbacks)
valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
if context.get_context("mode") == context.PYNATIVE_MODE:
cb_params.list_callback.insert(0, _StepSync())
callbacks = cb_params.list_callback
@ -555,17 +558,21 @@ class Model:
with _CallbackManager(callbacks) as list_callback:
self._check_reuse_dataset(train_dataset)
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch)
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos)
elif context.get_context("device_target") == "CPU":
logger.info("The CPU cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch)
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback,
cb_params, sink_size, initial_epoch)
cb_params, sink_size, initial_epoch, valid_infos)
@staticmethod
def _should_eval(epoch, validation_freq):
return epoch % validation_freq == 0 if isinstance(validation_freq, int) else epoch in validation_freq
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None,
sink_size=-1, initial_epoch=0):
sink_size=-1, initial_epoch=0, valid_infos=None):
"""
Training process. The data would be passed to network through dataset channel.
@ -593,7 +600,7 @@ class Model:
cb_params.dataset_sink_mode = True
run_context = RunContext(cb_params)
list_callback.begin(run_context)
list_callback.on_train_begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep
dataset_helper = None
if hasattr(train_dataset, '_dataset_helper'):
@ -609,7 +616,7 @@ class Model:
cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
self._current_epoch_num = cb_params.cur_epoch_num
self._current_step_num = 0
list_callback.epoch_begin(run_context)
list_callback.on_train_epoch_begin(run_context)
dataset_helper, train_network = self._exec_preprocess(is_train=True,
dataset=train_dataset,
dataset_sink_mode=True,
@ -632,22 +639,51 @@ class Model:
cb_params.cur_step_num += 1
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context)
list_callback.on_train_step_begin(run_context)
outputs = train_network(*inputs)
cb_params.net_outputs = outputs
# In disaster recovery scenarios, need not to execute callbacks if this step executes failed.
need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
if need_exec_callback_step_end:
list_callback.step_end(run_context)
list_callback.on_train_step_end(run_context)
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
os._exit(0)
dataset_helper.continue_send()
valid_dataset, valid_frequency, valid_dataset_sink_mode = valid_infos
if valid_dataset and self._should_eval(cb_params.cur_epoch_num, valid_frequency):
train_cur_step_num = cb_params.cur_step_num
train_batch_num = cb_params.batch_num
train_dataset_sink_mode = cb_params.dataset_sink_mode
train_net_outputs = cb_params.net_outputs
eval_callback = []
for cb in list_callback._callbacks:
if cb.__class__.__name__ in internal_cb_names:
if isinstance(cb, TimeMonitor):
eval_callback.append(cb)
else:
eval_callback.append(cb)
self._eval_in_fit(valid_dataset,
callbacks=eval_callback,
dataset_sink_mode=valid_dataset_sink_mode,
cb_params=cb_params)
cb_params.mode = "train"
cb_params.cur_step_num = train_cur_step_num
cb_params.batch_num = train_batch_num
cb_params.dataset_sink_mode = train_dataset_sink_mode
cb_params.net_outputs = train_net_outputs
# In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
need_exec_callback_epoch_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
if need_exec_callback_epoch_end:
list_callback.epoch_end(run_context)
list_callback.on_train_epoch_end(run_context)
if "metrics" in cb_params:
cb_params.pop("metrics")
should_stop = run_context.get_stop_requested()
if should_stop:
@ -663,7 +699,7 @@ class Model:
dataset_helper.stop_send()
dataset_helper.release()
list_callback.end(run_context)
list_callback.on_train_end(run_context)
def _check_enable_recovery(self):
"""
@ -753,7 +789,8 @@ class Model:
_set_recovery_context(need_reset=False)
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0):
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0,
valid_infos=None):
"""
Training process. The data would be passed to network directly.
@ -776,13 +813,13 @@ class Model:
cb_params.cur_step_num = 0
cb_params.dataset_sink_mode = False
run_context = RunContext(cb_params)
list_callback.begin(run_context)
list_callback.on_train_begin(run_context)
for i in range(initial_epoch, epoch):
cb_params.cur_epoch_num = i + 1
self._current_epoch_num = cb_params.cur_epoch_num
self._current_step_num = 0
list_callback.epoch_begin(run_context)
list_callback.on_train_epoch_begin(run_context)
for next_element in dataset_helper:
len_element = len(next_element)
@ -795,7 +832,7 @@ class Model:
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
cb_params.train_dataset_element = next_element
list_callback.step_begin(run_context)
list_callback.on_train_step_begin(run_context)
outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
@ -803,24 +840,52 @@ class Model:
overflow = np.all(overflow.asnumpy())
self._loss_scale_manager.update_loss_scale(overflow)
list_callback.step_end(run_context)
list_callback.on_train_step_end(run_context)
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
os._exit(0)
should_stop = run_context.get_stop_requested()
if should_stop:
break
valid_dataset, valid_frequency, valid_dataset_sink_mode = valid_infos
if valid_dataset and self._should_eval(cb_params.cur_epoch_num, valid_frequency):
train_cur_step_num = cb_params.cur_step_num
train_batch_num = cb_params.batch_num
train_dataset_sink_mode = cb_params.dataset_sink_mode
train_net_outputs = cb_params.net_outputs
eval_callback = []
for cb in list_callback._callbacks:
if cb.__class__.__name__ in internal_cb_names:
if isinstance(cb, TimeMonitor):
eval_callback.append(cb)
else:
eval_callback.append(cb)
self._eval_in_fit(valid_dataset,
callbacks=eval_callback,
dataset_sink_mode=valid_dataset_sink_mode,
cb_params=cb_params)
cb_params.mode = "train"
cb_params.cur_step_num = train_cur_step_num
cb_params.batch_num = train_batch_num
cb_params.dataset_sink_mode = train_dataset_sink_mode
cb_params.net_outputs = train_net_outputs
train_dataset.reset()
# if param is cache enable, flush data from cache to host before epoch end
self._flush_from_cache(cb_params)
list_callback.epoch_end(run_context)
list_callback.on_train_epoch_end(run_context)
if "metrics" in cb_params:
cb_params.pop("metrics")
should_stop = run_context.get_stop_requested()
if should_stop:
break
list_callback.end(run_context)
list_callback.on_train_end(run_context)
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0):
"""
@ -912,6 +977,9 @@ class Model:
_device_number_check(self._parallel_mode, self._device_number)
if callbacks:
self._check_methods_for_custom_callbacks(callbacks, "train")
self._train(epoch,
train_dataset,
callbacks=callbacks,
@ -924,6 +992,140 @@ class Model:
if _is_ps_mode() and _enable_distributed_mindrt():
_reset_op_id_with_offset()
@staticmethod
def _check_methods_for_custom_callbacks(callbacks, current_mode):
"""
Check whether methods of custimized callbacks are valid.
Args:
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object.
current_mode (str): 'fit', 'train' or 'eval'.
"""
old_version_methods_names = {'begin', 'end', 'epoch_begin', 'epoch_end', 'step_begin', 'step_end'}
if not isinstance(callbacks, list):
callbacks = [callbacks]
for cb in callbacks:
cb_name = cb.__class__.__name__
if cb_name not in internal_cb_names:
cb_methods_names = set(cb.__class__.__dict__.keys())
invalid_methods_names = cb_methods_names & old_version_methods_names
if invalid_methods_names:
if current_mode in ["train", "eval"]:
logger.warning("For %s callback, %s methods may not be supported in later version, "
"Use methods prefixed with 'on_train' or 'on_eval' instead "
"when using customized callbacks." % (cb_name, invalid_methods_names))
else:
raise ValueError("For %s callback, %s methods may not be supported in later version, "
"Use methods prefixed with 'on_train' or 'on_eval' instead when"
"using customized callbacks." % (cb_name, invalid_methods_names))
def fit(self, epoch, train_dataset, valid_dataset=None, valid_frequency=1, callbacks=None,
dataset_sink_mode=True, valid_dataset_sink_mode=True, sink_size=-1, initial_epoch=0):
"""
Fit API.
Evaluation process will be performed during training process if `valid_dataset` is provided.
More details please refer to `mindspore.Model.train` and `mindspore.Model.eval`.
Args:
epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch.
If `dataset_sink_mode` is set to True and `sink_size` is greater than 0, each epoch will
train `sink_size` steps instead of total steps of dataset.
If `epoch` used with `initial_epoch`, it is to be understood as "final epoch".
train_dataset (Dataset): A training dataset iterator. If `loss_fn` is defined, the data and label will be
passed to the `network` and the `loss_fn` respectively, so a tuple (data, label)
should be returned from dataset. If there is multiple data or labels, set `loss_fn`
to None and implement calculation of loss in `network`,
then a tuple (data1, data2, data3, ...) with all data returned from dataset
will be passed to the `network`.
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
will be performed on the end of training process. Default: None.
valid_frequency (int, list): Only relevant if `valid_dataset` is provided. If an integer, specifies
how many training epochs to run before a new validation run is performed,
e.g. `valid_frequency=2` runs validation every 2 epochs.
If a list, specifies the epochs on which to run validation,
e.g. `valid_frequency=[1, 5]` runs validation at the end of the 1st, 5th epochs.
Default: 1
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
which should be executed while training.
Default: None.
dataset_sink_mode (bool): Determines whether to pass the train data through dataset channel.
Configure pynative mode or CPU, the training process will be performed with
dataset not sink. Default: True.
valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
Default: True.
sink_size (int): Control the amount of data in each sink. `sink_size` is invalid if `dataset_sink_mode`
is False.
If sink_size = -1, sink the complete dataset for each epoch.
If sink_size > 0, sink sink_size data for each epoch.
Default: -1.
initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
Default: 0.
Examples:
>>> from mindspore import Model, nn, FixedLossScaleManager
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> train_dataset = create_custom_dataset()
>>> valid_dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics="accuracy")
>>> model.fit(2, train_dataset, valid_dataset)
"""
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
valid_dataset_sink_mode = Validator.check_bool(valid_dataset_sink_mode)
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
raise ValueError("when use Model.build to initialize model, the value of parameter `epoch` in Model.build "
"should be equal to value in Model.fit, but got {} and {} separately."
.format(train_dataset._warmup_epoch, epoch))
if dataset_sink_mode and _is_ps_mode() and not _cache_enable():
raise ValueError("Parameter server mode does not support 'data_sink_mode=True'.")
Validator.check_is_int(sink_size)
Validator.check_non_negative_int(initial_epoch)
if initial_epoch >= epoch:
raise ValueError(f"For 'Model.train', the parameter 'epoch' must bigger than parameter 'initial_epoch',"
f" but got the parameter 'epoch' is {epoch}, 'initial_epoch' is {initial_epoch}.")
dataset_size = train_dataset.get_dataset_size()
if dataset_size == 0:
raise ValueError("There is no valid data in dataset, please check dataset file firstly.")
if sink_size == -1:
sink_size = dataset_size
if sink_size < -1 or sink_size == 0:
raise ValueError("For 'Model.fit', The parameter 'sink_size' must be -1 or positive, "
"but got {}.".format(sink_size))
_device_number_check(self._parallel_mode, self._device_number)
if not isinstance(valid_frequency, (int, list)):
raise ValueError(f"For 'Model.fit', the type of 'valid_frequency' must be a list or a integer, but got"
"type {type(validation_freq)}.")
if valid_dataset and not self._metric_fns:
raise ValueError("For 'Model.fit', if valid_dataset is not None, the model argument 'metrics' can not be"
"None or empty, you should set the argument 'metrics' for model.")
if callbacks:
self._check_methods_for_custom_callbacks(callbacks, "fit")
self._train(epoch,
train_dataset,
callbacks=callbacks,
dataset_sink_mode=dataset_sink_mode,
sink_size=sink_size,
initial_epoch=initial_epoch,
valid_dataset=valid_dataset,
valid_frequency=valid_frequency,
valid_dataset_sink_mode=valid_dataset_sink_mode)
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, jit_config=None):
"""
Build computational graphs and data graphs with the sink mode.
@ -971,6 +1173,40 @@ class Model:
_cell_graph_executor.set_jit_config(jit_config)
self._init(train_dataset, valid_dataset, sink_size, epoch)
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
"""
Evaluation process in `mindspore.Model.fit`.
Args:
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
will be performed on the end of training process. Default: None.
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, which should be
executed while evaluation. Default: None.
valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
Default: True.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset
cb_params.batch_num = valid_dataset.get_dataset_size()
cb_params.mode = "eval"
cb_params.cur_step_num = 0
self._clear_metrics()
if context.get_context("device_target") == "CPU" and dataset_sink_mode:
dataset_sink_mode = False
logger.info("CPU cannot support dataset sink mode currently."
"So the evaluating process will be performed with dataset non-sink mode.")
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
"""
Evaluation. The data would be passed to network through dataset channel.
@ -990,20 +1226,20 @@ class Model:
dataset_sink_mode=True)
cb_params.eval_network = eval_network
cb_params.dataset_sink_mode = True
list_callback.begin(run_context)
list_callback.epoch_begin(run_context)
list_callback.on_eval_begin(run_context)
list_callback.on_eval_epoch_begin(run_context)
for inputs in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
list_callback.on_eval_step_begin(run_context)
outputs = eval_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
list_callback.on_eval_step_end(run_context)
self._update_metrics(outputs)
list_callback.epoch_end(run_context)
list_callback.on_eval_epoch_end(run_context)
metrics = self._get_metrics()
cb_params.metrics = metrics
list_callback.end(run_context)
list_callback.on_eval_end(run_context)
return metrics
@ -1021,25 +1257,25 @@ class Model:
"""
run_context = RunContext(cb_params)
cb_params.dataset_sink_mode = False
list_callback.begin(run_context)
list_callback.on_eval_begin(run_context)
dataset_helper, _ = self._exec_preprocess(is_train=False,
dataset=valid_dataset,
dataset_sink_mode=False)
list_callback.epoch_begin(run_context)
list_callback.on_eval_epoch_begin(run_context)
for next_element in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
list_callback.on_eval_step_begin(run_context)
next_element = _transfer_tensor_to_tuple(next_element)
outputs = self._eval_network(*next_element)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
list_callback.on_eval_step_end(run_context)
self._update_metrics(outputs)
list_callback.epoch_end(run_context)
list_callback.on_eval_epoch_end(run_context)
valid_dataset.reset()
metrics = self._get_metrics()
cb_params.metrics = metrics
list_callback.end(run_context)
list_callback.on_eval_end(run_context)
return metrics
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
@ -1087,7 +1323,8 @@ class Model:
"you should set the argument 'metrics' for model.")
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
if callbacks:
self._check_methods_for_custom_callbacks(callbacks, "eval")
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset

View File

@ -150,7 +150,7 @@ class ModelThor(Model):
self.switch_branch_one = not self.switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
list_callback.on_train_step_end(run_context)
else:
cb_params.cur_step_num += 1
if self.train_network_init_flag:
@ -163,7 +163,7 @@ class ModelThor(Model):
if self.index_first_order == iter_first_order:
self.index_first_order = 0
self.switch_branch_one = not self.switch_branch_one
list_callback.step_end(run_context)
list_callback.on_train_step_end(run_context)
def _train_ascend_sink_step(self, cb_params, train_dataset, iter_first_order, inputs, list_callback, run_context):
"""train ascend sink step"""
@ -184,10 +184,10 @@ class ModelThor(Model):
self.switch_branch_one = not self.switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
list_callback.on_train_step_end(run_context)
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None,
sink_size=-1, initial_epoch=0):
sink_size=-1, initial_epoch=0, valid_infos=None):
"""
Training process. The data would be passed to network through dataset channel.
@ -204,6 +204,9 @@ class ModelThor(Model):
initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
Default: 0.
"""
valid_dataset, _, _ = valid_infos
if valid_dataset:
raise ValueError("Evaluation in training is currently not supported in the second-order scenario of thor.")
if sink_size == -1:
epoch_num = epoch - initial_epoch
else:
@ -226,28 +229,28 @@ class ModelThor(Model):
cb_params.cur_step_num = 0
run_context = RunContext(cb_params)
list_callback.begin(run_context)
list_callback.on_train_begin(run_context)
for i in range(initial_epoch, epoch):
cb_params.cur_epoch_num = i + 1
list_callback.epoch_begin(run_context)
list_callback.on_train_epoch_begin(run_context)
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
list_callback.on_train_step_begin(run_context)
if context.get_context("device_target") == "GPU":
self._train_gpu_sink_step(cb_params, inputs, list_callback, iter_first_order, run_context)
else:
self._train_ascend_sink_step(cb_params, train_dataset, iter_first_order, inputs, list_callback,
run_context)
list_callback.epoch_end(run_context)
list_callback.on_train_epoch_end(run_context)
should_stop = False or run_context.get_stop_requested()
if should_stop:
break
dataset_helper.stop_send()
list_callback.end(run_context)
list_callback.on_train_end(run_context)
__all__ = ["ModelThor"]

202
tests/st/train/test_fit.py Normal file
View File

@ -0,0 +1,202 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_fit """
import pytest
import numpy as np
from mindspore import Model, nn, Tensor
from mindspore.common.initializer import Normal
from mindspore.train.callback import Callback, TimeMonitor, LossMonitor
from mindspore import dataset as ds
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
input_data = input_data.batch(batch_size, drop_remainder=True)
input_data = input_data.repeat(repeat_size)
return input_data
def define_model():
net = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
net_loss = nn.MSELoss()
net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
return Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={'mse', 'mae'})
class MyCallbackOldMethod(Callback):
""" Raise warning in `mindspore.Model.train` and `mindspore.Model.eval`; raise error in `mindspore.Model.fit`"""
def begin(self, run_context):
print("custom callback: print on begin, just for test.")
def step_end(self, run_context):
cb_params = run_context.original_args()
outputs = cb_params.get("net_outputs")
result = outputs if isinstance(outputs, Tensor) else outputs[0]
print("custom train callback: step end, loss is %s" % (result))
def on_train_epoch_end(self, run_context):
cb_params = run_context.original_args()
print("custom train callback: epoch end, loss is %s" % (cb_params.get("net_outputs")))
class MyCallbackNewMethod(Callback):
""" Custom callback running in `mindspore.Model.train`, `mindspore.Model.eval`, `mindspore.Model.fit`"""
def on_train_epoch_end(self, run_context):
cb_params = run_context.original_args()
print("custom callback: train epoch end, loss is %s" % (cb_params.get("net_outputs")))
def on_eval_epoch_end(self, run_context):
cb_params = run_context.original_args()
print("custom callback: eval epoch end, metric is %s" % (cb_params.get("net_outputs")[0]))
def test_fit_train_dataset_non_sink_mode():
"""
Feature: `mindspore.Model.fit` with train dataset in non-sink mode.
Description: test fit with train dataset in non-sink mode.
Expectation: run in non-sink mode.
"""
model = define_model()
ds_train = create_dataset(4096, 1024)
ds_eval = create_dataset(1024, 512)
callbacks = [LossMonitor()]
model.fit(3, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=False)
def test_fit_train_dataset_sink_mode():
"""
Feature: `mindspore.Model.fit` with train dataset in sink mode.
Description: test fit with train dataset in sink mode.
Expectation: run in sink mode.
"""
model = define_model()
ds_train = create_dataset(4096, 1024)
ds_eval = create_dataset(1024, 512)
callbacks = [LossMonitor()]
model.fit(3, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=True, sink_size=256)
def test_fit_valid_dataset_non_sink_mode():
"""
Feature: `mindspore.Model.fit` with valid dataset in non-sink mode.
Description: test fit with valid dataset in non-sink mode.
Expectation: run in non-sink mode.
"""
model = define_model()
ds_train = create_dataset(4096, 1024)
ds_eval = create_dataset(1024, 512)
callbacks = [LossMonitor()]
model.fit(3, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=False)
def test_fit_valid_dataset_sink_mode():
"""
Feature: `mindspore.Model.fit` with valid dataset in sink mode.
Description: test fit with valid dataset in sink mode.
Expectation: run in sink mode.
"""
model = define_model()
ds_train = create_dataset(4096, 1024)
ds_eval = create_dataset(1024, 512)
callbacks = [LossMonitor()]
model.fit(3, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=True)
def test_fit_without_valid_dataset():
"""
Feature: `mindspore.Model.fit` without `valid_dataset` input .
Description: test fit when `valid_dataset` is None and `valid_dataset_sink_mode` is True or False.
Expectation: network train without eval process, `valid_dataset_sink_mode` does not take effect.
"""
model = define_model()
ds_train = create_dataset(4096, 1024)
callbacks = [LossMonitor()]
model.fit(3, ds_train, None, callbacks=callbacks, valid_dataset_sink_mode=False)
model.fit(3, ds_train, None, callbacks=callbacks)
def test_fit_valid_frequency():
"""
Feature: check `valid_frequency` input in `mindspore.Model.fit`.
Description: when `valid_frequency` is integer, list or other types.
Expectation: raise ValueError when the type of valid_frequency is not int or list.
"""
model = define_model()
callbacks = [LossMonitor()]
ds_train = create_dataset(4096, 1024)
ds_eval = create_dataset(1024, 512)
model.fit(3, ds_train, ds_eval, valid_frequency=1, callbacks=callbacks)
model.fit(5, ds_train, ds_eval, valid_frequency=2, callbacks=callbacks)
model.fit(5, ds_train, ds_eval, valid_frequency=[0, 1, 4], callbacks=callbacks)
with pytest.raises(ValueError):
model.fit(5, ds_train, ds_eval, valid_frequency=(0, 2), callbacks=callbacks)
def test_fit_callbacks():
"""
Feature: check `callbacks` input in `mindspore.Model.fit`.
Description: test internal or custom callbacks in fit.
Expectation: raise ValueError when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'.
"""
model = define_model()
ds_train = create_dataset(4096, 1024)
ds_eval = create_dataset(1024, 512)
model.fit(3, ds_train, ds_eval, callbacks=None)
model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor()])
model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), LossMonitor()])
model.fit(3, ds_train, ds_eval, callbacks=[MyCallbackNewMethod()])
model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), MyCallbackNewMethod()])
with pytest.raises(ValueError):
model.fit(3, ds_train, ds_eval, callbacks=[MyCallbackOldMethod()])
with pytest.raises(ValueError):
model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), MyCallbackOldMethod()])
with pytest.raises(ValueError):
model.fit(3, ds_train, valid_dataset=None, callbacks=[TimeMonitor(), MyCallbackOldMethod()])
def test_train_eval_callbacks():
"""
Feature: check `callbacks` input in `mindspore.Model.train` or `mindspore.Model.eval`.
Description: test internal or custom callbacks in train or eval.
Expectation: raise warning when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'.
"""
model = define_model()
ds_train = create_dataset(4096, 1024)
ds_eval = create_dataset(1024, 512)
model.train(3, ds_train, callbacks=None)
model.train(3, ds_train, callbacks=[TimeMonitor()])
model.train(3, ds_train, callbacks=[LossMonitor()])
model.train(3, ds_train, callbacks=[MyCallbackNewMethod()])
model.train(3, ds_train, callbacks=[MyCallbackOldMethod()])
metric_results = model.eval(ds_eval, callbacks=None)
print("{}".format(metric_results))
metric_results = model.eval(ds_eval, callbacks=[TimeMonitor()])
print("{}".format(metric_results))
metric_results = model.eval(ds_eval, callbacks=[MyCallbackNewMethod()])
print("{}".format(metric_results))
metric_results = model.eval(ds_eval, callbacks=[MyCallbackOldMethod()])
print("{}".format(metric_results))

View File

@ -513,13 +513,13 @@ def test_lambda():
run_context = RunContext(cb_params)
lambda_cb = LambdaCallback(
epoch_end=lambda run_context: print("loss result: ", run_context.original_args().net_outputs))
on_train_epoch_end=lambda run_context: print("loss result: ", run_context.original_args().net_outputs))
callbacks = [lambda_cb]
with _CallbackManager(callbacks) as callbacklist:
callbacklist.begin(run_context)
callbacklist.epoch_begin(run_context)
callbacklist.step_begin(run_context)
callbacklist.step_end(run_context)
callbacklist.epoch_end(run_context)
callbacklist.end(run_context)
callbacklist.on_train_begin(run_context)
callbacklist.on_train_epoch_begin(run_context)
callbacklist.on_train_step_begin(run_context)
callbacklist.on_train_step_end(run_context)
callbacklist.on_train_epoch_end(run_context)
callbacklist.on_train_end(run_context)