!42104 move Model and Callbacks from mindspore to mindspore.train

Merge pull request !42104 from 吕昱峰(Nate.River)/code_docs_master
This commit is contained in:
i-robot 2022-09-16 02:03:12 +00:00 committed by Gitee
commit 3016e92c1c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
28 changed files with 147 additions and 139 deletions

View File

@ -63,32 +63,6 @@ mindspore
mindspore.get_algo_parameters
mindspore.reset_algo_parameters
模型
-----
.. mscnautosummary::
:toctree: mindspore
mindspore.Model
回调函数
---------
.. mscnautosummary::
:toctree: mindspore
mindspore.Callback
mindspore.CheckpointConfig
mindspore.EarlyStopping
mindspore.History
mindspore.LambdaCallback
mindspore.LearningRateScheduler
mindspore.LossMonitor
mindspore.ModelCheckpoint
mindspore.ReduceLROnPlateau
mindspore.RunContext
mindspore.TimeMonitor
数据处理工具
-------------------

View File

@ -0,0 +1,28 @@
mindspore.train
===============
模型
-----
.. mscnautosummary::
:toctree: mindspore
mindspore.train.Model
回调函数
---------
.. mscnautosummary::
:toctree: mindspore
mindspore.train.Callback
mindspore.train.CheckpointConfig
mindspore.train.EarlyStopping
mindspore.train.History
mindspore.train.LambdaCallback
mindspore.train.LearningRateScheduler
mindspore.train.LossMonitor
mindspore.train.ModelCheckpoint
mindspore.train.ReduceLROnPlateau
mindspore.train.RunContext
mindspore.train.TimeMonitor

View File

@ -1,7 +1,7 @@
mindspore.Callback
===================
mindspore.train.Callback
========================
.. py:class:: mindspore.Callback
.. py:class:: mindspore.train.Callback
用于构建Callback函数的基类。Callback函数是一个上下文管理器在运行模型时被调用。
可以使用此机制进行一些自定义操作。

View File

