diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc
index fe6be575eea..3f1e18183ac 100644
--- a/mindspore/ccsrc/parallel/step_auto_parallel.cc
+++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc
@@ -636,6 +636,15 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
       // Dealing with the RefKey case
       auto refkeys = cnode_with_refkeys.second;
       auto cnode = cnode_with_refkeys.first;
+
+      auto cnode_ptr = cnode->cast<CNodePtr>();
+      if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) {
+        continue;
+      }
+      if (!IsAutoParallelCareNode(cnode_ptr)) {
+        continue;
+      }
+
       if (refkeys.size() > 1) {
         MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys.";
       }
diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py
index 850e895ad01..a7c3f504404 100644
--- a/mindspore/ops/operations/array_ops.py
+++ b/mindspore/ops/operations/array_ops.py
@@ -1235,10 +1235,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
         Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
 
     Examples:
-        >>> input_x = [1, 2, 3, 4]
-        >>> segment_ids = [0, 0, 1, 2]
+        >>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
+        >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
         >>> num_segments = 4
-        >>> type = P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
+        >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
+        [3, 3, 4, 0]
     """
 
     @prim_attr_register
diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py
index 91f6d7ec01d..acccfbaba39 100644
--- a/mindspore/ops/operations/nn_ops.py
+++ b/mindspore/ops/operations/nn_ops.py
@@ -22,6 +22,8 @@ from functools import reduce
 import numpy as np
 
 from ... import context
+from ..._c_expression import signature_rw as sig_rw
+from ..._c_expression import signature_kind as sig_kind
 from ..._checkparam import ParamValidator as validator
 from ..._checkparam import Rel, check_bool, check_int_positive
 from ...common import dtype as mstype
@@ -1297,29 +1299,31 @@ class ApplyMomentum(PrimitiveWithInfer):
                                 filter(lambda x: x.requires_grad, net.get_parameters()))
         >>> model = Model(net, loss, opt)
     """
-
+    __mindspore_signature__ = (
+        ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
+        ('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
+        ('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
+        ('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
+        ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
+    )
     @prim_attr_register
     def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
         self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
                                 outputs=['output'])
 
     def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
-        validator.check(f'variable shape {v_shape}', len(v_shape), '', 0, Rel.GT)
-        validator.check(f'accumulation shape {a_shape}', len(a_shape), '', 0, Rel.GT)
-        validator.check(f'learning rate shape {l_shape}', len(l_shape), '', 0, Rel.GE)
-        validator.check(f'gradient shape {g_shape}', len(g_shape), '', 0, Rel.GE)
-        validator.check(f'momentum shape {m_shape}', len(m_shape), '', 0, Rel.GE)
         return v_shape
 
     def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
-        validator.check_subclass("v_dtype", v_dtype, mstype.tensor)
-        validator.check_subclass("a_dtype", a_dtype, mstype.tensor)
-        v_type = validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64])
-        validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64])
+        if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
+            validator.check_subclass("v_dtype", v_dtype, mstype.tensor)
+            validator.check_subclass("a_dtype", a_dtype, mstype.tensor)
+            validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64])
+            validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64])
         validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64])
         validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64])
         validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64])
-        return v_type
+        return g_dtype
 
 
 class SmoothL1Loss(PrimitiveWithInfer):
diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py
index e909b44e409..c4c115ef27f 100644
--- a/mindspore/train/amp.py
+++ b/mindspore/train/amp.py
@@ -82,6 +82,29 @@ def _check_kwargs(key_words):
         if loss_scale_manager:
             validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager)
 
