forked from mindspore-Ecosystem/mindspore
add model.fit and refactor callbacks
This commit is contained in:
parent
f35e313d33
commit
f6c8053c2b
|
@ -6,8 +6,7 @@ mindspore.Callback
|
||||||
用于构建Callback函数的基类。Callback函数是一个上下文管理器,在运行模型时被调用。
|
用于构建Callback函数的基类。Callback函数是一个上下文管理器,在运行模型时被调用。
|
||||||
可以使用此机制进行一些自定义操作。
|
可以使用此机制进行一些自定义操作。
|
||||||
|
|
||||||
Callback类的每个方法对应了训练或推理过程的不同阶段,这些方法有相同的入参 `run_context`,用于保存模型
|
Callback类的每个方法对应了训练或推理过程的不同阶段,这些方法有相同的入参 `run_context`,用于保存训练或推理过程中模型的相关信息。定义Callback子类或自定义Callback时,请根据需要重写名称前缀为“on_train”或“on_eval”的方法,否则自定义的Callback在 `model.fit` 中使用时会产生错误。
|
||||||
训练或推理过程模型的相关信息。定义Callback子类或自定义Callback时,请根据需要重写对应的方法。
|
|
||||||
|
|
||||||
自定义Callback场景下,在类方法中通过 `RunContext.original_args()` 方法可以获取模型训练或推理过程中已有
|
自定义Callback场景下,在类方法中通过 `RunContext.original_args()` 方法可以获取模型训练或推理过程中已有
|
||||||
的上下文信息,此信息为一个存储了已有属性的字典型变量;用户也可以在此信息中添加其他的自定义属性;此外,
|
的上下文信息,此信息为一个存储了已有属性的字典型变量;用户也可以在此信息中添加其他的自定义属性;此外,
|
||||||
|
@ -16,7 +15,7 @@ mindspore.Callback
|
||||||
|
|
||||||
.. py:method:: begin(run_context)
|
.. py:method:: begin(run_context)
|
||||||
|
|
||||||
在网络执行之前被调用一次。
|
在网络执行之前被调用一次。与 `on_train_begin` 和 `on_eval_begin` 方法具有兼容性。
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
|
@ -24,7 +23,7 @@ mindspore.Callback
|
||||||
|
|
||||||
.. py:method:: end(run_context)
|
.. py:method:: end(run_context)
|
||||||
|
|
||||||
网络执行后被调用一次。
|
网络执行后被调用一次。与 `on_train_end` 和 `on_eval_end` 方法具有兼容性。
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
|
@ -32,7 +31,7 @@ mindspore.Callback
|
||||||
|
|
||||||
.. py:method:: epoch_begin(run_context)
|
.. py:method:: epoch_begin(run_context)
|
||||||
|
|
||||||
在每个epoch开始之前被调用。
|
在每个epoch开始之前被调用。与 `on_train_epoch_begin` 和 `on_eval_epoch_begin` 方法具有兼容性。
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
|
@ -40,7 +39,7 @@ mindspore.Callback
|
||||||
|
|
||||||
.. py:method:: epoch_end(run_context)
|
.. py:method:: epoch_end(run_context)
|
||||||
|
|
||||||
在每个epoch结束后被调用。
|
在每个epoch结束后被调用。与 `on_train_epoch_end` 和 `on_eval_epoch_end` 方法具有兼容性。
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
|
@ -52,7 +51,7 @@ mindspore.Callback
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
- **run_context** (RunContext) - 包含模型的一些基本信息。与 `on_train_step_begin` 和 `on_eval_step_begin` 方法具有兼容性。
|
||||||
|
|
||||||
.. py:method:: step_end(run_context)
|
.. py:method:: step_end(run_context)
|
||||||
|
|
||||||
|
@ -60,4 +59,100 @@ mindspore.Callback
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。与 `on_train_step_end` 和 `on_eval_step_end` 方法具有兼容性。
|
||||||
|
|
||||||
|
.. py:method:: on_train_begin(run_context)
|
||||||
|
|
||||||
|
在网络执行训练之前调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method::on_train_end(run_context)
|
||||||
|
|
||||||
|
网络训练执行结束时调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_train_epoch_begin(run_context)
|
||||||
|
|
||||||
|
在训练的每个epoch开始之前被调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_train_epoch_end(run_context)
|
||||||
|
|
||||||
|
在训练的每个epoch结束后被调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_train_step_begin(run_context)
|
||||||
|
|
||||||
|
在训练的每个step开始之前被调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_train_step_end(run_context)
|
||||||
|
|
||||||
|
在训练的每个step完成后被调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_eval_begin(run_context)
|
||||||
|
|
||||||
|
在网络执行推理之前调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_eval_end(run_context)
|
||||||
|
|
||||||
|
网络执行推理之后调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_eval_epoch_begin(run_context)
|
||||||
|
|
||||||
|
在推理的epoch开始之前被调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_evalepoch_end(run_context)
|
||||||
|
|
||||||
|
在推理的epoch结束后被调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_eval_step_begin(run_context)
|
||||||
|
|
||||||
|
在推理的每个step开始之前被调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
||||||
|
.. py:method:: on_eval_step_end(run_context)
|
||||||
|
|
||||||
|
在推理的每个step完成后被调用。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
- **run_context** (RunContext) - 包含模型的一些基本信息。
|
||||||
|
|
|
@ -3,12 +3,12 @@ mindspore.History
|
||||||
|
|
||||||
.. py:class:: mindspore.History
|
.. py:class:: mindspore.History
|
||||||
|
|
||||||
将网络输出的相关信息记录到 `History` 对象中。
|
将网络输出和评估指标的相关信息记录到 `History` 对象中。
|
||||||
|
|
||||||
用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 `Tensor` 或 `numpy.ndarray`,则记录此返回值均值,如果返回 `tuple` 或 `list`,则记录第一个元素。
|
用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 `Tensor` 或 `numpy.ndarray`,则记录此返回值均值,如果返回 `tuple` 或 `list`,则记录第一个元素。
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
通常使用在 `mindspore.Model.train` 中。
|
通常使用在 `mindspore.Model.train` 和 `mindspore.Model.fit` 中。
|
||||||
|
|
||||||
.. py:method:: begin(run_context)
|
.. py:method:: begin(run_context)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ mindspore.History
|
||||||
|
|
||||||
.. py:method:: epoch_end(run_context)
|
.. py:method:: epoch_end(run_context)
|
||||||
|
|
||||||
epoch结束时记录网络输出的相关信息。
|
epoch结束时记录网络输出和评估指标的相关信息。
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ mindspore.LambdaCallback
|
||||||
|
|
||||||
用于自定义简单的callback。
|
用于自定义简单的callback。
|
||||||
|
|
||||||
使用匿名函数构建callback,定义的匿名函数将在 `mindspore.Model.{train | eval}` 的对应阶段被调用。
|
使用匿名函数构建callback,定义的匿名函数将在 `mindspore.Model.{train | eval | fit}` 的对应阶段被调用。
|
||||||
|
|
||||||
请注意,callback的每个阶段都需要一个位置参数:`run_context`。
|
请注意,callback的每个阶段都需要一个位置参数:`run_context`。
|
||||||
|
|
||||||
|
@ -14,9 +14,15 @@ mindspore.LambdaCallback
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **epoch_begin** (Function) - 每个epoch开始时被调用。
|
- **on_train_epoch_begin** (Function) - 训练每个epoch开始时被调用。
|
||||||
- **epoch_end** (Function) - 每个epoch结束时被调用。
|
- **on_train_epoch_end** (Function) - 训练每个epoch结束时被调用。
|
||||||
- **step_begin** (Function) - 每个step开始时被调用。
|
- **on_train_step_begin** (Function) - 训练每个step开始时被调用。
|
||||||
- **step_end** (Function) - 每个step结束时被调用。
|
- **on_train_step_end** (Function) - 训练每个step结束时被调用。
|
||||||
- **begin** (Function) - 模型训练、评估开始时被调用。
|
- **on_train_begin** (Function) - 模型训练开始时被调用。
|
||||||
- **end** (Function) - 模型训练、评估结束时被调用。
|
- **on_train_end** (Function) - 模型训练结束时被调用。
|
||||||
|
- **on_eval_epoch_begin** (Function) - 推理的epoch开始时被调用。
|
||||||
|
- **on_eval_epoch_end** (Function) - 推理的epoch结束时被调用。
|
||||||
|
- **on_eval_step_begin** (Function) - 推理的每个step开始时被调用。
|
||||||
|
- **on_eval_step_end** (Function) - 推理的每个step结束时被调用。
|
||||||
|
- **on_eval_begin** (Function) - 模型推理开始时被调用。
|
||||||
|
- **on_eval_end** (Function) - 模型推理结束时被调用。
|
||||||
|
|
|
@ -3,7 +3,7 @@ mindspore.LossMonitor
|
||||||
|
|
||||||
.. py:class:: mindspore.LossMonitor(per_print_times=1)
|
.. py:class:: mindspore.LossMonitor(per_print_times=1)
|
||||||
|
|
||||||
监控训练的loss。
|
训练场景下,监控训练的loss;边训练边推理场景下,监控训练的loss和推理的metrics。
|
||||||
|
|
||||||
如果loss是NAN或INF,则终止训练。
|
如果loss是NAN或INF,则终止训练。
|
||||||
|
|
||||||
|
@ -25,3 +25,11 @@ mindspore.LossMonitor
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||||
|
|
||||||
|
.. py:method:: on_train_epoch_end(run_context)
|
||||||
|
|
||||||
|
LossMoniter用于 `model.fit`,即边训练边推理场景时,打印训练的loss和当前epoch推理的metrics。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。
|
||||||
|
|
|
@ -162,6 +162,24 @@
|
||||||
- **sink_size** (int) – 控制每次数据下沉的数据量。`dataset_sink_mode` 为False时 `sink_size` 无效。如果sink_size=-1,则每一次epoch下沉完整数据集。如果sink_size>0,则每一次epoch下沉数据量为sink_size的数据集。默认值:-1。
|
- **sink_size** (int) – 控制每次数据下沉的数据量。`dataset_sink_mode` 为False时 `sink_size` 无效。如果sink_size=-1,则每一次epoch下沉完整数据集。如果sink_size>0,则每一次epoch下沉数据量为sink_size的数据集。默认值:-1。
|
||||||
- **initial_epoch** (int) - 从哪个epoch开始训练,一般用于中断恢复训练场景。
|
- **initial_epoch** (int) - 从哪个epoch开始训练,一般用于中断恢复训练场景。
|
||||||
|
|
||||||
|
.. py:method:: fit(epoch, train_dataset, valid_dataset=None, valid_frequency=1, callbacks=None, dataset_sink_mode=True, valid_dataset_sink_mode=True, sink_size=-1, initial_epoch=0)
|
||||||
|
|
||||||
|
模型边训练边推理接口。
|
||||||
|
|
||||||
|
如果 `valid_dataset` 不为None,在训练过程中同时执行推理。更多详细信息请参考 `mindspore.Model.model.train`。
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **epoch** (int) – 训练执行轮次。通常每个epoch都会使用全量数据集进行训练。当 `dataset_sink_mode` 设置为True且 `sink_size` 大于零时,则每个epoch训练次数为 `sink_size` 而不是数据集的总步数。如果 `epoch` 与 `initial_epoch` 一起使用,它表示训练的最后一个 `epoch` 是多少。
|
||||||
|
- **train_dataset** (Dataset) – 训练数据集迭代器。如果定义了 `loss_fn` ,则数据和标签会被分别传给 `network` 和 `loss_fn` ,此时数据集需要返回一个元组(data, label)。如果数据集中有多个数据或者标签,可以设置 `loss_fn` 为None,并在 `network` 中实现损失函数计算,此时数据集返回的所有数据组成的元组(data1, data2, data3, ...)会传给 `network` 。
|
||||||
|
- **valid_dataset** (Dataset) – 评估模型的数据集迭代器。默认值:None。
|
||||||
|
- **valid_frequency** (Dataset) – 此参数只有在valid_dataset不为None时生效。如果为int类型,表示执行推理的频率,例如 `valid_frequency=2`,则每2个训练epoch执行一次推理;如果为list类型,指明在哪几个epoch时执行推理,例如 `valid_frequency=[1, 5]`,则在第1个和第5个epoch执行推理。默认值:1。
|
||||||
|
- **callbacks** (Optional[list[Callback], Callback]) – 训练过程中需要执行的回调对象或者回调对象列表。默认值:None。
|
||||||
|
- **dataset_sink_mode** (bool) – 训练数据是否直接下沉至处理器进行处理。使用PYNATIVE_MODE模式或CPU处理器时,模型训练流程将以非下沉模式执行。默认值:True。
|
||||||
|
- **valid_dataset_sink_mode** (bool) - 推理数据是否直接下沉至处理器进行处理。默认值:True。
|
||||||
|
- **sink_size** (int) – 控制每次数据下沉的数据量。`dataset_sink_mode` 为False时 `sink_size` 无效。如果sink_size=-1,则每一次epoch下沉完整数据集。如果sink_size>0,则每一次epoch下沉数据量为sink_size的数据集。默认值:-1。
|
||||||
|
- **initial_epoch** (int) - 从哪个epoch开始训练,一般用于中断恢复训练场景。
|
||||||
|
|
||||||
.. py:method:: train_network
|
.. py:method:: train_network
|
||||||
:property:
|
:property:
|
||||||
|
|
||||||
|
@ -169,4 +187,4 @@
|
||||||
|
|
||||||
**返回:**
|
**返回:**
|
||||||
|
|
||||||
预测网络实例。
|
预测网络实例。
|
||||||
|
|
|
@ -3,7 +3,7 @@ mindspore.TimeMonitor
|
||||||
|
|
||||||
.. py:class:: mindspore.TimeMonitor(data_size=None)
|
.. py:class:: mindspore.TimeMonitor(data_size=None)
|
||||||
|
|
||||||
监控训练时间。
|
监控训练或推理的时间。
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,8 @@ class Callback:
|
||||||
Each method of Callback class corresponds to a stage in training or eval process, and those methods
|
Each method of Callback class corresponds to a stage in training or eval process, and those methods
|
||||||
have the same input `run_context`, which hold context information of the model in
|
have the same input `run_context`, which hold context information of the model in
|
||||||
training or eval process. When defining a Callback subclass or creating a custom Callback,
|
training or eval process. When defining a Callback subclass or creating a custom Callback,
|
||||||
override these methods.
|
Note that you should override methods with names prefixed with "on_train" or "on_eval",
|
||||||
|
otherwise ValueError will be raised if the custimized Callbacks used in `model.fit`.
|
||||||
|
|
||||||
When creating a custom Callback, model context information can be obtained in Callback
|
When creating a custom Callback, model context information can be obtained in Callback
|
||||||
methods by calling `RunContext.original_args()`, which is a dictionary varivable
|
methods by calling `RunContext.original_args()`, which is a dictionary varivable
|
||||||
|
@ -122,6 +123,7 @@ class Callback:
|
||||||
def begin(self, run_context):
|
def begin(self, run_context):
|
||||||
"""
|
"""
|
||||||
Called once before the network executing.
|
Called once before the network executing.
|
||||||
|
A backwards compatibility alias for `on_train_begin` and `on_eval_begin`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_context (RunContext): Include some information of the model.
|
run_context (RunContext): Include some information of the model.
|
||||||
|
@ -130,6 +132,7 @@ class Callback:
|
||||||
def epoch_begin(self, run_context):
|
def epoch_begin(self, run_context):
|
||||||
"""
|
"""
|
||||||
Called before each epoch beginning.
|
Called before each epoch beginning.
|
||||||
|
A backwards compatibility alias for `on_train_epoch_begin` and `on_eval_epoch_begin`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_context (RunContext): Include some information of the model.
|
run_context (RunContext): Include some information of the model.
|
||||||
|
@ -138,6 +141,7 @@ class Callback:
|
||||||
def epoch_end(self, run_context):
|
def epoch_end(self, run_context):
|
||||||
"""
|
"""
|
||||||
Called after each epoch finished.
|
Called after each epoch finished.
|
||||||
|
A backwards compatibility alias for `on_train_epoch_end` and `on_eval_epoch_end`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_context (RunContext): Include some information of the model.
|
run_context (RunContext): Include some information of the model.
|
||||||
|
@ -146,6 +150,7 @@ class Callback:
|
||||||
def step_begin(self, run_context):
|
def step_begin(self, run_context):
|
||||||
"""
|
"""
|
||||||
Called before each step beginning.
|
Called before each step beginning.
|
||||||
|
A backwards compatibility alias for `on_train_step_begin` and `on_eval_step_begin`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_context (RunContext): Include some information of the model.
|
run_context (RunContext): Include some information of the model.
|
||||||
|
@ -154,6 +159,7 @@ class Callback:
|
||||||
def step_end(self, run_context):
|
def step_end(self, run_context):
|
||||||
"""
|
"""
|
||||||
Called after each step finished.
|
Called after each step finished.
|
||||||
|
A backwards compatibility alias for `on_train_step_end` and `on_eval_step_end`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_context (RunContext): Include some information of the model.
|
run_context (RunContext): Include some information of the model.
|
||||||
|
@ -162,11 +168,120 @@ class Callback:
|
||||||
def end(self, run_context):
|
def end(self, run_context):
|
||||||
"""
|
"""
|
||||||
Called once after network training.
|
Called once after network training.
|
||||||
|
A backwards compatibility alias for `on_train_end` and `on_eval_end`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_context (RunContext): Include some information of the model.
|
run_context (RunContext): Include some information of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def on_train_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called once before the network training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.begin(run_context)
|
||||||
|
|
||||||
|
def on_train_epoch_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called before each training epoch begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.epoch_begin(run_context)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called after each training epoch end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.epoch_end(run_context)
|
||||||
|
|
||||||
|
def on_train_step_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called before each training step begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.step_begin(run_context)
|
||||||
|
|
||||||
|
def on_train_step_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called after each training step end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.step_end(run_context)
|
||||||
|
|
||||||
|
def on_train_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called after training end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.end(run_context)
|
||||||
|
|
||||||
|
def on_eval_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called before eval begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.begin(run_context)
|
||||||
|
|
||||||
|
def on_eval_epoch_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called before eval epoch begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.epoch_begin(run_context)
|
||||||
|
|
||||||
|
def on_eval_epoch_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called after eval epoch end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.epoch_end(run_context)
|
||||||
|
|
||||||
|
def on_eval_step_begin(self, run_context):
|
||||||
|
"""
|
||||||
|
Called before each eval step begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.step_begin(run_context)
|
||||||
|
|
||||||
|
def on_eval_step_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called after each eval step end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.step_end(run_context)
|
||||||
|
|
||||||
|
def on_eval_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Called after eval end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.end(run_context)
|
||||||
|
|
||||||
|
|
||||||
class CallbackManager(Callback):
|
class CallbackManager(Callback):
|
||||||
"""
|
"""
|
||||||
|
@ -208,7 +323,7 @@ class CallbackManager(Callback):
|
||||||
return self._stack.__exit__(*err)
|
return self._stack.__exit__(*err)
|
||||||
|
|
||||||
def begin(self, run_context):
|
def begin(self, run_context):
|
||||||
"""Called once before network training."""
|
"""Called once before network train or eval."""
|
||||||
for cb in self._callbacks:
|
for cb in self._callbacks:
|
||||||
cb.begin(run_context)
|
cb.begin(run_context)
|
||||||
|
|
||||||
|
@ -223,7 +338,7 @@ class CallbackManager(Callback):
|
||||||
cb.epoch_end(run_context)
|
cb.epoch_end(run_context)
|
||||||
|
|
||||||
def step_begin(self, run_context):
|
def step_begin(self, run_context):
|
||||||
"""Called before each epoch begin."""
|
"""Called before each step begin."""
|
||||||
for cb in self._callbacks:
|
for cb in self._callbacks:
|
||||||
cb.step_begin(run_context)
|
cb.step_begin(run_context)
|
||||||
|
|
||||||
|
@ -233,10 +348,70 @@ class CallbackManager(Callback):
|
||||||
cb.step_end(run_context)
|
cb.step_end(run_context)
|
||||||
|
|
||||||
def end(self, run_context):
|
def end(self, run_context):
|
||||||
"""Called once after network training."""
|
"""Called once after network train or eval."""
|
||||||
for cb in self._callbacks:
|
for cb in self._callbacks:
|
||||||
cb.end(run_context)
|
cb.end(run_context)
|
||||||
|
|
||||||
|
def on_train_begin(self, run_context):
|
||||||
|
"""Called before network train."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_train_begin(run_context)
|
||||||
|
|
||||||
|
def on_train_epoch_begin(self, run_context):
|
||||||
|
"""Called before each train epoch begin."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_train_epoch_begin(run_context)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, run_context):
|
||||||
|
"""Called after each train epoch finished."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_train_epoch_end(run_context)
|
||||||
|
|
||||||
|
def on_train_step_begin(self, run_context):
|
||||||
|
"""Called before each train step begin."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_train_step_begin(run_context)
|
||||||
|
|
||||||
|
def on_train_step_end(self, run_context):
|
||||||
|
"""Called after each train step finished."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.step_end(run_context)
|
||||||
|
|
||||||
|
def on_train_end(self, run_context):
|
||||||
|
"""Called after network train end."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_train_end(run_context)
|
||||||
|
|
||||||
|
def on_eval_begin(self, run_context):
|
||||||
|
"""Called before network eval."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_eval_begin(run_context)
|
||||||
|
|
||||||
|
def on_eval_epoch_begin(self, run_context):
|
||||||
|
"""Called before eval epoch begin."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_eval_epoch_begin(run_context)
|
||||||
|
|
||||||
|
def on_eval_epoch_end(self, run_context):
|
||||||
|
"""Called after eval epoch finished."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_eval_epoch_end(run_context)
|
||||||
|
|
||||||
|
def on_eval_step_begin(self, run_context):
|
||||||
|
"""Called before each eval step begin."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_eval_step_begin(run_context)
|
||||||
|
|
||||||
|
def on_eval_step_end(self, run_context):
|
||||||
|
"""Called after each eval step finished."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_eval_step_end(run_context)
|
||||||
|
|
||||||
|
def on_eval_end(self, run_context):
|
||||||
|
"""Called after network eval end."""
|
||||||
|
for cb in self._callbacks:
|
||||||
|
cb.on_eval_end(run_context)
|
||||||
|
|
||||||
|
|
||||||
class InternalCallbackParam(dict):
|
class InternalCallbackParam(dict):
|
||||||
"""Internal callback object's parameters."""
|
"""Internal callback object's parameters."""
|
||||||
|
|
|
@ -21,7 +21,7 @@ from ._callback import Callback
|
||||||
|
|
||||||
class History(Callback):
|
class History(Callback):
|
||||||
"""
|
"""
|
||||||
Records the network outputs information into a `History` object.
|
Records the network outputs and metrics information into a `History` object.
|
||||||
|
|
||||||
The network outputs information will be the loss value if not custimizing the train network or eval network;
|
The network outputs information will be the loss value if not custimizing the train network or eval network;
|
||||||
if the custimized network returns a `Tensor` or `numpy.ndarray`, the mean value of network output
|
if the custimized network returns a `Tensor` or `numpy.ndarray`, the mean value of network output
|
||||||
|
@ -29,7 +29,7 @@ class History(Callback):
|
||||||
outputs will be recorded.
|
outputs will be recorded.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Normally used in `mindspore.Model.train`.
|
Normally used in `mindspore.Model.train` or `mindspore.Model.fit`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
|
@ -65,7 +65,7 @@ class History(Callback):
|
||||||
|
|
||||||
def epoch_end(self, run_context):
|
def epoch_end(self, run_context):
|
||||||
"""
|
"""
|
||||||
Records the first element of network outputs at the end of epoch.
|
Records the first element of network outputs and metrics information at the end of epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. For more details,
|
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. For more details,
|
||||||
|
|
|
@ -22,19 +22,25 @@ class LambdaCallback(Callback):
|
||||||
Callback for creating simple, custom callbacks.
|
Callback for creating simple, custom callbacks.
|
||||||
|
|
||||||
This callback is constructed with anonymous functions that will be called
|
This callback is constructed with anonymous functions that will be called
|
||||||
at the appropriate time (during `mindspore.Model.{train | eval}`). Note that
|
at the appropriate time (during `mindspore.Model.{train | eval | fit}`). Note that
|
||||||
each stage of callbacks expects one positional arguments: `run_context`.
|
each stage of callbacks expects one positional arguments: `run_context`.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This is an experimental interface that is subject to change or deletion.
|
This is an experimental interface that is subject to change or deletion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
epoch_begin (Function): called at the beginning of every epoch.
|
on_train_epoch_begin (Function): called at each train epoch begin.
|
||||||
epoch_end (Function): called at the end of every epoch.
|
on_train_epoch_end (Function): called at each train epoch end.
|
||||||
step_begin (Function): called at the beginning of every batch.
|
on_train_step_begin (Function): called at each train step begin.
|
||||||
step_end (Function): called at the end of every batch.
|
on_train_step_end (Function): called at each train step end.
|
||||||
begin (Function): called at the beginning of model train/eval.
|
on_train_begin (Function): called at the beginning of model train.
|
||||||
end (Function): called at the end of model train/eval.
|
on_train_end (Function): called at the end of model train.
|
||||||
|
on_eval_epoch_begin (Function): called at eval epoch begin.
|
||||||
|
on_eval_epoch_end (Function): called at eval epoch end.
|
||||||
|
on_eval_step_begin (Function): called at each eval step begin.
|
||||||
|
on_eval_step_end (Function): called at each eval step end.
|
||||||
|
on_eval_begin (Function): called at the beginning of model eval.
|
||||||
|
on_eval_end (Function): called at the end of model eval.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import numpy as np
|
>>> import numpy as np
|
||||||
|
@ -46,19 +52,28 @@ class LambdaCallback(Callback):
|
||||||
>>> net = nn.Dense(10, 5)
|
>>> net = nn.Dense(10, 5)
|
||||||
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
||||||
>>> lambda_callback = LambdaCallback(epoch_end=
|
>>> lambda_callback = LambdaCallback(on_train_epoch_end=
|
||||||
... lambda run_context: print("loss: ", run_context.original_args().net_outputs))
|
... lambda run_context: print("loss: ", run_context.original_args().net_outputs))
|
||||||
>>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
|
>>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
|
||||||
>>> model.train(2, train_dataset, callbacks=[lambda_callback])
|
>>> model.train(2, train_dataset, callbacks=[lambda_callback])
|
||||||
loss: 1.6127687
|
loss: 1.6127687
|
||||||
loss: 1.6106578
|
loss: 1.6106578
|
||||||
"""
|
"""
|
||||||
def __init__(self, epoch_begin=None, epoch_end=None, step_begin=None,
|
def __init__(self, on_train_epoch_begin=None, on_train_epoch_end=None, on_train_step_begin=None,
|
||||||
step_end=None, begin=None, end=None):
|
on_train_step_end=None, on_train_begin=None, on_train_end=None,
|
||||||
|
on_eval_epoch_begin=None, on_eval_epoch_end=None, on_eval_step_begin=None,
|
||||||
|
on_eval_step_end=None, on_eval_begin=None, on_eval_end=None):
|
||||||
super(LambdaCallback, self).__init__()
|
super(LambdaCallback, self).__init__()
|
||||||
self.epoch_begin = epoch_begin if epoch_begin else lambda run_context: None
|
self.on_train_epoch_begin = on_train_epoch_begin if on_train_epoch_begin else lambda run_context: None
|
||||||
self.epoch_end = epoch_end if epoch_end else lambda run_context: None
|
self.on_train_epoch_end = on_train_epoch_end if on_train_epoch_end else lambda run_context: None
|
||||||
self.step_begin = step_begin if step_begin else lambda run_context: None
|
self.on_train_step_begin = on_train_step_begin if on_train_step_begin else lambda run_context: None
|
||||||
self.step_end = step_end if step_end else lambda run_context: None
|
self.on_train_step_end = on_train_step_end if on_train_step_end else lambda run_context: None
|
||||||
self.begin = begin if begin else lambda run_context: None
|
self.on_train_begin = on_train_begin if on_train_begin else lambda run_context: None
|
||||||
self.end = end if end else lambda run_context: None
|
self.on_train_end = on_train_end if on_train_end else lambda run_context: None
|
||||||
|
|
||||||
|
self.on_eval_epoch_begin = on_eval_epoch_begin if on_eval_epoch_begin else lambda run_context: None
|
||||||
|
self.on_eval_epoch_end = on_eval_epoch_end if on_eval_epoch_end else lambda run_context: None
|
||||||
|
self.on_eval_step_begin = on_eval_step_begin if on_eval_step_begin else lambda run_context: None
|
||||||
|
self.on_eval_step_end = on_eval_step_end if on_eval_step_end else lambda run_context: None
|
||||||
|
self.on_eval_begin = on_eval_begin if on_eval_begin else lambda run_context: None
|
||||||
|
self.on_eval_end = on_eval_end if on_eval_end else lambda run_context: None
|
||||||
|
|
|
@ -23,7 +23,7 @@ from ._callback import Callback
|
||||||
|
|
||||||
class LossMonitor(Callback):
|
class LossMonitor(Callback):
|
||||||
"""
|
"""
|
||||||
Monitor the loss in training.
|
Monitor the loss in train or monitor the loss and eval metrics in fit.
|
||||||
|
|
||||||
If the loss is NAN or INF, it will terminate training.
|
If the loss is NAN or INF, it will terminate training.
|
||||||
|
|
||||||
|
@ -90,3 +90,17 @@ class LossMonitor(Callback):
|
||||||
if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times:
|
if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times:
|
||||||
self._last_print_time = cb_params.cur_step_num
|
self._last_print_time = cb_params.cur_step_num
|
||||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, run_context):
|
||||||
|
"""
|
||||||
|
When LossMoniter used in `model.fit`, print eval metrics at the end of epoch if current epoch
|
||||||
|
should do evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context (RunContext): Include some information of the model. For more details,
|
||||||
|
please refer to :class:`mindspore.RunContext`.
|
||||||
|
"""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
metrics = cb_params.get("metrics")
|
||||||
|
if metrics:
|
||||||
|
print("Eval result: epoch %d, metrics: %s" % (cb_params.cur_epoch_num, metrics))
|
||||||
|
|
|
@ -22,7 +22,7 @@ from ._callback import Callback
|
||||||
|
|
||||||
class TimeMonitor(Callback):
|
class TimeMonitor(Callback):
|
||||||
"""
|
"""
|
||||||
Monitor the time in training.
|
Monitor the time in train or eval process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_size (int): How many steps are the intervals between print information each time.
|
data_size (int): How many steps are the intervals between print information each time.
|
||||||
|
@ -71,6 +71,7 @@ class TimeMonitor(Callback):
|
||||||
epoch_seconds = (time.time() - self.epoch_time) * 1000
|
epoch_seconds = (time.time() - self.epoch_time) * 1000
|
||||||
step_size = self.data_size
|
step_size = self.data_size
|
||||||
cb_params = run_context.original_args()
|
cb_params = run_context.original_args()
|
||||||
|
mode = cb_params.get("mode", "")
|
||||||
if hasattr(cb_params, "batch_num"):
|
if hasattr(cb_params, "batch_num"):
|
||||||
batch_num = cb_params.batch_num
|
batch_num = cb_params.batch_num
|
||||||
if isinstance(batch_num, int) and batch_num > 0:
|
if isinstance(batch_num, int) and batch_num > 0:
|
||||||
|
@ -78,4 +79,5 @@ class TimeMonitor(Callback):
|
||||||
Validator.check_positive_int(step_size)
|
Validator.check_positive_int(step_size)
|
||||||
|
|
||||||
step_seconds = epoch_seconds / step_size
|
step_seconds = epoch_seconds / step_size
|
||||||
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, step_seconds), flush=True)
|
print("{} epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format
|
||||||
|
(mode.title(), epoch_seconds, step_seconds), flush=True)
|
||||||
|
|
|
@ -27,7 +27,8 @@ from .callback._checkpoint import _chg_ckpt_file_name_if_same_exist
|
||||||
from ..common.tensor import Tensor
|
from ..common.tensor import Tensor
|
||||||
from ..nn.metrics import get_metrics
|
from ..nn.metrics import get_metrics
|
||||||
from .._checkparam import check_input_data, check_output_data, Validator
|
from .._checkparam import check_input_data, check_output_data, Validator
|
||||||
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
|
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor
|
||||||
|
from .callback import __all__ as internal_cb_names
|
||||||
from .. import context
|
from .. import context
|
||||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
|
||||||
|
@ -503,7 +504,8 @@ class Model:
|
||||||
return [callbacks]
|
return [callbacks]
|
||||||
|
|
||||||
@_save_final_ckpt
|
@_save_final_ckpt
|
||||||
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0):
|
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0,
|
||||||
|
valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True):
|
||||||
"""
|
"""
|
||||||
Training.
|
Training.
|
||||||
|
|
||||||
|
@ -541,6 +543,7 @@ class Model:
|
||||||
cb_params.device_number = self._device_number
|
cb_params.device_number = self._device_number
|
||||||
cb_params.train_dataset = train_dataset
|
cb_params.train_dataset = train_dataset
|
||||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||||
|
valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
|
||||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||||
cb_params.list_callback.insert(0, _StepSync())
|
cb_params.list_callback.insert(0, _StepSync())
|
||||||
callbacks = cb_params.list_callback
|
callbacks = cb_params.list_callback
|
||||||
|
@ -555,17 +558,21 @@ class Model:
|
||||||
with _CallbackManager(callbacks) as list_callback:
|
with _CallbackManager(callbacks) as list_callback:
|
||||||
self._check_reuse_dataset(train_dataset)
|
self._check_reuse_dataset(train_dataset)
|
||||||
if not dataset_sink_mode:
|
if not dataset_sink_mode:
|
||||||
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch)
|
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos)
|
||||||
elif context.get_context("device_target") == "CPU":
|
elif context.get_context("device_target") == "CPU":
|
||||||
logger.info("The CPU cannot support dataset sink mode currently."
|
logger.info("The CPU cannot support dataset sink mode currently."
|
||||||
"So the training process will be performed with dataset not sink.")
|
"So the training process will be performed with dataset not sink.")
|
||||||
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch)
|
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos)
|
||||||
else:
|
else:
|
||||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback,
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback,
|
||||||
cb_params, sink_size, initial_epoch)
|
cb_params, sink_size, initial_epoch, valid_infos)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_eval(epoch, validation_freq):
|
||||||
|
return epoch % validation_freq == 0 if isinstance(validation_freq, int) else epoch in validation_freq
|
||||||
|
|
||||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None,
|
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None,
|
||||||
sink_size=-1, initial_epoch=0):
|
sink_size=-1, initial_epoch=0, valid_infos=None):
|
||||||
"""
|
"""
|
||||||
Training process. The data would be passed to network through dataset channel.
|
Training process. The data would be passed to network through dataset channel.
|
||||||
|
|
||||||
|
@ -593,7 +600,7 @@ class Model:
|
||||||
cb_params.dataset_sink_mode = True
|
cb_params.dataset_sink_mode = True
|
||||||
|
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
list_callback.begin(run_context)
|
list_callback.on_train_begin(run_context)
|
||||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||||
dataset_helper = None
|
dataset_helper = None
|
||||||
if hasattr(train_dataset, '_dataset_helper'):
|
if hasattr(train_dataset, '_dataset_helper'):
|
||||||
|
@ -609,7 +616,7 @@ class Model:
|
||||||
cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
|
cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
|
||||||
self._current_epoch_num = cb_params.cur_epoch_num
|
self._current_epoch_num = cb_params.cur_epoch_num
|
||||||
self._current_step_num = 0
|
self._current_step_num = 0
|
||||||
list_callback.epoch_begin(run_context)
|
list_callback.on_train_epoch_begin(run_context)
|
||||||
dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
dataset_sink_mode=True,
|
dataset_sink_mode=True,
|
||||||
|
@ -632,22 +639,51 @@ class Model:
|
||||||
cb_params.cur_step_num += 1
|
cb_params.cur_step_num += 1
|
||||||
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
||||||
cb_params.train_dataset_element = inputs
|
cb_params.train_dataset_element = inputs
|
||||||
list_callback.step_begin(run_context)
|
list_callback.on_train_step_begin(run_context)
|
||||||
outputs = train_network(*inputs)
|
outputs = train_network(*inputs)
|
||||||
cb_params.net_outputs = outputs
|
cb_params.net_outputs = outputs
|
||||||
# In disaster recovery scenarios, need not to execute callbacks if this step executes failed.
|
# In disaster recovery scenarios, need not to execute callbacks if this step executes failed.
|
||||||
need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
||||||
if need_exec_callback_step_end:
|
if need_exec_callback_step_end:
|
||||||
list_callback.step_end(run_context)
|
list_callback.on_train_step_end(run_context)
|
||||||
|
|
||||||
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
|
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
dataset_helper.continue_send()
|
dataset_helper.continue_send()
|
||||||
|
|
||||||
|
valid_dataset, valid_frequency, valid_dataset_sink_mode = valid_infos
|
||||||
|
if valid_dataset and self._should_eval(cb_params.cur_epoch_num, valid_frequency):
|
||||||
|
|
||||||
|
train_cur_step_num = cb_params.cur_step_num
|
||||||
|
train_batch_num = cb_params.batch_num
|
||||||
|
train_dataset_sink_mode = cb_params.dataset_sink_mode
|
||||||
|
train_net_outputs = cb_params.net_outputs
|
||||||
|
|
||||||
|
eval_callback = []
|
||||||
|
for cb in list_callback._callbacks:
|
||||||
|
if cb.__class__.__name__ in internal_cb_names:
|
||||||
|
if isinstance(cb, TimeMonitor):
|
||||||
|
eval_callback.append(cb)
|
||||||
|
else:
|
||||||
|
eval_callback.append(cb)
|
||||||
|
|
||||||
|
self._eval_in_fit(valid_dataset,
|
||||||
|
callbacks=eval_callback,
|
||||||
|
dataset_sink_mode=valid_dataset_sink_mode,
|
||||||
|
cb_params=cb_params)
|
||||||
|
cb_params.mode = "train"
|
||||||
|
cb_params.cur_step_num = train_cur_step_num
|
||||||
|
cb_params.batch_num = train_batch_num
|
||||||
|
cb_params.dataset_sink_mode = train_dataset_sink_mode
|
||||||
|
cb_params.net_outputs = train_net_outputs
|
||||||
|
|
||||||
# In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
|
# In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
|
||||||
need_exec_callback_epoch_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
need_exec_callback_epoch_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
||||||
if need_exec_callback_epoch_end:
|
if need_exec_callback_epoch_end:
|
||||||
list_callback.epoch_end(run_context)
|
list_callback.on_train_epoch_end(run_context)
|
||||||
|
if "metrics" in cb_params:
|
||||||
|
cb_params.pop("metrics")
|
||||||
|
|
||||||
should_stop = run_context.get_stop_requested()
|
should_stop = run_context.get_stop_requested()
|
||||||
if should_stop:
|
if should_stop:
|
||||||
|
@ -663,7 +699,7 @@ class Model:
|
||||||
dataset_helper.stop_send()
|
dataset_helper.stop_send()
|
||||||
dataset_helper.release()
|
dataset_helper.release()
|
||||||
|
|
||||||
list_callback.end(run_context)
|
list_callback.on_train_end(run_context)
|
||||||
|
|
||||||
def _check_enable_recovery(self):
|
def _check_enable_recovery(self):
|
||||||
"""
|
"""
|
||||||
|
@ -753,7 +789,8 @@ class Model:
|
||||||
|
|
||||||
_set_recovery_context(need_reset=False)
|
_set_recovery_context(need_reset=False)
|
||||||
|
|
||||||
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0):
|
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0,
|
||||||
|
valid_infos=None):
|
||||||
"""
|
"""
|
||||||
Training process. The data would be passed to network directly.
|
Training process. The data would be passed to network directly.
|
||||||
|
|
||||||
|
@ -776,13 +813,13 @@ class Model:
|
||||||
cb_params.cur_step_num = 0
|
cb_params.cur_step_num = 0
|
||||||
cb_params.dataset_sink_mode = False
|
cb_params.dataset_sink_mode = False
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
list_callback.begin(run_context)
|
list_callback.on_train_begin(run_context)
|
||||||
for i in range(initial_epoch, epoch):
|
for i in range(initial_epoch, epoch):
|
||||||
cb_params.cur_epoch_num = i + 1
|
cb_params.cur_epoch_num = i + 1
|
||||||
self._current_epoch_num = cb_params.cur_epoch_num
|
self._current_epoch_num = cb_params.cur_epoch_num
|
||||||
self._current_step_num = 0
|
self._current_step_num = 0
|
||||||
|
|
||||||
list_callback.epoch_begin(run_context)
|
list_callback.on_train_epoch_begin(run_context)
|
||||||
|
|
||||||
for next_element in dataset_helper:
|
for next_element in dataset_helper:
|
||||||
len_element = len(next_element)
|
len_element = len(next_element)
|
||||||
|
@ -795,7 +832,7 @@ class Model:
|
||||||
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
||||||
|
|
||||||
cb_params.train_dataset_element = next_element
|
cb_params.train_dataset_element = next_element
|
||||||
list_callback.step_begin(run_context)
|
list_callback.on_train_step_begin(run_context)
|
||||||
outputs = self._train_network(*next_element)
|
outputs = self._train_network(*next_element)
|
||||||
cb_params.net_outputs = outputs
|
cb_params.net_outputs = outputs
|
||||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||||
|
@ -803,24 +840,52 @@ class Model:
|
||||||
overflow = np.all(overflow.asnumpy())
|
overflow = np.all(overflow.asnumpy())
|
||||||
self._loss_scale_manager.update_loss_scale(overflow)
|
self._loss_scale_manager.update_loss_scale(overflow)
|
||||||
|
|
||||||
list_callback.step_end(run_context)
|
list_callback.on_train_step_end(run_context)
|
||||||
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
|
if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
should_stop = run_context.get_stop_requested()
|
should_stop = run_context.get_stop_requested()
|
||||||
if should_stop:
|
if should_stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
valid_dataset, valid_frequency, valid_dataset_sink_mode = valid_infos
|
||||||
|
if valid_dataset and self._should_eval(cb_params.cur_epoch_num, valid_frequency):
|
||||||
|
train_cur_step_num = cb_params.cur_step_num
|
||||||
|
train_batch_num = cb_params.batch_num
|
||||||
|
train_dataset_sink_mode = cb_params.dataset_sink_mode
|
||||||
|
train_net_outputs = cb_params.net_outputs
|
||||||
|
|
||||||
|
eval_callback = []
|
||||||
|
for cb in list_callback._callbacks:
|
||||||
|
if cb.__class__.__name__ in internal_cb_names:
|
||||||
|
if isinstance(cb, TimeMonitor):
|
||||||
|
eval_callback.append(cb)
|
||||||
|
else:
|
||||||
|
eval_callback.append(cb)
|
||||||
|
|
||||||
|
self._eval_in_fit(valid_dataset,
|
||||||
|
callbacks=eval_callback,
|
||||||
|
dataset_sink_mode=valid_dataset_sink_mode,
|
||||||
|
cb_params=cb_params)
|
||||||
|
|
||||||
|
cb_params.mode = "train"
|
||||||
|
cb_params.cur_step_num = train_cur_step_num
|
||||||
|
cb_params.batch_num = train_batch_num
|
||||||
|
cb_params.dataset_sink_mode = train_dataset_sink_mode
|
||||||
|
cb_params.net_outputs = train_net_outputs
|
||||||
|
|
||||||
train_dataset.reset()
|
train_dataset.reset()
|
||||||
|
|
||||||
# if param is cache enable, flush data from cache to host before epoch end
|
# if param is cache enable, flush data from cache to host before epoch end
|
||||||
self._flush_from_cache(cb_params)
|
self._flush_from_cache(cb_params)
|
||||||
|
|
||||||
list_callback.epoch_end(run_context)
|
list_callback.on_train_epoch_end(run_context)
|
||||||
|
if "metrics" in cb_params:
|
||||||
|
cb_params.pop("metrics")
|
||||||
should_stop = run_context.get_stop_requested()
|
should_stop = run_context.get_stop_requested()
|
||||||
if should_stop:
|
if should_stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
list_callback.end(run_context)
|
list_callback.on_train_end(run_context)
|
||||||
|
|
||||||
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0):
|
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0):
|
||||||
"""
|
"""
|
||||||
|
@ -912,6 +977,9 @@ class Model:
|
||||||
|
|
||||||
_device_number_check(self._parallel_mode, self._device_number)
|
_device_number_check(self._parallel_mode, self._device_number)
|
||||||
|
|
||||||
|
if callbacks:
|
||||||
|
self._check_methods_for_custom_callbacks(callbacks, "train")
|
||||||
|
|
||||||
self._train(epoch,
|
self._train(epoch,
|
||||||
train_dataset,
|
train_dataset,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
@ -924,6 +992,140 @@ class Model:
|
||||||
if _is_ps_mode() and _enable_distributed_mindrt():
|
if _is_ps_mode() and _enable_distributed_mindrt():
|
||||||
_reset_op_id_with_offset()
|
_reset_op_id_with_offset()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_methods_for_custom_callbacks(callbacks, current_mode):
|
||||||
|
"""
|
||||||
|
Check whether methods of custimized callbacks are valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object.
|
||||||
|
current_mode (str): 'fit', 'train' or 'eval'.
|
||||||
|
"""
|
||||||
|
old_version_methods_names = {'begin', 'end', 'epoch_begin', 'epoch_end', 'step_begin', 'step_end'}
|
||||||
|
if not isinstance(callbacks, list):
|
||||||
|
callbacks = [callbacks]
|
||||||
|
for cb in callbacks:
|
||||||
|
cb_name = cb.__class__.__name__
|
||||||
|
if cb_name not in internal_cb_names:
|
||||||
|
cb_methods_names = set(cb.__class__.__dict__.keys())
|
||||||
|
invalid_methods_names = cb_methods_names & old_version_methods_names
|
||||||
|
if invalid_methods_names:
|
||||||
|
if current_mode in ["train", "eval"]:
|
||||||
|
logger.warning("For %s callback, %s methods may not be supported in later version, "
|
||||||
|
"Use methods prefixed with 'on_train' or 'on_eval' instead "
|
||||||
|
"when using customized callbacks." % (cb_name, invalid_methods_names))
|
||||||
|
else:
|
||||||
|
raise ValueError("For %s callback, %s methods may not be supported in later version, "
|
||||||
|
"Use methods prefixed with 'on_train' or 'on_eval' instead when"
|
||||||
|
"using customized callbacks." % (cb_name, invalid_methods_names))
|
||||||
|
|
||||||
|
def fit(self, epoch, train_dataset, valid_dataset=None, valid_frequency=1, callbacks=None,
|
||||||
|
dataset_sink_mode=True, valid_dataset_sink_mode=True, sink_size=-1, initial_epoch=0):
|
||||||
|
"""
|
||||||
|
Fit API.
|
||||||
|
|
||||||
|
Evaluation process will be performed during training process if `valid_dataset` is provided.
|
||||||
|
|
||||||
|
More details please refer to `mindspore.Model.train` and `mindspore.Model.eval`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch.
|
||||||
|
If `dataset_sink_mode` is set to True and `sink_size` is greater than 0, each epoch will
|
||||||
|
train `sink_size` steps instead of total steps of dataset.
|
||||||
|
If `epoch` used with `initial_epoch`, it is to be understood as "final epoch".
|
||||||
|
train_dataset (Dataset): A training dataset iterator. If `loss_fn` is defined, the data and label will be
|
||||||
|
passed to the `network` and the `loss_fn` respectively, so a tuple (data, label)
|
||||||
|
should be returned from dataset. If there is multiple data or labels, set `loss_fn`
|
||||||
|
to None and implement calculation of loss in `network`,
|
||||||
|
then a tuple (data1, data2, data3, ...) with all data returned from dataset
|
||||||
|
will be passed to the `network`.
|
||||||
|
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
||||||
|
will be performed on the end of training process. Default: None.
|
||||||
|
valid_frequency (int, list): Only relevant if `valid_dataset` is provided. If an integer, specifies
|
||||||
|
how many training epochs to run before a new validation run is performed,
|
||||||
|
e.g. `valid_frequency=2` runs validation every 2 epochs.
|
||||||
|
If a list, specifies the epochs on which to run validation,
|
||||||
|
e.g. `valid_frequency=[1, 5]` runs validation at the end of the 1st, 5th epochs.
|
||||||
|
Default: 1
|
||||||
|
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
|
||||||
|
which should be executed while training.
|
||||||
|
Default: None.
|
||||||
|
dataset_sink_mode (bool): Determines whether to pass the train data through dataset channel.
|
||||||
|
Configure pynative mode or CPU, the training process will be performed with
|
||||||
|
dataset not sink. Default: True.
|
||||||
|
valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
|
||||||
|
Default: True.
|
||||||
|
sink_size (int): Control the amount of data in each sink. `sink_size` is invalid if `dataset_sink_mode`
|
||||||
|
is False.
|
||||||
|
If sink_size = -1, sink the complete dataset for each epoch.
|
||||||
|
If sink_size > 0, sink sink_size data for each epoch.
|
||||||
|
Default: -1.
|
||||||
|
initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
|
||||||
|
Default: 0.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mindspore import Model, nn, FixedLossScaleManager
|
||||||
|
>>>
|
||||||
|
>>> # For details about how to build the dataset, please refer to the tutorial
|
||||||
|
>>> # document on the official website.
|
||||||
|
>>> train_dataset = create_custom_dataset()
|
||||||
|
>>> valid_dataset = create_custom_dataset()
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics="accuracy")
|
||||||
|
>>> model.fit(2, train_dataset, valid_dataset)
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||||
|
valid_dataset_sink_mode = Validator.check_bool(valid_dataset_sink_mode)
|
||||||
|
|
||||||
|
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
|
||||||
|
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
|
||||||
|
|
||||||
|
if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
|
||||||
|
raise ValueError("when use Model.build to initialize model, the value of parameter `epoch` in Model.build "
|
||||||
|
"should be equal to value in Model.fit, but got {} and {} separately."
|
||||||
|
.format(train_dataset._warmup_epoch, epoch))
|
||||||
|
|
||||||
|
if dataset_sink_mode and _is_ps_mode() and not _cache_enable():
|
||||||
|
raise ValueError("Parameter server mode does not support 'data_sink_mode=True'.")
|
||||||
|
|
||||||
|
Validator.check_is_int(sink_size)
|
||||||
|
Validator.check_non_negative_int(initial_epoch)
|
||||||
|
if initial_epoch >= epoch:
|
||||||
|
raise ValueError(f"For 'Model.train', the parameter 'epoch' must bigger than parameter 'initial_epoch',"
|
||||||
|
f" but got the parameter 'epoch' is {epoch}, 'initial_epoch' is {initial_epoch}.")
|
||||||
|
dataset_size = train_dataset.get_dataset_size()
|
||||||
|
if dataset_size == 0:
|
||||||
|
raise ValueError("There is no valid data in dataset, please check dataset file firstly.")
|
||||||
|
if sink_size == -1:
|
||||||
|
sink_size = dataset_size
|
||||||
|
if sink_size < -1 or sink_size == 0:
|
||||||
|
raise ValueError("For 'Model.fit', The parameter 'sink_size' must be -1 or positive, "
|
||||||
|
"but got {}.".format(sink_size))
|
||||||
|
|
||||||
|
_device_number_check(self._parallel_mode, self._device_number)
|
||||||
|
|
||||||
|
if not isinstance(valid_frequency, (int, list)):
|
||||||
|
raise ValueError(f"For 'Model.fit', the type of 'valid_frequency' must be a list or a integer, but got"
|
||||||
|
"type {type(validation_freq)}.")
|
||||||
|
|
||||||
|
if valid_dataset and not self._metric_fns:
|
||||||
|
raise ValueError("For 'Model.fit', if valid_dataset is not None, the model argument 'metrics' can not be"
|
||||||
|
"None or empty, you should set the argument 'metrics' for model.")
|
||||||
|
if callbacks:
|
||||||
|
self._check_methods_for_custom_callbacks(callbacks, "fit")
|
||||||
|
self._train(epoch,
|
||||||
|
train_dataset,
|
||||||
|
callbacks=callbacks,
|
||||||
|
dataset_sink_mode=dataset_sink_mode,
|
||||||
|
sink_size=sink_size,
|
||||||
|
initial_epoch=initial_epoch,
|
||||||
|
valid_dataset=valid_dataset,
|
||||||
|
valid_frequency=valid_frequency,
|
||||||
|
valid_dataset_sink_mode=valid_dataset_sink_mode)
|
||||||
|
|
||||||
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, jit_config=None):
|
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, jit_config=None):
|
||||||
"""
|
"""
|
||||||
Build computational graphs and data graphs with the sink mode.
|
Build computational graphs and data graphs with the sink mode.
|
||||||
|
@ -971,6 +1173,40 @@ class Model:
|
||||||
_cell_graph_executor.set_jit_config(jit_config)
|
_cell_graph_executor.set_jit_config(jit_config)
|
||||||
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
||||||
|
|
||||||
|
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
||||||
|
"""
|
||||||
|
Evaluation process in `mindspore.Model.fit`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
||||||
|
will be performed on the end of training process. Default: None.
|
||||||
|
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, which should be
|
||||||
|
executed while evaluation. Default: None.
|
||||||
|
valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
|
||||||
|
Default: True.
|
||||||
|
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||||
|
"""
|
||||||
|
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
|
||||||
|
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
||||||
|
|
||||||
|
cb_params.eval_network = self._eval_network
|
||||||
|
cb_params.valid_dataset = valid_dataset
|
||||||
|
cb_params.batch_num = valid_dataset.get_dataset_size()
|
||||||
|
cb_params.mode = "eval"
|
||||||
|
cb_params.cur_step_num = 0
|
||||||
|
|
||||||
|
self._clear_metrics()
|
||||||
|
|
||||||
|
if context.get_context("device_target") == "CPU" and dataset_sink_mode:
|
||||||
|
dataset_sink_mode = False
|
||||||
|
logger.info("CPU cannot support dataset sink mode currently."
|
||||||
|
"So the evaluating process will be performed with dataset non-sink mode.")
|
||||||
|
|
||||||
|
with _CallbackManager(callbacks) as list_callback:
|
||||||
|
if dataset_sink_mode:
|
||||||
|
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||||
|
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||||
|
|
||||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||||
"""
|
"""
|
||||||
Evaluation. The data would be passed to network through dataset channel.
|
Evaluation. The data would be passed to network through dataset channel.
|
||||||
|
@ -990,20 +1226,20 @@ class Model:
|
||||||
dataset_sink_mode=True)
|
dataset_sink_mode=True)
|
||||||
cb_params.eval_network = eval_network
|
cb_params.eval_network = eval_network
|
||||||
cb_params.dataset_sink_mode = True
|
cb_params.dataset_sink_mode = True
|
||||||
list_callback.begin(run_context)
|
list_callback.on_eval_begin(run_context)
|
||||||
list_callback.epoch_begin(run_context)
|
list_callback.on_eval_epoch_begin(run_context)
|
||||||
for inputs in dataset_helper:
|
for inputs in dataset_helper:
|
||||||
cb_params.cur_step_num += 1
|
cb_params.cur_step_num += 1
|
||||||
list_callback.step_begin(run_context)
|
list_callback.on_eval_step_begin(run_context)
|
||||||
outputs = eval_network(*inputs)
|
outputs = eval_network(*inputs)
|
||||||
cb_params.net_outputs = outputs
|
cb_params.net_outputs = outputs
|
||||||
list_callback.step_end(run_context)
|
list_callback.on_eval_step_end(run_context)
|
||||||
self._update_metrics(outputs)
|
self._update_metrics(outputs)
|
||||||
|
|
||||||
list_callback.epoch_end(run_context)
|
list_callback.on_eval_epoch_end(run_context)
|
||||||
metrics = self._get_metrics()
|
metrics = self._get_metrics()
|
||||||
cb_params.metrics = metrics
|
cb_params.metrics = metrics
|
||||||
list_callback.end(run_context)
|
list_callback.on_eval_end(run_context)
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
@ -1021,25 +1257,25 @@ class Model:
|
||||||
"""
|
"""
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
cb_params.dataset_sink_mode = False
|
cb_params.dataset_sink_mode = False
|
||||||
list_callback.begin(run_context)
|
list_callback.on_eval_begin(run_context)
|
||||||
dataset_helper, _ = self._exec_preprocess(is_train=False,
|
dataset_helper, _ = self._exec_preprocess(is_train=False,
|
||||||
dataset=valid_dataset,
|
dataset=valid_dataset,
|
||||||
dataset_sink_mode=False)
|
dataset_sink_mode=False)
|
||||||
list_callback.epoch_begin(run_context)
|
list_callback.on_eval_epoch_begin(run_context)
|
||||||
for next_element in dataset_helper:
|
for next_element in dataset_helper:
|
||||||
cb_params.cur_step_num += 1
|
cb_params.cur_step_num += 1
|
||||||
list_callback.step_begin(run_context)
|
list_callback.on_eval_step_begin(run_context)
|
||||||
next_element = _transfer_tensor_to_tuple(next_element)
|
next_element = _transfer_tensor_to_tuple(next_element)
|
||||||
outputs = self._eval_network(*next_element)
|
outputs = self._eval_network(*next_element)
|
||||||
cb_params.net_outputs = outputs
|
cb_params.net_outputs = outputs
|
||||||
list_callback.step_end(run_context)
|
list_callback.on_eval_step_end(run_context)
|
||||||
self._update_metrics(outputs)
|
self._update_metrics(outputs)
|
||||||
|
|
||||||
list_callback.epoch_end(run_context)
|
list_callback.on_eval_epoch_end(run_context)
|
||||||
valid_dataset.reset()
|
valid_dataset.reset()
|
||||||
metrics = self._get_metrics()
|
metrics = self._get_metrics()
|
||||||
cb_params.metrics = metrics
|
cb_params.metrics = metrics
|
||||||
list_callback.end(run_context)
|
list_callback.on_eval_end(run_context)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
|
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
|
||||||
|
@ -1087,7 +1323,8 @@ class Model:
|
||||||
"you should set the argument 'metrics' for model.")
|
"you should set the argument 'metrics' for model.")
|
||||||
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
|
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
|
||||||
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
||||||
|
if callbacks:
|
||||||
|
self._check_methods_for_custom_callbacks(callbacks, "eval")
|
||||||
cb_params = _InternalCallbackParam()
|
cb_params = _InternalCallbackParam()
|
||||||
cb_params.eval_network = self._eval_network
|
cb_params.eval_network = self._eval_network
|
||||||
cb_params.valid_dataset = valid_dataset
|
cb_params.valid_dataset = valid_dataset
|
||||||
|
|
|
@ -150,7 +150,7 @@ class ModelThor(Model):
|
||||||
self.switch_branch_one = not self.switch_branch_one
|
self.switch_branch_one = not self.switch_branch_one
|
||||||
outputs = self._train_network(*inputs)
|
outputs = self._train_network(*inputs)
|
||||||
cb_params.net_outputs = outputs
|
cb_params.net_outputs = outputs
|
||||||
list_callback.step_end(run_context)
|
list_callback.on_train_step_end(run_context)
|
||||||
else:
|
else:
|
||||||
cb_params.cur_step_num += 1
|
cb_params.cur_step_num += 1
|
||||||
if self.train_network_init_flag:
|
if self.train_network_init_flag:
|
||||||
|
@ -163,7 +163,7 @@ class ModelThor(Model):
|
||||||
if self.index_first_order == iter_first_order:
|
if self.index_first_order == iter_first_order:
|
||||||
self.index_first_order = 0
|
self.index_first_order = 0
|
||||||
self.switch_branch_one = not self.switch_branch_one
|
self.switch_branch_one = not self.switch_branch_one
|
||||||
list_callback.step_end(run_context)
|
list_callback.on_train_step_end(run_context)
|
||||||
|
|
||||||
def _train_ascend_sink_step(self, cb_params, train_dataset, iter_first_order, inputs, list_callback, run_context):
|
def _train_ascend_sink_step(self, cb_params, train_dataset, iter_first_order, inputs, list_callback, run_context):
|
||||||
"""train ascend sink step"""
|
"""train ascend sink step"""
|
||||||
|
@ -184,10 +184,10 @@ class ModelThor(Model):
|
||||||
self.switch_branch_one = not self.switch_branch_one
|
self.switch_branch_one = not self.switch_branch_one
|
||||||
outputs = self._train_network(*inputs)
|
outputs = self._train_network(*inputs)
|
||||||
cb_params.net_outputs = outputs
|
cb_params.net_outputs = outputs
|
||||||
list_callback.step_end(run_context)
|
list_callback.on_train_step_end(run_context)
|
||||||
|
|
||||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None,
|
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None,
|
||||||
sink_size=-1, initial_epoch=0):
|
sink_size=-1, initial_epoch=0, valid_infos=None):
|
||||||
"""
|
"""
|
||||||
Training process. The data would be passed to network through dataset channel.
|
Training process. The data would be passed to network through dataset channel.
|
||||||
|
|
||||||
|
@ -204,6 +204,9 @@ class ModelThor(Model):
|
||||||
initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
|
initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
|
||||||
Default: 0.
|
Default: 0.
|
||||||
"""
|
"""
|
||||||
|
valid_dataset, _, _ = valid_infos
|
||||||
|
if valid_dataset:
|
||||||
|
raise ValueError("Evaluation in training is currently not supported in the second-order scenario of thor.")
|
||||||
if sink_size == -1:
|
if sink_size == -1:
|
||||||
epoch_num = epoch - initial_epoch
|
epoch_num = epoch - initial_epoch
|
||||||
else:
|
else:
|
||||||
|
@ -226,28 +229,28 @@ class ModelThor(Model):
|
||||||
cb_params.cur_step_num = 0
|
cb_params.cur_step_num = 0
|
||||||
|
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
list_callback.begin(run_context)
|
list_callback.on_train_begin(run_context)
|
||||||
|
|
||||||
for i in range(initial_epoch, epoch):
|
for i in range(initial_epoch, epoch):
|
||||||
cb_params.cur_epoch_num = i + 1
|
cb_params.cur_epoch_num = i + 1
|
||||||
list_callback.epoch_begin(run_context)
|
list_callback.on_train_epoch_begin(run_context)
|
||||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||||
for inputs in dataset_helper:
|
for inputs in dataset_helper:
|
||||||
if _need_to_full() and context.get_context("device_target") == "GPU":
|
if _need_to_full() and context.get_context("device_target") == "GPU":
|
||||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||||
list_callback.step_begin(run_context)
|
list_callback.on_train_step_begin(run_context)
|
||||||
if context.get_context("device_target") == "GPU":
|
if context.get_context("device_target") == "GPU":
|
||||||
self._train_gpu_sink_step(cb_params, inputs, list_callback, iter_first_order, run_context)
|
self._train_gpu_sink_step(cb_params, inputs, list_callback, iter_first_order, run_context)
|
||||||
else:
|
else:
|
||||||
self._train_ascend_sink_step(cb_params, train_dataset, iter_first_order, inputs, list_callback,
|
self._train_ascend_sink_step(cb_params, train_dataset, iter_first_order, inputs, list_callback,
|
||||||
run_context)
|
run_context)
|
||||||
list_callback.epoch_end(run_context)
|
list_callback.on_train_epoch_end(run_context)
|
||||||
should_stop = False or run_context.get_stop_requested()
|
should_stop = False or run_context.get_stop_requested()
|
||||||
if should_stop:
|
if should_stop:
|
||||||
break
|
break
|
||||||
dataset_helper.stop_send()
|
dataset_helper.stop_send()
|
||||||
|
|
||||||
list_callback.end(run_context)
|
list_callback.on_train_end(run_context)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ModelThor"]
|
__all__ = ["ModelThor"]
|
||||||
|
|
|
@ -0,0 +1,202 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
""" test_fit """
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import Model, nn, Tensor
|
||||||
|
from mindspore.common.initializer import Normal
|
||||||
|
from mindspore.train.callback import Callback, TimeMonitor, LossMonitor
|
||||||
|
from mindspore import dataset as ds
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(num, w=2.0, b=3.0):
|
||||||
|
for _ in range(num):
|
||||||
|
x = np.random.uniform(-10.0, 10.0)
|
||||||
|
noise = np.random.normal(0, 1)
|
||||||
|
y = x * w + b + noise
|
||||||
|
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(num_data, batch_size=16, repeat_size=1):
|
||||||
|
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
|
||||||
|
input_data = input_data.batch(batch_size, drop_remainder=True)
|
||||||
|
input_data = input_data.repeat(repeat_size)
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
|
||||||
|
def define_model():
|
||||||
|
net = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
|
||||||
|
net_loss = nn.MSELoss()
|
||||||
|
net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
||||||
|
return Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={'mse', 'mae'})
|
||||||
|
|
||||||
|
|
||||||
|
class MyCallbackOldMethod(Callback):
|
||||||
|
""" Raise warning in `mindspore.Model.train` and `mindspore.Model.eval`; raise error in `mindspore.Model.fit`"""
|
||||||
|
def begin(self, run_context):
|
||||||
|
print("custom callback: print on begin, just for test.")
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
outputs = cb_params.get("net_outputs")
|
||||||
|
result = outputs if isinstance(outputs, Tensor) else outputs[0]
|
||||||
|
print("custom train callback: step end, loss is %s" % (result))
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
print("custom train callback: epoch end, loss is %s" % (cb_params.get("net_outputs")))
|
||||||
|
|
||||||
|
|
||||||
|
class MyCallbackNewMethod(Callback):
|
||||||
|
""" Custom callback running in `mindspore.Model.train`, `mindspore.Model.eval`, `mindspore.Model.fit`"""
|
||||||
|
def on_train_epoch_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
print("custom callback: train epoch end, loss is %s" % (cb_params.get("net_outputs")))
|
||||||
|
|
||||||
|
def on_eval_epoch_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
print("custom callback: eval epoch end, metric is %s" % (cb_params.get("net_outputs")[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_train_dataset_non_sink_mode():
|
||||||
|
"""
|
||||||
|
Feature: `mindspore.Model.fit` with train dataset in non-sink mode.
|
||||||
|
Description: test fit with train dataset in non-sink mode.
|
||||||
|
Expectation: run in non-sink mode.
|
||||||
|
"""
|
||||||
|
model = define_model()
|
||||||
|
ds_train = create_dataset(4096, 1024)
|
||||||
|
ds_eval = create_dataset(1024, 512)
|
||||||
|
callbacks = [LossMonitor()]
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_train_dataset_sink_mode():
|
||||||
|
"""
|
||||||
|
Feature: `mindspore.Model.fit` with train dataset in sink mode.
|
||||||
|
Description: test fit with train dataset in sink mode.
|
||||||
|
Expectation: run in sink mode.
|
||||||
|
"""
|
||||||
|
model = define_model()
|
||||||
|
ds_train = create_dataset(4096, 1024)
|
||||||
|
ds_eval = create_dataset(1024, 512)
|
||||||
|
callbacks = [LossMonitor()]
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=True, sink_size=256)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_valid_dataset_non_sink_mode():
|
||||||
|
"""
|
||||||
|
Feature: `mindspore.Model.fit` with valid dataset in non-sink mode.
|
||||||
|
Description: test fit with valid dataset in non-sink mode.
|
||||||
|
Expectation: run in non-sink mode.
|
||||||
|
"""
|
||||||
|
model = define_model()
|
||||||
|
ds_train = create_dataset(4096, 1024)
|
||||||
|
ds_eval = create_dataset(1024, 512)
|
||||||
|
callbacks = [LossMonitor()]
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_valid_dataset_sink_mode():
|
||||||
|
"""
|
||||||
|
Feature: `mindspore.Model.fit` with valid dataset in sink mode.
|
||||||
|
Description: test fit with valid dataset in sink mode.
|
||||||
|
Expectation: run in sink mode.
|
||||||
|
"""
|
||||||
|
model = define_model()
|
||||||
|
ds_train = create_dataset(4096, 1024)
|
||||||
|
ds_eval = create_dataset(1024, 512)
|
||||||
|
callbacks = [LossMonitor()]
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_without_valid_dataset():
|
||||||
|
"""
|
||||||
|
Feature: `mindspore.Model.fit` without `valid_dataset` input .
|
||||||
|
Description: test fit when `valid_dataset` is None and `valid_dataset_sink_mode` is True or False.
|
||||||
|
Expectation: network train without eval process, `valid_dataset_sink_mode` does not take effect.
|
||||||
|
"""
|
||||||
|
model = define_model()
|
||||||
|
ds_train = create_dataset(4096, 1024)
|
||||||
|
callbacks = [LossMonitor()]
|
||||||
|
model.fit(3, ds_train, None, callbacks=callbacks, valid_dataset_sink_mode=False)
|
||||||
|
model.fit(3, ds_train, None, callbacks=callbacks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_valid_frequency():
|
||||||
|
"""
|
||||||
|
Feature: check `valid_frequency` input in `mindspore.Model.fit`.
|
||||||
|
Description: when `valid_frequency` is integer, list or other types.
|
||||||
|
Expectation: raise ValueError when the type of valid_frequency is not int or list.
|
||||||
|
"""
|
||||||
|
model = define_model()
|
||||||
|
callbacks = [LossMonitor()]
|
||||||
|
ds_train = create_dataset(4096, 1024)
|
||||||
|
ds_eval = create_dataset(1024, 512)
|
||||||
|
model.fit(3, ds_train, ds_eval, valid_frequency=1, callbacks=callbacks)
|
||||||
|
model.fit(5, ds_train, ds_eval, valid_frequency=2, callbacks=callbacks)
|
||||||
|
model.fit(5, ds_train, ds_eval, valid_frequency=[0, 1, 4], callbacks=callbacks)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model.fit(5, ds_train, ds_eval, valid_frequency=(0, 2), callbacks=callbacks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_callbacks():
|
||||||
|
"""
|
||||||
|
Feature: check `callbacks` input in `mindspore.Model.fit`.
|
||||||
|
Description: test internal or custom callbacks in fit.
|
||||||
|
Expectation: raise ValueError when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'.
|
||||||
|
"""
|
||||||
|
model = define_model()
|
||||||
|
ds_train = create_dataset(4096, 1024)
|
||||||
|
ds_eval = create_dataset(1024, 512)
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=None)
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor()])
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), LossMonitor()])
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=[MyCallbackNewMethod()])
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), MyCallbackNewMethod()])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=[MyCallbackOldMethod()])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), MyCallbackOldMethod()])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model.fit(3, ds_train, valid_dataset=None, callbacks=[TimeMonitor(), MyCallbackOldMethod()])
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_eval_callbacks():
|
||||||
|
"""
|
||||||
|
Feature: check `callbacks` input in `mindspore.Model.train` or `mindspore.Model.eval`.
|
||||||
|
Description: test internal or custom callbacks in train or eval.
|
||||||
|
Expectation: raise warning when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'.
|
||||||
|
"""
|
||||||
|
model = define_model()
|
||||||
|
ds_train = create_dataset(4096, 1024)
|
||||||
|
ds_eval = create_dataset(1024, 512)
|
||||||
|
|
||||||
|
model.train(3, ds_train, callbacks=None)
|
||||||
|
model.train(3, ds_train, callbacks=[TimeMonitor()])
|
||||||
|
model.train(3, ds_train, callbacks=[LossMonitor()])
|
||||||
|
model.train(3, ds_train, callbacks=[MyCallbackNewMethod()])
|
||||||
|
model.train(3, ds_train, callbacks=[MyCallbackOldMethod()])
|
||||||
|
|
||||||
|
metric_results = model.eval(ds_eval, callbacks=None)
|
||||||
|
print("{}".format(metric_results))
|
||||||
|
metric_results = model.eval(ds_eval, callbacks=[TimeMonitor()])
|
||||||
|
print("{}".format(metric_results))
|
||||||
|
metric_results = model.eval(ds_eval, callbacks=[MyCallbackNewMethod()])
|
||||||
|
print("{}".format(metric_results))
|
||||||
|
metric_results = model.eval(ds_eval, callbacks=[MyCallbackOldMethod()])
|
||||||
|
print("{}".format(metric_results))
|
|
@ -513,13 +513,13 @@ def test_lambda():
|
||||||
|
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
lambda_cb = LambdaCallback(
|
lambda_cb = LambdaCallback(
|
||||||
epoch_end=lambda run_context: print("loss result: ", run_context.original_args().net_outputs))
|
on_train_epoch_end=lambda run_context: print("loss result: ", run_context.original_args().net_outputs))
|
||||||
|
|
||||||
callbacks = [lambda_cb]
|
callbacks = [lambda_cb]
|
||||||
with _CallbackManager(callbacks) as callbacklist:
|
with _CallbackManager(callbacks) as callbacklist:
|
||||||
callbacklist.begin(run_context)
|
callbacklist.on_train_begin(run_context)
|
||||||
callbacklist.epoch_begin(run_context)
|
callbacklist.on_train_epoch_begin(run_context)
|
||||||
callbacklist.step_begin(run_context)
|
callbacklist.on_train_step_begin(run_context)
|
||||||
callbacklist.step_end(run_context)
|
callbacklist.on_train_step_end(run_context)
|
||||||
callbacklist.epoch_end(run_context)
|
callbacklist.on_train_epoch_end(run_context)
|
||||||
callbacklist.end(run_context)
|
callbacklist.on_train_end(run_context)
|
||||||
|
|
Loading…
Reference in New Issue