!12715 Add TimeMonitor to DeepSpeech and Refresh Support list for ResizeBilinear and ResizeNearestNeighbor
From: @wanyiming Reviewed-by: Signed-off-by:
This commit is contained in:
commit
faaaa79314
|
@ -709,7 +709,7 @@ class ResizeBilinear(Cell):
|
|||
ValueError: If `size` is a list or tuple whose length is not equal to 2.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> tensor = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32)
|
||||
|
|
|
@ -3297,7 +3297,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
|
|||
ValueError: If length of `size` is not equal to 2.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.array([[[[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]]]), mindspore.float32)
|
||||
|
|
|
@ -3151,7 +3151,7 @@ class ResizeBilinear(PrimitiveWithInfer):
|
|||
ValueError: If length of shape of `input` is not equal to 4.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
|
||||
|
|
|
@ -21,7 +21,7 @@ import argparse
|
|||
from mindspore import context, Tensor, ParameterTuple
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
|
@ -89,7 +89,7 @@ if __name__ == '__main__':
|
|||
print('Successfully loading the pre-trained model')
|
||||
|
||||
model = Model(train_net)
|
||||
callback_list = [LossMonitor()]
|
||||
callback_list = [TimeMonitor(steps_size), LossMonitor()]
|
||||
|
||||
if args.is_distributed:
|
||||
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())
|
||||
|
|
Loading…
Reference in New Issue