diff --git a/docs/api/api_python/mindspore/mindspore.Callback.rst b/docs/api/api_python/mindspore/mindspore.Callback.rst index 76f00b0dbd5..269da2b3a03 100644 --- a/docs/api/api_python/mindspore/mindspore.Callback.rst +++ b/docs/api/api_python/mindspore/mindspore.Callback.rst @@ -6,8 +6,7 @@ mindspore.Callback 用于构建Callback函数的基类。Callback函数是一个上下文管理器,在运行模型时被调用。 可以使用此机制进行一些自定义操作。 - Callback类的每个方法对应了训练或推理过程的不同阶段,这些方法有相同的入参 `run_context`,用于保存模型 - 训练或推理过程模型的相关信息。定义Callback子类或自定义Callback时,请根据需要重写对应的方法。 + Callback类的每个方法对应了训练或推理过程的不同阶段,这些方法有相同的入参 `run_context`,用于保存训练或推理过程中模型的相关信息。定义Callback子类或自定义Callback时,请根据需要重写名称前缀为“on_train”或“on_eval”的方法,否则自定义的Callback在 `model.fit` 中使用时会产生错误。 自定义Callback场景下,在类方法中通过 `RunContext.original_args()` 方法可以获取模型训练或推理过程中已有 的上下文信息,此信息为一个存储了已有属性的字典型变量;用户也可以在此信息中添加其他的自定义属性;此外, @@ -16,7 +15,7 @@ mindspore.Callback .. py:method:: begin(run_context) - 在网络执行之前被调用一次。 + 在网络执行之前被调用一次。与 `on_train_begin` 和 `on_eval_begin` 方法具有兼容性。 **参数:** @@ -24,7 +23,7 @@ mindspore.Callback .. py:method:: end(run_context) - 网络执行后被调用一次。 + 网络执行后被调用一次。与 `on_train_end` 和 `on_eval_end` 方法具有兼容性。 **参数:** @@ -32,7 +31,7 @@ mindspore.Callback .. py:method:: epoch_begin(run_context) - 在每个epoch开始之前被调用。 + 在每个epoch开始之前被调用。与 `on_train_epoch_begin` 和 `on_eval_epoch_begin` 方法具有兼容性。 **参数:** @@ -40,7 +39,7 @@ mindspore.Callback .. py:method:: epoch_end(run_context) - 在每个epoch结束后被调用。 + 在每个epoch结束后被调用。与 `on_train_epoch_end` 和 `on_eval_epoch_end` 方法具有兼容性。 **参数:** @@ -52,7 +51,7 @@ mindspore.Callback **参数:** - - **run_context** (RunContext) - 包含模型的一些基本信息。 + - **run_context** (RunContext) - 包含模型的一些基本信息。与 `on_train_step_begin` 和 `on_eval_step_begin` 方法具有兼容性。 .. py:method:: step_end(run_context) @@ -60,4 +59,100 @@ mindspore.Callback **参数:** + - **run_context** (RunContext) - 包含模型的一些基本信息。与 `on_train_step_end` 和 `on_eval_step_end` 方法具有兼容性。 + + .. py:method:: on_train_begin(run_context) + + 在网络执行训练之前调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method::on_train_end(run_context) + + 网络训练执行结束时调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_train_epoch_begin(run_context) + + 在训练的每个epoch开始之前被调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_train_epoch_end(run_context) + + 在训练的每个epoch结束后被调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_train_step_begin(run_context) + + 在训练的每个step开始之前被调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_train_step_end(run_context) + + 在训练的每个step完成后被调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_eval_begin(run_context) + + 在网络执行推理之前调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_eval_end(run_context) + + 网络执行推理之后调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_eval_epoch_begin(run_context) + + 在推理的epoch开始之前被调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_evalepoch_end(run_context) + + 在推理的epoch结束后被调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_eval_step_begin(run_context) + + 在推理的每个step开始之前被调用。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的一些基本信息。 + + .. py:method:: on_eval_step_end(run_context) + + 在推理的每个step完成后被调用。 + + **参数:** + - **run_context** (RunContext) - 包含模型的一些基本信息。 diff --git a/docs/api/api_python/mindspore/mindspore.History.rst b/docs/api/api_python/mindspore/mindspore.History.rst index 1b266c31ce5..77da7db28fc 100644 --- a/docs/api/api_python/mindspore/mindspore.History.rst +++ b/docs/api/api_python/mindspore/mindspore.History.rst @@ -3,12 +3,12 @@ mindspore.History .. py:class:: mindspore.History - 将网络输出的相关信息记录到 `History` 对象中。 + 将网络输出和评估指标的相关信息记录到 `History` 对象中。 用户不自定义训练网络或评估网络情况下,记录的内容将为损失值;用户自定义了训练网络/评估网络的情况下,如果定义的网络返回 `Tensor` 或 `numpy.ndarray`,则记录此返回值均值,如果返回 `tuple` 或 `list`,则记录第一个元素。 .. note:: - 通常使用在 `mindspore.Model.train` 中。 + 通常使用在 `mindspore.Model.train` 和 `mindspore.Model.fit` 中。 .. py:method:: begin(run_context) @@ -20,7 +20,7 @@ mindspore.History .. py:method:: epoch_end(run_context) - epoch结束时记录网络输出的相关信息。 + epoch结束时记录网络输出和评估指标的相关信息。 **参数:** diff --git a/docs/api/api_python/mindspore/mindspore.LambdaCallback.rst b/docs/api/api_python/mindspore/mindspore.LambdaCallback.rst index 5b689cbaaf6..2e114a29b34 100644 --- a/docs/api/api_python/mindspore/mindspore.LambdaCallback.rst +++ b/docs/api/api_python/mindspore/mindspore.LambdaCallback.rst @@ -5,7 +5,7 @@ mindspore.LambdaCallback 用于自定义简单的callback。 - 使用匿名函数构建callback,定义的匿名函数将在 `mindspore.Model.{train | eval}` 的对应阶段被调用。 + 使用匿名函数构建callback,定义的匿名函数将在 `mindspore.Model.{train | eval | fit}` 的对应阶段被调用。 请注意,callback的每个阶段都需要一个位置参数:`run_context`。 @@ -14,9 +14,15 @@ mindspore.LambdaCallback **参数:** - - **epoch_begin** (Function) - 每个epoch开始时被调用。 - - **epoch_end** (Function) - 每个epoch结束时被调用。 - - **step_begin** (Function) - 每个step开始时被调用。 - - **step_end** (Function) - 每个step结束时被调用。 - - **begin** (Function) - 模型训练、评估开始时被调用。 - - **end** (Function) - 模型训练、评估结束时被调用。 + - **on_train_epoch_begin** (Function) - 训练每个epoch开始时被调用。 + - **on_train_epoch_end** (Function) - 训练每个epoch结束时被调用。 + - **on_train_step_begin** (Function) - 训练每个step开始时被调用。 + - **on_train_step_end** (Function) - 训练每个step结束时被调用。 + - **on_train_begin** (Function) - 模型训练开始时被调用。 + - **on_train_end** (Function) - 模型训练结束时被调用。 + - **on_eval_epoch_begin** (Function) - 推理的epoch开始时被调用。 + - **on_eval_epoch_end** (Function) - 推理的epoch结束时被调用。 + - **on_eval_step_begin** (Function) - 推理的每个step开始时被调用。 + - **on_eval_step_end** (Function) - 推理的每个step结束时被调用。 + - **on_eval_begin** (Function) - 模型推理开始时被调用。 + - **on_eval_end** (Function) - 模型推理结束时被调用。 diff --git a/docs/api/api_python/mindspore/mindspore.LossMonitor.rst b/docs/api/api_python/mindspore/mindspore.LossMonitor.rst index 3aae81b3faa..7697b1f5c44 100644 --- a/docs/api/api_python/mindspore/mindspore.LossMonitor.rst +++ b/docs/api/api_python/mindspore/mindspore.LossMonitor.rst @@ -3,7 +3,7 @@ mindspore.LossMonitor .. py:class:: mindspore.LossMonitor(per_print_times=1) - 监控训练的loss。 + 训练场景下,监控训练的loss;边训练边推理场景下,监控训练的loss和推理的metrics。 如果loss是NAN或INF,则终止训练。 @@ -25,3 +25,11 @@ mindspore.LossMonitor **参数:** - **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。 + + .. py:method:: on_train_epoch_end(run_context) + + LossMoniter用于 `model.fit`,即边训练边推理场景时,打印训练的loss和当前epoch推理的metrics。 + + **参数:** + + - **run_context** (RunContext) - 包含模型的相关信息。详情请参考 :class:`mindspore.RunContext`。 diff --git a/docs/api/api_python/mindspore/mindspore.Model.rst b/docs/api/api_python/mindspore/mindspore.Model.rst index 6bf7d9daf62..0ed99784882 100644 --- a/docs/api/api_python/mindspore/mindspore.Model.rst +++ b/docs/api/api_python/mindspore/mindspore.Model.rst @@ -162,6 +162,24 @@ - **sink_size** (int) – 控制每次数据下沉的数据量。`dataset_sink_mode` 为False时 `sink_size` 无效。如果sink_size=-1,则每一次epoch下沉完整数据集。如果sink_size>0,则每一次epoch下沉数据量为sink_size的数据集。默认值:-1。 - **initial_epoch** (int) - 从哪个epoch开始训练,一般用于中断恢复训练场景。 + .. py:method:: fit(epoch, train_dataset, valid_dataset=None, valid_frequency=1, callbacks=None, dataset_sink_mode=True, valid_dataset_sink_mode=True, sink_size=-1, initial_epoch=0) + + 模型边训练边推理接口。 + + 如果 `valid_dataset` 不为None,在训练过程中同时执行推理。更多详细信息请参考 `mindspore.Model.model.train`。 + + **参数:** + + - **epoch** (int) – 训练执行轮次。通常每个epoch都会使用全量数据集进行训练。当 `dataset_sink_mode` 设置为True且 `sink_size` 大于零时,则每个epoch训练次数为 `sink_size` 而不是数据集的总步数。如果 `epoch` 与 `initial_epoch` 一起使用,它表示训练的最后一个 `epoch` 是多少。 + - **train_dataset** (Dataset) – 训练数据集迭代器。如果定义了 `loss_fn` ,则数据和标签会被分别传给 `network` 和 `loss_fn` ,此时数据集需要返回一个元组(data, label)。如果数据集中有多个数据或者标签,可以设置 `loss_fn` 为None,并在 `network` 中实现损失函数计算,此时数据集返回的所有数据组成的元组(data1, data2, data3, ...)会传给 `network` 。 + - **valid_dataset** (Dataset) – 评估模型的数据集迭代器。默认值:None。 + - **valid_frequency** (Dataset) – 此参数只有在valid_dataset不为None时生效。如果为int类型,表示执行推理的频率,例如 `valid_frequency=2`,则每2个训练epoch执行一次推理;如果为list类型,指明在哪几个epoch时执行推理,例如 `valid_frequency=[1, 5]`,则在第1个和第5个epoch执行推理。默认值:1。 + - **callbacks** (Optional[list[Callback], Callback]) – 训练过程中需要执行的回调对象或者回调对象列表。默认值:None。 + - **dataset_sink_mode** (bool) – 训练数据是否直接下沉至处理器进行处理。使用PYNATIVE_MODE模式或CPU处理器时,模型训练流程将以非下沉模式执行。默认值:True。 + - **valid_dataset_sink_mode** (bool) - 推理数据是否直接下沉至处理器进行处理。默认值:True。 + - **sink_size** (int) – 控制每次数据下沉的数据量。`dataset_sink_mode` 为False时 `sink_size` 无效。如果sink_size=-1,则每一次epoch下沉完整数据集。如果sink_size>0,则每一次epoch下沉数据量为sink_size的数据集。默认值:-1。 + - **initial_epoch** (int) - 从哪个epoch开始训练,一般用于中断恢复训练场景。 + .. py:method:: train_network :property: @@ -169,4 +187,4 @@ **返回:** - 预测网络实例。 \ No newline at end of file + 预测网络实例。 diff --git a/docs/api/api_python/mindspore/mindspore.TimeMonitor.rst b/docs/api/api_python/mindspore/mindspore.TimeMonitor.rst index 91bfa3d32df..c59140febdd 100644 --- a/docs/api/api_python/mindspore/mindspore.TimeMonitor.rst +++ b/docs/api/api_python/mindspore/mindspore.TimeMonitor.rst @@ -3,7 +3,7 @@ mindspore.TimeMonitor .. py:class:: mindspore.TimeMonitor(data_size=None) - 监控训练时间。 + 监控训练或推理的时间。 **参数:** diff --git a/mindspore/python/mindspore/train/callback/_callback.py b/mindspore/python/mindspore/train/callback/_callback.py index 9ff3d65fd48..56e421e0bb7 100644 --- a/mindspore/python/mindspore/train/callback/_callback.py +++ b/mindspore/python/mindspore/train/callback/_callback.py @@ -82,7 +82,8 @@ class Callback: Each method of Callback class corresponds to a stage in training or eval process, and those methods have the same input `run_context`, which hold context information of the model in training or eval process. When defining a Callback subclass or creating a custom Callback, - override these methods. + Note that you should override methods with names prefixed with "on_train" or "on_eval", + otherwise ValueError will be raised if the custimized Callbacks used in `model.fit`. When creating a custom Callback, model context information can be obtained in Callback methods by calling `RunContext.original_args()`, which is a dictionary varivable @@ -122,6 +123,7 @@ class Callback: def begin(self, run_context): """ Called once before the network executing. + A backwards compatibility alias for `on_train_begin` and `on_eval_begin`. Args: run_context (RunContext): Include some information of the model. @@ -130,6 +132,7 @@ class Callback: def epoch_begin(self, run_context): """ Called before each epoch beginning. + A backwards compatibility alias for `on_train_epoch_begin` and `on_eval_epoch_begin`. Args: run_context (RunContext): Include some information of the model. @@ -138,6 +141,7 @@ class Callback: def epoch_end(self, run_context): """ Called after each epoch finished. + A backwards compatibility alias for `on_train_epoch_end` and `on_eval_epoch_end`. Args: run_context (RunContext): Include some information of the model. @@ -146,6 +150,7 @@ class Callback: def step_begin(self, run_context): """ Called before each step beginning. + A backwards compatibility alias for `on_train_step_begin` and `on_eval_step_begin`. Args: run_context (RunContext): Include some information of the model. @@ -154,6 +159,7 @@ class Callback: def step_end(self, run_context): """ Called after each step finished. + A backwards compatibility alias for `on_train_step_end` and `on_eval_step_end`. Args: run_context (RunContext): Include some information of the model. @@ -162,11 +168,120 @@ class Callback: def end(self, run_context): """ Called once after network training. + A backwards compatibility alias for `on_train_end` and `on_eval_end`. Args: run_context (RunContext): Include some information of the model. """ + def on_train_begin(self, run_context): + """ + Called once before the network training. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.begin(run_context) + + def on_train_epoch_begin(self, run_context): + """ + Called before each training epoch begin. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.epoch_begin(run_context) + + def on_train_epoch_end(self, run_context): + """ + Called after each training epoch end. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.epoch_end(run_context) + + def on_train_step_begin(self, run_context): + """ + Called before each training step begin. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.step_begin(run_context) + + def on_train_step_end(self, run_context): + """ + Called after each training step end. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.step_end(run_context) + + def on_train_end(self, run_context): + """ + Called after training end. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.end(run_context) + + def on_eval_begin(self, run_context): + """ + Called before eval begin. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.begin(run_context) + + def on_eval_epoch_begin(self, run_context): + """ + Called before eval epoch begin. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.epoch_begin(run_context) + + def on_eval_epoch_end(self, run_context): + """ + Called after eval epoch end. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.epoch_end(run_context) + + def on_eval_step_begin(self, run_context): + """ + Called before each eval step begin. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.step_begin(run_context) + + def on_eval_step_end(self, run_context): + """ + Called after each eval step end. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.step_end(run_context) + + def on_eval_end(self, run_context): + """ + Called after eval end. + + Args: + run_context (RunContext): Include some information of the model. + """ + self.end(run_context) + class CallbackManager(Callback): """ @@ -208,7 +323,7 @@ class CallbackManager(Callback): return self._stack.__exit__(*err) def begin(self, run_context): - """Called once before network training.""" + """Called once before network train or eval.""" for cb in self._callbacks: cb.begin(run_context) @@ -223,7 +338,7 @@ class CallbackManager(Callback): cb.epoch_end(run_context) def step_begin(self, run_context): - """Called before each epoch begin.""" + """Called before each step begin.""" for cb in self._callbacks: cb.step_begin(run_context) @@ -233,10 +348,70 @@ class CallbackManager(Callback): cb.step_end(run_context) def end(self, run_context): - """Called once after network training.""" + """Called once after network train or eval.""" for cb in self._callbacks: cb.end(run_context) + def on_train_begin(self, run_context): + """Called before network train.""" + for cb in self._callbacks: + cb.on_train_begin(run_context) + + def on_train_epoch_begin(self, run_context): + """Called before each train epoch begin.""" + for cb in self._callbacks: + cb.on_train_epoch_begin(run_context) + + def on_train_epoch_end(self, run_context): + """Called after each train epoch finished.""" + for cb in self._callbacks: + cb.on_train_epoch_end(run_context) + + def on_train_step_begin(self, run_context): + """Called before each train step begin.""" + for cb in self._callbacks: + cb.on_train_step_begin(run_context) + + def on_train_step_end(self, run_context): + """Called after each train step finished.""" + for cb in self._callbacks: + cb.step_end(run_context) + + def on_train_end(self, run_context): + """Called after network train end.""" + for cb in self._callbacks: + cb.on_train_end(run_context) + + def on_eval_begin(self, run_context): + """Called before network eval.""" + for cb in self._callbacks: + cb.on_eval_begin(run_context) + + def on_eval_epoch_begin(self, run_context): + """Called before eval epoch begin.""" + for cb in self._callbacks: + cb.on_eval_epoch_begin(run_context) + + def on_eval_epoch_end(self, run_context): + """Called after eval epoch finished.""" + for cb in self._callbacks: + cb.on_eval_epoch_end(run_context) + + def on_eval_step_begin(self, run_context): + """Called before each eval step begin.""" + for cb in self._callbacks: + cb.on_eval_step_begin(run_context) + + def on_eval_step_end(self, run_context): + """Called after each eval step finished.""" + for cb in self._callbacks: + cb.on_eval_step_end(run_context) + + def on_eval_end(self, run_context): + """Called after network eval end.""" + for cb in self._callbacks: + cb.on_eval_end(run_context) + class InternalCallbackParam(dict): """Internal callback object's parameters.""" diff --git a/mindspore/python/mindspore/train/callback/_history.py b/mindspore/python/mindspore/train/callback/_history.py index bbab78f1b81..aa602487f77 100644 --- a/mindspore/python/mindspore/train/callback/_history.py +++ b/mindspore/python/mindspore/train/callback/_history.py @@ -21,7 +21,7 @@ from ._callback import Callback class History(Callback): """ - Records the network outputs information into a `History` object. + Records the network outputs and metrics information into a `History` object. The network outputs information will be the loss value if not custimizing the train network or eval network; if the custimized network returns a `Tensor` or `numpy.ndarray`, the mean value of network output @@ -29,7 +29,7 @@ class History(Callback): outputs will be recorded. Note: - Normally used in `mindspore.Model.train`. + Normally used in `mindspore.Model.train` or `mindspore.Model.fit`. Examples: >>> import numpy as np @@ -65,7 +65,7 @@ class History(Callback): def epoch_end(self, run_context): """ - Records the first element of network outputs at the end of epoch. + Records the first element of network outputs and metrics information at the end of epoch. Args: run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. For more details, diff --git a/mindspore/python/mindspore/train/callback/_lambda_callback.py b/mindspore/python/mindspore/train/callback/_lambda_callback.py index 9aeb7e98dc7..9e0aa6a5832 100644 --- a/mindspore/python/mindspore/train/callback/_lambda_callback.py +++ b/mindspore/python/mindspore/train/callback/_lambda_callback.py @@ -22,19 +22,25 @@ class LambdaCallback(Callback): Callback for creating simple, custom callbacks. This callback is constructed with anonymous functions that will be called - at the appropriate time (during `mindspore.Model.{train | eval}`). Note that + at the appropriate time (during `mindspore.Model.{train | eval | fit}`). Note that each stage of callbacks expects one positional arguments: `run_context`. Note: This is an experimental interface that is subject to change or deletion. Args: - epoch_begin (Function): called at the beginning of every epoch. - epoch_end (Function): called at the end of every epoch. - step_begin (Function): called at the beginning of every batch. - step_end (Function): called at the end of every batch. - begin (Function): called at the beginning of model train/eval. - end (Function): called at the end of model train/eval. + on_train_epoch_begin (Function): called at each train epoch begin. + on_train_epoch_end (Function): called at each train epoch end. + on_train_step_begin (Function): called at each train step begin. + on_train_step_end (Function): called at each train step end. + on_train_begin (Function): called at the beginning of model train. + on_train_end (Function): called at the end of model train. + on_eval_epoch_begin (Function): called at eval epoch begin. + on_eval_epoch_end (Function): called at eval epoch end. + on_eval_step_begin (Function): called at each eval step begin. + on_eval_step_end (Function): called at each eval step end. + on_eval_begin (Function): called at the beginning of model eval. + on_eval_end (Function): called at the end of model eval. Examples: >>> import numpy as np @@ -46,19 +52,28 @@ class LambdaCallback(Callback): >>> net = nn.Dense(10, 5) >>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9) - >>> lambda_callback = LambdaCallback(epoch_end= + >>> lambda_callback = LambdaCallback(on_train_epoch_end= ... lambda run_context: print("loss: ", run_context.original_args().net_outputs)) >>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"}) >>> model.train(2, train_dataset, callbacks=[lambda_callback]) loss: 1.6127687 loss: 1.6106578 """ - def __init__(self, epoch_begin=None, epoch_end=None, step_begin=None, - step_end=None, begin=None, end=None): + def __init__(self, on_train_epoch_begin=None, on_train_epoch_end=None, on_train_step_begin=None, + on_train_step_end=None, on_train_begin=None, on_train_end=None, + on_eval_epoch_begin=None, on_eval_epoch_end=None, on_eval_step_begin=None, + on_eval_step_end=None, on_eval_begin=None, on_eval_end=None): super(LambdaCallback, self).__init__() - self.epoch_begin = epoch_begin if epoch_begin else lambda run_context: None - self.epoch_end = epoch_end if epoch_end else lambda run_context: None - self.step_begin = step_begin if step_begin else lambda run_context: None - self.step_end = step_end if step_end else lambda run_context: None - self.begin = begin if begin else lambda run_context: None - self.end = end if end else lambda run_context: None + self.on_train_epoch_begin = on_train_epoch_begin if on_train_epoch_begin else lambda run_context: None + self.on_train_epoch_end = on_train_epoch_end if on_train_epoch_end else lambda run_context: None + self.on_train_step_begin = on_train_step_begin if on_train_step_begin else lambda run_context: None + self.on_train_step_end = on_train_step_end if on_train_step_end else lambda run_context: None + self.on_train_begin = on_train_begin if on_train_begin else lambda run_context: None + self.on_train_end = on_train_end if on_train_end else lambda run_context: None + + self.on_eval_epoch_begin = on_eval_epoch_begin if on_eval_epoch_begin else lambda run_context: None + self.on_eval_epoch_end = on_eval_epoch_end if on_eval_epoch_end else lambda run_context: None + self.on_eval_step_begin = on_eval_step_begin if on_eval_step_begin else lambda run_context: None + self.on_eval_step_end = on_eval_step_end if on_eval_step_end else lambda run_context: None + self.on_eval_begin = on_eval_begin if on_eval_begin else lambda run_context: None + self.on_eval_end = on_eval_end if on_eval_end else lambda run_context: None diff --git a/mindspore/python/mindspore/train/callback/_loss_monitor.py b/mindspore/python/mindspore/train/callback/_loss_monitor.py index 4037935762a..c96f698e4c2 100644 --- a/mindspore/python/mindspore/train/callback/_loss_monitor.py +++ b/mindspore/python/mindspore/train/callback/_loss_monitor.py @@ -23,7 +23,7 @@ from ._callback import Callback class LossMonitor(Callback): """ - Monitor the loss in training. + Monitor the loss in train or monitor the loss and eval metrics in fit. If the loss is NAN or INF, it will terminate training. @@ -90,3 +90,17 @@ class LossMonitor(Callback): if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times: self._last_print_time = cb_params.cur_step_num print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) + + def on_train_epoch_end(self, run_context): + """ + When LossMoniter used in `model.fit`, print eval metrics at the end of epoch if current epoch + should do evaluation. + + Args: + run_context (RunContext): Include some information of the model. For more details, + please refer to :class:`mindspore.RunContext`. + """ + cb_params = run_context.original_args() + metrics = cb_params.get("metrics") + if metrics: + print("Eval result: epoch %d, metrics: %s" % (cb_params.cur_epoch_num, metrics)) diff --git a/mindspore/python/mindspore/train/callback/_time_monitor.py b/mindspore/python/mindspore/train/callback/_time_monitor.py index 4f2bae36ba0..d568b15b747 100644 --- a/mindspore/python/mindspore/train/callback/_time_monitor.py +++ b/mindspore/python/mindspore/train/callback/_time_monitor.py @@ -22,7 +22,7 @@ from ._callback import Callback class TimeMonitor(Callback): """ - Monitor the time in training. + Monitor the time in train or eval process. Args: data_size (int): How many steps are the intervals between print information each time. @@ -71,6 +71,7 @@ class TimeMonitor(Callback): epoch_seconds = (time.time() - self.epoch_time) * 1000 step_size = self.data_size cb_params = run_context.original_args() + mode = cb_params.get("mode", "") if hasattr(cb_params, "batch_num"): batch_num = cb_params.batch_num if isinstance(batch_num, int) and batch_num > 0: @@ -78,4 +79,5 @@ class TimeMonitor(Callback): Validator.check_positive_int(step_size) step_seconds = epoch_seconds / step_size - print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, step_seconds), flush=True) + print("{} epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format + (mode.title(), epoch_seconds, step_seconds), flush=True) diff --git a/mindspore/python/mindspore/train/model.py b/mindspore/python/mindspore/train/model.py index e5098b7e58d..5fbd824e665 100644 --- a/mindspore/python/mindspore/train/model.py +++ b/mindspore/python/mindspore/train/model.py @@ -27,7 +27,8 @@ from .callback._checkpoint import _chg_ckpt_file_name_if_same_exist from ..common.tensor import Tensor from ..nn.metrics import get_metrics from .._checkparam import check_input_data, check_output_data, Validator -from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback +from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor +from .callback import __all__ as internal_cb_names from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \ @@ -503,7 +504,8 @@ class Model: return [callbacks] @_save_final_ckpt - def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0): + def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0, + valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True): """ Training. @@ -541,6 +543,7 @@ class Model: cb_params.device_number = self._device_number cb_params.train_dataset = train_dataset cb_params.list_callback = self._transform_callbacks(callbacks) + valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode) if context.get_context("mode") == context.PYNATIVE_MODE: cb_params.list_callback.insert(0, _StepSync()) callbacks = cb_params.list_callback @@ -555,17 +558,21 @@ class Model: with _CallbackManager(callbacks) as list_callback: self._check_reuse_dataset(train_dataset) if not dataset_sink_mode: - self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch) + self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos) elif context.get_context("device_target") == "CPU": logger.info("The CPU cannot support dataset sink mode currently." "So the training process will be performed with dataset not sink.") - self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch) + self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch, valid_infos) else: self._train_dataset_sink_process(epoch, train_dataset, list_callback, - cb_params, sink_size, initial_epoch) + cb_params, sink_size, initial_epoch, valid_infos) + + @staticmethod + def _should_eval(epoch, validation_freq): + return epoch % validation_freq == 0 if isinstance(validation_freq, int) else epoch in validation_freq def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, - sink_size=-1, initial_epoch=0): + sink_size=-1, initial_epoch=0, valid_infos=None): """ Training process. The data would be passed to network through dataset channel. @@ -593,7 +600,7 @@ class Model: cb_params.dataset_sink_mode = True run_context = RunContext(cb_params) - list_callback.begin(run_context) + list_callback.on_train_begin(run_context) # used to stop training for early stop, such as stopAtTIme or stopATStep dataset_helper = None if hasattr(train_dataset, '_dataset_helper'): @@ -609,7 +616,7 @@ class Model: cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch self._current_epoch_num = cb_params.cur_epoch_num self._current_step_num = 0 - list_callback.epoch_begin(run_context) + list_callback.on_train_epoch_begin(run_context) dataset_helper, train_network = self._exec_preprocess(is_train=True, dataset=train_dataset, dataset_sink_mode=True, @@ -632,22 +639,51 @@ class Model: cb_params.cur_step_num += 1 self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) cb_params.train_dataset_element = inputs - list_callback.step_begin(run_context) + list_callback.on_train_step_begin(run_context) outputs = train_network(*inputs) cb_params.net_outputs = outputs # In disaster recovery scenarios, need not to execute callbacks if this step executes failed. need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset")) if need_exec_callback_step_end: - list_callback.step_end(run_context) + list_callback.on_train_step_end(run_context) if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched(): os._exit(0) dataset_helper.continue_send() + + valid_dataset, valid_frequency, valid_dataset_sink_mode = valid_infos + if valid_dataset and self._should_eval(cb_params.cur_epoch_num, valid_frequency): + + train_cur_step_num = cb_params.cur_step_num + train_batch_num = cb_params.batch_num + train_dataset_sink_mode = cb_params.dataset_sink_mode + train_net_outputs = cb_params.net_outputs + + eval_callback = [] + for cb in list_callback._callbacks: + if cb.__class__.__name__ in internal_cb_names: + if isinstance(cb, TimeMonitor): + eval_callback.append(cb) + else: + eval_callback.append(cb) + + self._eval_in_fit(valid_dataset, + callbacks=eval_callback, + dataset_sink_mode=valid_dataset_sink_mode, + cb_params=cb_params) + cb_params.mode = "train" + cb_params.cur_step_num = train_cur_step_num + cb_params.batch_num = train_batch_num + cb_params.dataset_sink_mode = train_dataset_sink_mode + cb_params.net_outputs = train_net_outputs + # In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed. need_exec_callback_epoch_end = not (self.enable_recovery and _get_recovery_context("need_reset")) if need_exec_callback_epoch_end: - list_callback.epoch_end(run_context) + list_callback.on_train_epoch_end(run_context) + if "metrics" in cb_params: + cb_params.pop("metrics") should_stop = run_context.get_stop_requested() if should_stop: @@ -663,7 +699,7 @@ class Model: dataset_helper.stop_send() dataset_helper.release() - list_callback.end(run_context) + list_callback.on_train_end(run_context) def _check_enable_recovery(self): """ @@ -753,7 +789,8 @@ class Model: _set_recovery_context(need_reset=False) - def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0): + def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0, + valid_infos=None): """ Training process. The data would be passed to network directly. @@ -776,13 +813,13 @@ class Model: cb_params.cur_step_num = 0 cb_params.dataset_sink_mode = False run_context = RunContext(cb_params) - list_callback.begin(run_context) + list_callback.on_train_begin(run_context) for i in range(initial_epoch, epoch): cb_params.cur_epoch_num = i + 1 self._current_epoch_num = cb_params.cur_epoch_num self._current_step_num = 0 - list_callback.epoch_begin(run_context) + list_callback.on_train_epoch_begin(run_context) for next_element in dataset_helper: len_element = len(next_element) @@ -795,7 +832,7 @@ class Model: self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) cb_params.train_dataset_element = next_element - list_callback.step_begin(run_context) + list_callback.on_train_step_begin(run_context) outputs = self._train_network(*next_element) cb_params.net_outputs = outputs if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): @@ -803,24 +840,52 @@ class Model: overflow = np.all(overflow.asnumpy()) self._loss_scale_manager.update_loss_scale(overflow) - list_callback.step_end(run_context) + list_callback.on_train_step_end(run_context) if (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched(): os._exit(0) should_stop = run_context.get_stop_requested() if should_stop: break + valid_dataset, valid_frequency, valid_dataset_sink_mode = valid_infos + if valid_dataset and self._should_eval(cb_params.cur_epoch_num, valid_frequency): + train_cur_step_num = cb_params.cur_step_num + train_batch_num = cb_params.batch_num + train_dataset_sink_mode = cb_params.dataset_sink_mode + train_net_outputs = cb_params.net_outputs + + eval_callback = [] + for cb in list_callback._callbacks: + if cb.__class__.__name__ in internal_cb_names: + if isinstance(cb, TimeMonitor): + eval_callback.append(cb) + else: + eval_callback.append(cb) + + self._eval_in_fit(valid_dataset, + callbacks=eval_callback, + dataset_sink_mode=valid_dataset_sink_mode, + cb_params=cb_params) + + cb_params.mode = "train" + cb_params.cur_step_num = train_cur_step_num + cb_params.batch_num = train_batch_num + cb_params.dataset_sink_mode = train_dataset_sink_mode + cb_params.net_outputs = train_net_outputs + train_dataset.reset() # if param is cache enable, flush data from cache to host before epoch end self._flush_from_cache(cb_params) - list_callback.epoch_end(run_context) + list_callback.on_train_epoch_end(run_context) + if "metrics" in cb_params: + cb_params.pop("metrics") should_stop = run_context.get_stop_requested() if should_stop: break - list_callback.end(run_context) + list_callback.on_train_end(run_context) def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0): """ @@ -912,6 +977,9 @@ class Model: _device_number_check(self._parallel_mode, self._device_number) + if callbacks: + self._check_methods_for_custom_callbacks(callbacks, "train") + self._train(epoch, train_dataset, callbacks=callbacks, @@ -924,6 +992,140 @@ class Model: if _is_ps_mode() and _enable_distributed_mindrt(): _reset_op_id_with_offset() + @staticmethod + def _check_methods_for_custom_callbacks(callbacks, current_mode): + """ + Check whether methods of custimized callbacks are valid. + + Args: + callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object. + current_mode (str): 'fit', 'train' or 'eval'. + """ + old_version_methods_names = {'begin', 'end', 'epoch_begin', 'epoch_end', 'step_begin', 'step_end'} + if not isinstance(callbacks, list): + callbacks = [callbacks] + for cb in callbacks: + cb_name = cb.__class__.__name__ + if cb_name not in internal_cb_names: + cb_methods_names = set(cb.__class__.__dict__.keys()) + invalid_methods_names = cb_methods_names & old_version_methods_names + if invalid_methods_names: + if current_mode in ["train", "eval"]: + logger.warning("For %s callback, %s methods may not be supported in later version, " + "Use methods prefixed with 'on_train' or 'on_eval' instead " + "when using customized callbacks." % (cb_name, invalid_methods_names)) + else: + raise ValueError("For %s callback, %s methods may not be supported in later version, " + "Use methods prefixed with 'on_train' or 'on_eval' instead when" + "using customized callbacks." % (cb_name, invalid_methods_names)) + + def fit(self, epoch, train_dataset, valid_dataset=None, valid_frequency=1, callbacks=None, + dataset_sink_mode=True, valid_dataset_sink_mode=True, sink_size=-1, initial_epoch=0): + """ + Fit API. + + Evaluation process will be performed during training process if `valid_dataset` is provided. + + More details please refer to `mindspore.Model.train` and `mindspore.Model.eval`. + + Args: + epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch. + If `dataset_sink_mode` is set to True and `sink_size` is greater than 0, each epoch will + train `sink_size` steps instead of total steps of dataset. + If `epoch` used with `initial_epoch`, it is to be understood as "final epoch". + train_dataset (Dataset): A training dataset iterator. If `loss_fn` is defined, the data and label will be + passed to the `network` and the `loss_fn` respectively, so a tuple (data, label) + should be returned from dataset. If there is multiple data or labels, set `loss_fn` + to None and implement calculation of loss in `network`, + then a tuple (data1, data2, data3, ...) with all data returned from dataset + will be passed to the `network`. + valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process + will be performed on the end of training process. Default: None. + valid_frequency (int, list): Only relevant if `valid_dataset` is provided. If an integer, specifies + how many training epochs to run before a new validation run is performed, + e.g. `valid_frequency=2` runs validation every 2 epochs. + If a list, specifies the epochs on which to run validation, + e.g. `valid_frequency=[1, 5]` runs validation at the end of the 1st, 5th epochs. + Default: 1 + callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, + which should be executed while training. + Default: None. + dataset_sink_mode (bool): Determines whether to pass the train data through dataset channel. + Configure pynative mode or CPU, the training process will be performed with + dataset not sink. Default: True. + valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel. + Default: True. + sink_size (int): Control the amount of data in each sink. `sink_size` is invalid if `dataset_sink_mode` + is False. + If sink_size = -1, sink the complete dataset for each epoch. + If sink_size > 0, sink sink_size data for each epoch. + Default: -1. + initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run. + Default: 0. + + Examples: + >>> from mindspore import Model, nn, FixedLossScaleManager + >>> + >>> # For details about how to build the dataset, please refer to the tutorial + >>> # document on the official website. + >>> train_dataset = create_custom_dataset() + >>> valid_dataset = create_custom_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics="accuracy") + >>> model.fit(2, train_dataset, valid_dataset) + """ + + dataset_sink_mode = Validator.check_bool(dataset_sink_mode) + valid_dataset_sink_mode = Validator.check_bool(valid_dataset_sink_mode) + + if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode: + raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.") + + if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch: + raise ValueError("when use Model.build to initialize model, the value of parameter `epoch` in Model.build " + "should be equal to value in Model.fit, but got {} and {} separately." + .format(train_dataset._warmup_epoch, epoch)) + + if dataset_sink_mode and _is_ps_mode() and not _cache_enable(): + raise ValueError("Parameter server mode does not support 'data_sink_mode=True'.") + + Validator.check_is_int(sink_size) + Validator.check_non_negative_int(initial_epoch) + if initial_epoch >= epoch: + raise ValueError(f"For 'Model.train', the parameter 'epoch' must bigger than parameter 'initial_epoch'," + f" but got the parameter 'epoch' is {epoch}, 'initial_epoch' is {initial_epoch}.") + dataset_size = train_dataset.get_dataset_size() + if dataset_size == 0: + raise ValueError("There is no valid data in dataset, please check dataset file firstly.") + if sink_size == -1: + sink_size = dataset_size + if sink_size < -1 or sink_size == 0: + raise ValueError("For 'Model.fit', The parameter 'sink_size' must be -1 or positive, " + "but got {}.".format(sink_size)) + + _device_number_check(self._parallel_mode, self._device_number) + + if not isinstance(valid_frequency, (int, list)): + raise ValueError(f"For 'Model.fit', the type of 'valid_frequency' must be a list or a integer, but got" + "type {type(validation_freq)}.") + + if valid_dataset and not self._metric_fns: + raise ValueError("For 'Model.fit', if valid_dataset is not None, the model argument 'metrics' can not be" + "None or empty, you should set the argument 'metrics' for model.") + if callbacks: + self._check_methods_for_custom_callbacks(callbacks, "fit") + self._train(epoch, + train_dataset, + callbacks=callbacks, + dataset_sink_mode=dataset_sink_mode, + sink_size=sink_size, + initial_epoch=initial_epoch, + valid_dataset=valid_dataset, + valid_frequency=valid_frequency, + valid_dataset_sink_mode=valid_dataset_sink_mode) + def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1, jit_config=None): """ Build computational graphs and data graphs with the sink mode. @@ -971,6 +1173,40 @@ class Model: _cell_graph_executor.set_jit_config(jit_config) self._init(train_dataset, valid_dataset, sink_size, epoch) + def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None): + """ + Evaluation process in `mindspore.Model.fit`. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process + will be performed on the end of training process. Default: None. + callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, which should be + executed while evaluation. Default: None. + valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel. + Default: True. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + """ + if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode: + raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.") + + cb_params.eval_network = self._eval_network + cb_params.valid_dataset = valid_dataset + cb_params.batch_num = valid_dataset.get_dataset_size() + cb_params.mode = "eval" + cb_params.cur_step_num = 0 + + self._clear_metrics() + + if context.get_context("device_target") == "CPU" and dataset_sink_mode: + dataset_sink_mode = False + logger.info("CPU cannot support dataset sink mode currently." + "So the evaluating process will be performed with dataset non-sink mode.") + + with _CallbackManager(callbacks) as list_callback: + if dataset_sink_mode: + return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) + return self._eval_process(valid_dataset, list_callback, cb_params) + def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): """ Evaluation. The data would be passed to network through dataset channel. @@ -990,20 +1226,20 @@ class Model: dataset_sink_mode=True) cb_params.eval_network = eval_network cb_params.dataset_sink_mode = True - list_callback.begin(run_context) - list_callback.epoch_begin(run_context) + list_callback.on_eval_begin(run_context) + list_callback.on_eval_epoch_begin(run_context) for inputs in dataset_helper: cb_params.cur_step_num += 1 - list_callback.step_begin(run_context) + list_callback.on_eval_step_begin(run_context) outputs = eval_network(*inputs) cb_params.net_outputs = outputs - list_callback.step_end(run_context) + list_callback.on_eval_step_end(run_context) self._update_metrics(outputs) - list_callback.epoch_end(run_context) + list_callback.on_eval_epoch_end(run_context) metrics = self._get_metrics() cb_params.metrics = metrics - list_callback.end(run_context) + list_callback.on_eval_end(run_context) return metrics @@ -1021,25 +1257,25 @@ class Model: """ run_context = RunContext(cb_params) cb_params.dataset_sink_mode = False - list_callback.begin(run_context) + list_callback.on_eval_begin(run_context) dataset_helper, _ = self._exec_preprocess(is_train=False, dataset=valid_dataset, dataset_sink_mode=False) - list_callback.epoch_begin(run_context) + list_callback.on_eval_epoch_begin(run_context) for next_element in dataset_helper: cb_params.cur_step_num += 1 - list_callback.step_begin(run_context) + list_callback.on_eval_step_begin(run_context) next_element = _transfer_tensor_to_tuple(next_element) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs - list_callback.step_end(run_context) + list_callback.on_eval_step_end(run_context) self._update_metrics(outputs) - list_callback.epoch_end(run_context) + list_callback.on_eval_epoch_end(run_context) valid_dataset.reset() metrics = self._get_metrics() cb_params.metrics = metrics - list_callback.end(run_context) + list_callback.on_eval_end(run_context) return metrics def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True): @@ -1087,7 +1323,8 @@ class Model: "you should set the argument 'metrics' for model.") if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode: raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.") - + if callbacks: + self._check_methods_for_custom_callbacks(callbacks, "eval") cb_params = _InternalCallbackParam() cb_params.eval_network = self._eval_network cb_params.valid_dataset = valid_dataset diff --git a/mindspore/python/mindspore/train/train_thor/model_thor.py b/mindspore/python/mindspore/train/train_thor/model_thor.py index 6b70c8a00db..3cf2f011be9 100644 --- a/mindspore/python/mindspore/train/train_thor/model_thor.py +++ b/mindspore/python/mindspore/train/train_thor/model_thor.py @@ -150,7 +150,7 @@ class ModelThor(Model): self.switch_branch_one = not self.switch_branch_one outputs = self._train_network(*inputs) cb_params.net_outputs = outputs - list_callback.step_end(run_context) + list_callback.on_train_step_end(run_context) else: cb_params.cur_step_num += 1 if self.train_network_init_flag: @@ -163,7 +163,7 @@ class ModelThor(Model): if self.index_first_order == iter_first_order: self.index_first_order = 0 self.switch_branch_one = not self.switch_branch_one - list_callback.step_end(run_context) + list_callback.on_train_step_end(run_context) def _train_ascend_sink_step(self, cb_params, train_dataset, iter_first_order, inputs, list_callback, run_context): """train ascend sink step""" @@ -184,10 +184,10 @@ class ModelThor(Model): self.switch_branch_one = not self.switch_branch_one outputs = self._train_network(*inputs) cb_params.net_outputs = outputs - list_callback.step_end(run_context) + list_callback.on_train_step_end(run_context) def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, - sink_size=-1, initial_epoch=0): + sink_size=-1, initial_epoch=0, valid_infos=None): """ Training process. The data would be passed to network through dataset channel. @@ -204,6 +204,9 @@ class ModelThor(Model): initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run. Default: 0. """ + valid_dataset, _, _ = valid_infos + if valid_dataset: + raise ValueError("Evaluation in training is currently not supported in the second-order scenario of thor.") if sink_size == -1: epoch_num = epoch - initial_epoch else: @@ -226,28 +229,28 @@ class ModelThor(Model): cb_params.cur_step_num = 0 run_context = RunContext(cb_params) - list_callback.begin(run_context) + list_callback.on_train_begin(run_context) for i in range(initial_epoch, epoch): cb_params.cur_epoch_num = i + 1 - list_callback.epoch_begin(run_context) + list_callback.on_train_epoch_begin(run_context) # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: if _need_to_full() and context.get_context("device_target") == "GPU": inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) - list_callback.step_begin(run_context) + list_callback.on_train_step_begin(run_context) if context.get_context("device_target") == "GPU": self._train_gpu_sink_step(cb_params, inputs, list_callback, iter_first_order, run_context) else: self._train_ascend_sink_step(cb_params, train_dataset, iter_first_order, inputs, list_callback, run_context) - list_callback.epoch_end(run_context) + list_callback.on_train_epoch_end(run_context) should_stop = False or run_context.get_stop_requested() if should_stop: break dataset_helper.stop_send() - list_callback.end(run_context) + list_callback.on_train_end(run_context) __all__ = ["ModelThor"] diff --git a/tests/st/train/test_fit.py b/tests/st/train/test_fit.py new file mode 100644 index 00000000000..2e4dcf82c2d --- /dev/null +++ b/tests/st/train/test_fit.py @@ -0,0 +1,202 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" test_fit """ + +import pytest +import numpy as np +from mindspore import Model, nn, Tensor +from mindspore.common.initializer import Normal +from mindspore.train.callback import Callback, TimeMonitor, LossMonitor +from mindspore import dataset as ds + + +def get_data(num, w=2.0, b=3.0): + for _ in range(num): + x = np.random.uniform(-10.0, 10.0) + noise = np.random.normal(0, 1) + y = x * w + b + noise + yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32) + + +def create_dataset(num_data, batch_size=16, repeat_size=1): + input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label']) + input_data = input_data.batch(batch_size, drop_remainder=True) + input_data = input_data.repeat(repeat_size) + return input_data + + +def define_model(): + net = nn.Dense(1, 1, Normal(0.02), Normal(0.02)) + net_loss = nn.MSELoss() + net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9) + return Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={'mse', 'mae'}) + + +class MyCallbackOldMethod(Callback): + """ Raise warning in `mindspore.Model.train` and `mindspore.Model.eval`; raise error in `mindspore.Model.fit`""" + def begin(self, run_context): + print("custom callback: print on begin, just for test.") + + def step_end(self, run_context): + cb_params = run_context.original_args() + outputs = cb_params.get("net_outputs") + result = outputs if isinstance(outputs, Tensor) else outputs[0] + print("custom train callback: step end, loss is %s" % (result)) + + def on_train_epoch_end(self, run_context): + cb_params = run_context.original_args() + print("custom train callback: epoch end, loss is %s" % (cb_params.get("net_outputs"))) + + +class MyCallbackNewMethod(Callback): + """ Custom callback running in `mindspore.Model.train`, `mindspore.Model.eval`, `mindspore.Model.fit`""" + def on_train_epoch_end(self, run_context): + cb_params = run_context.original_args() + print("custom callback: train epoch end, loss is %s" % (cb_params.get("net_outputs"))) + + def on_eval_epoch_end(self, run_context): + cb_params = run_context.original_args() + print("custom callback: eval epoch end, metric is %s" % (cb_params.get("net_outputs")[0])) + + +def test_fit_train_dataset_non_sink_mode(): + """ + Feature: `mindspore.Model.fit` with train dataset in non-sink mode. + Description: test fit with train dataset in non-sink mode. + Expectation: run in non-sink mode. + """ + model = define_model() + ds_train = create_dataset(4096, 1024) + ds_eval = create_dataset(1024, 512) + callbacks = [LossMonitor()] + model.fit(3, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=False) + + +def test_fit_train_dataset_sink_mode(): + """ + Feature: `mindspore.Model.fit` with train dataset in sink mode. + Description: test fit with train dataset in sink mode. + Expectation: run in sink mode. + """ + model = define_model() + ds_train = create_dataset(4096, 1024) + ds_eval = create_dataset(1024, 512) + callbacks = [LossMonitor()] + model.fit(3, ds_train, ds_eval, callbacks=callbacks, dataset_sink_mode=True, sink_size=256) + + +def test_fit_valid_dataset_non_sink_mode(): + """ + Feature: `mindspore.Model.fit` with valid dataset in non-sink mode. + Description: test fit with valid dataset in non-sink mode. + Expectation: run in non-sink mode. + """ + model = define_model() + ds_train = create_dataset(4096, 1024) + ds_eval = create_dataset(1024, 512) + callbacks = [LossMonitor()] + model.fit(3, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=False) + + +def test_fit_valid_dataset_sink_mode(): + """ + Feature: `mindspore.Model.fit` with valid dataset in sink mode. + Description: test fit with valid dataset in sink mode. + Expectation: run in sink mode. + """ + model = define_model() + ds_train = create_dataset(4096, 1024) + ds_eval = create_dataset(1024, 512) + callbacks = [LossMonitor()] + model.fit(3, ds_train, ds_eval, callbacks=callbacks, valid_dataset_sink_mode=True) + + +def test_fit_without_valid_dataset(): + """ + Feature: `mindspore.Model.fit` without `valid_dataset` input . + Description: test fit when `valid_dataset` is None and `valid_dataset_sink_mode` is True or False. + Expectation: network train without eval process, `valid_dataset_sink_mode` does not take effect. + """ + model = define_model() + ds_train = create_dataset(4096, 1024) + callbacks = [LossMonitor()] + model.fit(3, ds_train, None, callbacks=callbacks, valid_dataset_sink_mode=False) + model.fit(3, ds_train, None, callbacks=callbacks) + + +def test_fit_valid_frequency(): + """ + Feature: check `valid_frequency` input in `mindspore.Model.fit`. + Description: when `valid_frequency` is integer, list or other types. + Expectation: raise ValueError when the type of valid_frequency is not int or list. + """ + model = define_model() + callbacks = [LossMonitor()] + ds_train = create_dataset(4096, 1024) + ds_eval = create_dataset(1024, 512) + model.fit(3, ds_train, ds_eval, valid_frequency=1, callbacks=callbacks) + model.fit(5, ds_train, ds_eval, valid_frequency=2, callbacks=callbacks) + model.fit(5, ds_train, ds_eval, valid_frequency=[0, 1, 4], callbacks=callbacks) + with pytest.raises(ValueError): + model.fit(5, ds_train, ds_eval, valid_frequency=(0, 2), callbacks=callbacks) + + +def test_fit_callbacks(): + """ + Feature: check `callbacks` input in `mindspore.Model.fit`. + Description: test internal or custom callbacks in fit. + Expectation: raise ValueError when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'. + """ + model = define_model() + ds_train = create_dataset(4096, 1024) + ds_eval = create_dataset(1024, 512) + model.fit(3, ds_train, ds_eval, callbacks=None) + model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor()]) + model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), LossMonitor()]) + model.fit(3, ds_train, ds_eval, callbacks=[MyCallbackNewMethod()]) + model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), MyCallbackNewMethod()]) + with pytest.raises(ValueError): + model.fit(3, ds_train, ds_eval, callbacks=[MyCallbackOldMethod()]) + with pytest.raises(ValueError): + model.fit(3, ds_train, ds_eval, callbacks=[TimeMonitor(), MyCallbackOldMethod()]) + with pytest.raises(ValueError): + model.fit(3, ds_train, valid_dataset=None, callbacks=[TimeMonitor(), MyCallbackOldMethod()]) + + +def test_train_eval_callbacks(): + """ + Feature: check `callbacks` input in `mindspore.Model.train` or `mindspore.Model.eval`. + Description: test internal or custom callbacks in train or eval. + Expectation: raise warning when methods of custom callbacks are not prefixed with 'on_train' or 'on_eval'. + """ + model = define_model() + ds_train = create_dataset(4096, 1024) + ds_eval = create_dataset(1024, 512) + + model.train(3, ds_train, callbacks=None) + model.train(3, ds_train, callbacks=[TimeMonitor()]) + model.train(3, ds_train, callbacks=[LossMonitor()]) + model.train(3, ds_train, callbacks=[MyCallbackNewMethod()]) + model.train(3, ds_train, callbacks=[MyCallbackOldMethod()]) + + metric_results = model.eval(ds_eval, callbacks=None) + print("{}".format(metric_results)) + metric_results = model.eval(ds_eval, callbacks=[TimeMonitor()]) + print("{}".format(metric_results)) + metric_results = model.eval(ds_eval, callbacks=[MyCallbackNewMethod()]) + print("{}".format(metric_results)) + metric_results = model.eval(ds_eval, callbacks=[MyCallbackOldMethod()]) + print("{}".format(metric_results)) diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index 9a06cba5d71..f0d299e2df8 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -513,13 +513,13 @@ def test_lambda(): run_context = RunContext(cb_params) lambda_cb = LambdaCallback( - epoch_end=lambda run_context: print("loss result: ", run_context.original_args().net_outputs)) + on_train_epoch_end=lambda run_context: print("loss result: ", run_context.original_args().net_outputs)) callbacks = [lambda_cb] with _CallbackManager(callbacks) as callbacklist: - callbacklist.begin(run_context) - callbacklist.epoch_begin(run_context) - callbacklist.step_begin(run_context) - callbacklist.step_end(run_context) - callbacklist.epoch_end(run_context) - callbacklist.end(run_context) + callbacklist.on_train_begin(run_context) + callbacklist.on_train_epoch_begin(run_context) + callbacklist.on_train_step_begin(run_context) + callbacklist.on_train_step_end(run_context) + callbacklist.on_train_epoch_end(run_context) + callbacklist.on_train_end(run_context)