@ -1,7 +1,7 @@
mindspore.CheckpointConfig
===========================
mindspore.train.CheckpointConfig
================================
.. py:class:: mindspore.CheckpointConfig(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM', exception_save=False)
.. py:class:: mindspore.train.CheckpointConfig(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM', exception_save=False)
保存checkpoint时的配置策略。

View File

@ -1,7 +1,7 @@
mindspore.EarlyStopping
================================
mindspore.train.EarlyStopping
=============================
.. py:class:: mindspore.EarlyStopping(monitor='eval_loss', min_delta=0, patience=0, verbose=False, mode='auto', baseline=None, restore_best_weights=False)
.. py:class:: mindspore.train.EarlyStopping(monitor='eval_loss', min_delta=0, patience=0, verbose=False, mode='auto', baseline=None, restore_best_weights=False)
当监控的指标停止改进时停止训练。

View File

@ -1,7 +1,7 @@
mindspore.History
===========================
mindspore.train.History
=======================
.. py:class:: mindspore.History
.. py:class:: mindspore.train.History
将网络输出和评估指标的相关信息记录到 `History` 对象中。

View File

@ -1,7 +1,7 @@
mindspore.LambdaCallback
===========================
mindspore.train.LambdaCallback
==============================
.. py:class:: mindspore.LambdaCallback(on_train_epoch_begin=None, on_train_epoch_end=None, on_train_step_begin=None, on_train_step_end=None, on_train_begin=None, on_train_end=None, on_eval_epoch_begin=None, on_eval_epoch_end=None, on_eval_step_begin=None, on_eval_step_end=None, on_eval_begin=None, on_eval_end=None)
.. py:class:: mindspore.train.LambdaCallback(on_train_epoch_begin=None, on_train_epoch_end=None, on_train_step_begin=None, on_train_step_end=None, on_train_begin=None, on_train_end=None, on_eval_epoch_begin=None, on_eval_epoch_end=None, on_eval_step_begin=None, on_eval_step_end=None, on_eval_begin=None, on_eval_end=None)
用于自定义简单的callback。

View File

@ -1,7 +1,7 @@
mindspore.LearningRateScheduler
================================
mindspore.train.LearningRateScheduler
=====================================
.. py:class:: mindspore.LearningRateScheduler(learning_rate_function)
.. py:class:: mindspore.train.LearningRateScheduler(learning_rate_function)
用于在训练期间更改学习率。

View File

@ -1,7 +1,7 @@
mindspore.LossMonitor
================================
mindspore.train.LossMonitor
===========================
.. py:class:: mindspore.LossMonitor(per_print_times=1)
.. py:class:: mindspore.train.LossMonitor(per_print_times=1)
训练场景下监控训练的loss边训练边推理场景下监控训练的loss和推理的metrics。

View File

@ -1,7 +1,7 @@
mindspore.Model
================
mindspore.train.Model
======================
.. py:class:: mindspore.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level="O0", boost_level="O0", **kwargs)
.. py:class:: mindspore.train.Model(network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None, amp_level="O0", boost_level="O0", **kwargs)
模型训练或推理的高阶接口。 `Model` 会根据用户传入的参数封装可训练或推理的实例。

View File

@ -1,7 +1,7 @@
mindspore.ModelCheckpoint
================================
mindspore.train.ModelCheckpoint
===============================
.. py:class:: mindspore.ModelCheckpoint(prefix='CKP', directory=None, config=None)
.. py:class:: mindspore.train.ModelCheckpoint(prefix='CKP', directory=None, config=None)
checkpoint的回调函数。

View File

@ -1,7 +1,7 @@
mindspore.ReduceLROnPlateau
================================
mindspore.train.ReduceLROnPlateau
=================================
.. py:class:: mindspore.ReduceLROnPlateau(monitor='eval_loss', factor=0.1, patience=10, verbose=False, mode='auto', min_delta=1e-4, cooldown=0, min_lr=0)
.. py:class:: mindspore.train.ReduceLROnPlateau(monitor='eval_loss', factor=0.1, patience=10, verbose=False, mode='auto', min_delta=1e-4, cooldown=0, min_lr=0)
`monitor` 停止改进时降低学习率。

View File

@ -1,7 +1,7 @@
mindspore.RunContext
================================
mindspore.train.RunContext
==========================
.. py:class:: mindspore.RunContext(original_args)
.. py:class:: mindspore.train.RunContext(original_args)
保存和管理模型的相关信息。

View File

@ -1,7 +1,7 @@
mindspore.TimeMonitor
================================
mindspore.train.TimeMonitor
===========================
.. py:class:: mindspore.TimeMonitor(data_size=None)
.. py:class:: mindspore.train.TimeMonitor(data_size=None)
监控训练或推理的时间。

View File

@ -174,36 +174,6 @@ Context
mindspore.get_algo_parameters
mindspore.reset_algo_parameters
Model
-----
.. autosummary::
:toctree: mindspore
:nosignatures:
:template: classtemplate.rst
mindspore.Model
Callback
--------
.. autosummary::
:toctree: mindspore
:nosignatures:
:template: classtemplate.rst
mindspore.Callback
mindspore.CheckpointConfig
mindspore.EarlyStopping
mindspore.History
mindspore.LambdaCallback
mindspore.LearningRateScheduler
mindspore.LossMonitor
mindspore.ModelCheckpoint
mindspore.ReduceLROnPlateau
mindspore.RunContext
mindspore.TimeMonitor
Dataset Helper
---------------

View File

@ -0,0 +1,32 @@
mindspore.train
===============
Model
-----
.. autosummary::
:toctree: mindspore
:nosignatures:
:template: classtemplate.rst
mindspore.Model
Callback
--------
.. autosummary::
:toctree: mindspore
:nosignatures:
:template: classtemplate.rst
mindspore.Callback
mindspore.CheckpointConfig
mindspore.EarlyStopping
mindspore.History
mindspore.LambdaCallback
mindspore.LearningRateScheduler
mindspore.LossMonitor
mindspore.ModelCheckpoint
mindspore.ReduceLROnPlateau
mindspore.RunContext
mindspore.TimeMonitor

View File

@ -97,10 +97,10 @@ class Callback:
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore import dataset as ds
>>> class Print_info(ms.Callback):
>>> from mindspore.train import Model, Callback
>>> class Print_info(Callback):
... def step_end(self, run_context):
... cb_params = run_context.original_args()
... print("step_num: ", cb_params.cur_step_num)
@ -111,7 +111,7 @@ class Callback:
>>> net = nn.Dense(10, 5)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> model.train(1, dataset, callbacks=print_cb)
step_num: 2
"""

View File

@ -105,9 +105,9 @@ class CheckpointConfig:
ValueError: If input parameter is not the correct type.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.common.initializer import Normal
>>> from mindspore.train import Model, CheckpointConfig, ModelCheckpoint
>>>
>>> class LeNet5(nn.Cell):
... def __init__(self, num_class=10, num_channel=1):
@ -133,11 +133,11 @@ class CheckpointConfig:
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> data_path = './MNIST_Data'
>>> dataset = create_dataset(data_path)
>>> config = ms.CheckpointConfig(saved_network=net)
>>> ckpoint_cb = ms.ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config)
>>> config = CheckpointConfig(saved_network=net)
>>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config)
>>> model.train(10, dataset, callbacks=ckpoint_cb)
"""

View File

@ -82,8 +82,8 @@ class EarlyStopping(Callback):
ValueError: The monitor value is not a scalar.
Examples:
>>> from mindspore.train.callback import EarlyStopping
>>> from mindspore import Model, nn
>>> from mindspore import nn
>>> from mindspore.train import Model, EarlyStopping
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)

