!12715 Add TimeMonitor to DeepSpeech and Refresh Support list for ResizeBilinear and ResizeNearestNeighbor

From: @wanyiming
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-08 17:29:49 +08:00 committed by Gitee
commit faaaa79314
4 changed files with 5 additions and 5 deletions

View File

@ -709,7 +709,7 @@ class ResizeBilinear(Cell):
ValueError: If `size` is a list or tuple whose length is not equal to 2.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> tensor = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32)

View File

@ -3297,7 +3297,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
ValueError: If length of `size` is not equal to 2.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_tensor = Tensor(np.array([[[[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]]]), mindspore.float32)

View File

@ -3151,7 +3151,7 @@ class ResizeBilinear(PrimitiveWithInfer):
ValueError: If length of shape of `input` is not equal to 4.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)

View File

@ -21,7 +21,7 @@ import argparse
from mindspore import context, Tensor, ParameterTuple
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Adam
from mindspore.nn import TrainOneStepCell
@ -89,7 +89,7 @@ if __name__ == '__main__':
print('Successfully loading the pre-trained model')
model = Model(train_net)
callback_list = [LossMonitor()]
callback_list = [TimeMonitor(steps_size), LossMonitor()]
if args.is_distributed:
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())