+
+def _add_loss_network(network, loss_fn, cast_model_type):
+    class WithLossCell(nn.Cell):
+        "Wrap loss for amp. Cast network output back to float32"
+
+        def __init__(self, backbone, loss_fn):
+            super(WithLossCell, self).__init__(auto_prefix=False)
+            self._backbone = backbone
+            self._loss_fn = loss_fn
+
+        def construct(self, data, label):
+            out = self._backbone(data)
+            label = _mp_cast_helper(mstype.float32, label)
+            return self._loss_fn(F.cast(out, mstype.float32), label)
+
+    validator.check_isinstance('loss_fn', loss_fn, nn.Cell)
+    if cast_model_type == mstype.float16:
+        network = WithLossCell(network, loss_fn)
+    else:
+        network = nn.WithLossCell(network, loss_fn)
+    return network
+
+
 def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
     """
     Build the mixed precision training cell automatically.
@@ -117,24 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
             _do_keep_batchnorm_fp32(network)
 
     if loss_fn:
-        class WithLossCell(nn.Cell):
-            "Wrap loss for amp. Cast network output back to float32"
-
-            def __init__(self, backbone, loss_fn):
-                super(WithLossCell, self).__init__(auto_prefix=False)
-                self._backbone = backbone
-                self._loss_fn = loss_fn
-
-            def construct(self, data, label):
-                out = self._backbone(data)
-                label = _mp_cast_helper(mstype.float32, label)
-                return self._loss_fn(F.cast(out, mstype.float32), label)
-
-        validator.check_isinstance('loss_fn', loss_fn, nn.Cell)
-        if config.cast_model_type == mstype.float16:
-            network = WithLossCell(network, loss_fn)
-        else:
-            network = nn.WithLossCell(network, loss_fn)
+        network = _add_loss_network(network, loss_fn, config.cast_model_type)
 
     if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
         network = _VirtualDatasetCell(network)
diff --git a/mindspore/train/model.py b/mindspore/train/model.py
index 833fb07256b..a1acec859c3 100755
--- a/mindspore/train/model.py
+++ b/mindspore/train/model.py
@@ -24,8 +24,7 @@ from .. import context
 from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
     _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
 from ..nn.metrics import Loss
-from ..nn.wrap import WithLossCell, WithEvalCell, \
-    DataWrapper
+from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell
 from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
 from .parallel_utils import ParallelMode
 from ..common import dtype as mstype
@@ -151,7 +150,10 @@ class Model:
         else:
             if self._loss_fn is None:
                 raise ValueError("loss_fn can not be None.")
-            self._eval_network = WithEvalCell(self._network, self._loss_fn)
+            if self._optimizer:
+                self._eval_network = self._train_network.network
+            else:
+                self._eval_network = WithEvalCell(self._network, self._loss_fn)
             self._eval_indexes = [0, 1, 2]
 
     def _clear_metrics(self):
diff --git a/tests/train_step_wrap.py b/tests/train_step_wrap.py
index 7289c010044..d48e25b8371 100644
--- a/tests/train_step_wrap.py
+++ b/tests/train_step_wrap.py
@@ -21,47 +21,6 @@ from mindspore.ops import composite as C
 from mindspore.ops import operations as P
 from mindspore import Parameter, ParameterTuple
 
-
-run_opt = C.MultitypeFuncGraph("run_opt")
-
-# pylint: disable=unused-argument
-@run_opt.register("Function", "Int", "Number", "Number",
-                  "Tensor", "Tensor", "Tensor")
-def tensor_run_opt(opt, iterator, learning_rate, momentum,
-                   gradient, variable, moment):
-    success = True
-    new_weight = opt(gradient, moment, variable, learning_rate, momentum)
-    success = F.depend(success, P.Assign()(variable, new_weight))
-    return success
-
-
-class OptimizerByMomentum(nn.Cell):
-    """
-    OptimizerByMomentum definition
-    """
-    # list of tensor
-    def __init__(self, weights):
-        super(OptimizerByMomentum, self).__init__()
-        self.learning_rate = Parameter(0.1, name="learning_rate")
-        self.momentum = Parameter(0.05, name="momentum")
-        self.iter = Parameter(0, name="iter")
-
-        self.weights = weights
-        self.moments = weights.clone(prefix="moments", init='zeros')
-
-        self.hyper_map = C.HyperMap()
-        self.opt = P.ApplyMomentum()
-
-    def construct(self, grads):
-        success = True
-        weights = self.weights
-        moments = self.moments
-        success = self.hyper_map(
-            F.partial(run_opt, self.opt, self.iter,
-                      self.learning_rate, self.momentum), grads, weights, moments)
-        # self.learning_rate = updata_lr(self.learning_rate, self.momentum)
-        return success
-
 class TrainStepWrap(nn.Cell):
     """
     TrainStepWrap definition
@@ -71,7 +30,7 @@ class TrainStepWrap(nn.Cell):
         self.network = network
         self.network.set_train()
         self.weights = ParameterTuple(network.trainable_params())
-        self.optimizer = OptimizerByMomentum(self.weights)
+        self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
         self.hyper_map = C.HyperMap()
         self.grad = C.GradOperation('grad', get_by_list=True)
 
@@ -107,7 +66,7 @@ class TrainStepWrap2(nn.Cell):
         self.network = network
         self.network.set_train()
         self.weights = ParameterTuple(network.get_parameters())
-        self.optimizer = OptimizerByMomentum(self.weights)
+        self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
         self.hyper_map = C.HyperMap()
         self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
         self.sens = sens