View File

@ -35,16 +35,16 @@ class History(Callback):
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn
>>> from mindspore.train import Model, History
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> net = nn.Dense(10, 5)
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> history_cb = ms.History()
>>> model = ms.Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
>>> history_cb = History()
>>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
>>> model.train(2, train_dataset, callbacks=[history_cb])
>>> print(history_cb.epoch)
>>> print(history_cb.history)

View File

@ -45,17 +45,17 @@ class LambdaCallback(Callback):
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> from mindspore import nn
>>> from mindspore.train import Model, LambdaCallback
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> net = nn.Dense(10, 5)
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> lambda_callback = ms.LambdaCallback(on_train_epoch_end=
>>> lambda_callback = LambdaCallback(on_train_epoch_end=
... lambda run_context: print("loss: ", run_context.original_args().net_outputs))
>>> model = ms.Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
>>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
>>> model.train(2, train_dataset, callbacks=[lambda_callback])
loss: 1.6127687
loss: 1.6106578

View File

@ -181,6 +181,7 @@ class SummaryLandscape:
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore.nn import Loss, Accuracy
>>> from mindspore.train import Model, SummaryCollector, SummaryLandscape
>>>
>>> if __name__ == '__main__':
... # If the device_target is Ascend, set the device_target to "Ascend"
@ -192,10 +193,10 @@ class SummaryLandscape:
... network = LeNet5(10)
... net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
... model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
... model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
... # Simple usage for collect landscape information:
... interval_1 = [1, 2, 3, 4, 5]
... summary_collector = ms.SummaryCollector(summary_dir='./summary/lenet_interval_1',
... summary_collector = SummaryCollector(summary_dir='./summary/lenet_interval_1',
... collect_specified_data={'collect_landscape':{"landscape_size": 4,
... "unit": "step",
... "create_landscape":{"train":True,
@ -215,7 +216,7 @@ class SummaryLandscape:
... ds_eval = create_dataset(mnist_dataset_dir, 32)
... return model, network, ds_eval, metrics
...
... summary_landscape = ms.SummaryLandscape('./summary/lenet_interval_1')
... summary_landscape = SummaryLandscape('./summary/lenet_interval_1')
... # parameters of collect_landscape can be modified or unchanged
... summary_landscape.gen_landscapes_with_multi_process(callback_fn,
... collect_landscape={"landscape_size": 4,

View File

@ -38,13 +38,13 @@ class LossMonitor(Callback):
ValueError: If per_print_times is not an integer or less than zero.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.train import Model, LossMonitor
>>>
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> data_path = './MNIST_Data'
>>> dataset = create_dataset(data_path)
>>> loss_monitor = LossMonitor()

View File

@ -34,10 +34,8 @@ class LearningRateScheduler(Callback):
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore import LearningRateScheduler
>>> import mindspore.nn as nn
>>> from mindspore.train import Model, LearningRateScheduler
>>> from mindspore import dataset as ds
...
>>> def learning_rate_function(lr, cur_step_num):
@ -50,7 +48,7 @@ class LearningRateScheduler(Callback):
>>> net = nn.Dense(10, 5)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
...
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> dataset = ds.NumpySlicesDataset(data=data).batch(32)

View File

@ -81,8 +81,8 @@ class ReduceLROnPlateau(Callback):
ValueError: The learning rate is not a Parameter.
Examples:
>>> from mindspore.train.callback import ReduceLROnPlateau
>>> from mindspore import Model, nn
>>> from mindspore import nn
>>> from mindspore.train import Model, ReduceLROnPlateau
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)

View File

@ -177,6 +177,7 @@ class SummaryCollector(Callback):
Examples:
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore.train import Model, SummaryCollector
>>> from mindspore.nn import Accuracy
>>>
>>> if __name__ == '__main__':
@ -189,15 +190,15 @@ class SummaryCollector(Callback):
... network = LeNet5(10)
... net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
... model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
... model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")
...
... # Simple usage:
... summary_collector = ms.SummaryCollector(summary_dir='./summary_dir')
... summary_collector = SummaryCollector(summary_dir='./summary_dir')
... model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=False)
...
... # Do not collect metric and collect the first layer parameter, others are collected by default
... specified={'collect_metric': False, 'histogram_regular': '^conv1.*'}
... summary_collector = ms.SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified)
... summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified)
... model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=False)
"""

View File

@ -34,13 +34,13 @@ class TimeMonitor(Callback):
ValueError: If data_size is not positive int.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.train import Model, TimeMonitor
>>>
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> data_path = './MNIST_Data'
>>> dataset = create_dataset(data_path)
>>> time_monitor = TimeMonitor()

View File

@ -162,8 +162,8 @@ class Model:
the Graph mode + Ascend platform, and for better acceleration, refer to the documentation to configure
boost_config_dict.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.train import Model
>>>
>>> class Net(nn.Cell):
... def __init__(self, num_class=10, num_channel=1):
@ -189,7 +189,7 @@ class Model:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>> # For details about how to build the dataset, please refer to the variable `dataset_train` in tutorial
>>> # document on the official website:
>>> # https://www.mindspore.cn/tutorials/zh-CN/master/beginner/quick_start.html
@ -989,8 +989,8 @@ class Model:
Default: 0.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.train import Model
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
@ -999,7 +999,7 @@ class Model:
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = ms.FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics=None,
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
... loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
@ -1126,8 +1126,8 @@ class Model:
Default: 0.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.train import Model
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
@ -1136,7 +1136,7 @@ class Model:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
>>> model.fit(2, train_dataset, valid_dataset)
"""
@ -1211,17 +1211,18 @@ class Model:
epoch (int): Control the training epochs. Default: 1.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.train import Model
>>> from mindspore.amp import FixedLossScaleManager
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = ms.FixedLossScaleManager()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics=None,
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
... loss_scale_manager=loss_scale_manager)
>>> model.build(dataset, epoch=2)
>>> model.train(2, dataset)
@ -1380,15 +1381,15 @@ class Model:
the model in the test mode.
Examples:
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore.train import Model
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = ms.Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> acc = model.eval(dataset, dataset_sink_mode=False)
"""
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
@ -1451,11 +1452,12 @@ class Model:
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindspore.train import Model
>>>
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = ms.Model(Net())
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), mindspore.float32)
>>> model = Model(Net())
>>> result = model.predict(input_data)
"""
self._check_network_mode(self._predict_network, False)
@ -1537,6 +1539,7 @@ class Model:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, nn
>>> from mindspore.train import Model
>>> from mindspore.communication import init
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
@ -1550,7 +1553,7 @@ class Model:
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = ms.FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = ms.Model(net, loss_fn=loss, optimizer=optim, metrics=None,
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
... loss_scale_manager=loss_scale_manager)
>>> layout_dict = model.infer_train_layout(dataset)
"""
@ -1595,13 +1598,14 @@ class Model:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore.train import Model
>>> from mindspore.communication import init
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> ms.set_auto_parallel_context(full_batch=True, parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = ms.Model(Net())
>>> model = Model(Net())
>>> predict_map = model.infer_predict_layout(input_data)
"""
if context.get_context("mode") != context.GRAPH_MODE: