forked from mindspore-Ecosystem/mindspore
support amp when model eval, fix example of UnsortSegmentsSum
This commit is contained in:
parent
c478be0ff0
commit
2fecdede6b
|
@ -636,6 +636,15 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
// Dealing with the RefKey case
|
||||
auto refkeys = cnode_with_refkeys.second;
|
||||
auto cnode = cnode_with_refkeys.first;
|
||||
|
||||
auto cnode_ptr = cnode->cast<CNodePtr>();
|
||||
if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) {
|
||||
continue;
|
||||
}
|
||||
if (!IsAutoParallelCareNode(cnode_ptr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (refkeys.size() > 1) {
|
||||
MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys.";
|
||||
}
|
||||
|
|
|
@ -1235,10 +1235,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> input_x = [1, 2, 3, 4]
|
||||
>>> segment_ids = [0, 0, 1, 2]
|
||||
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
|
||||
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
|
||||
>>> num_segments = 4
|
||||
>>> type = P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
|
||||
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
|
||||
[3, 3, 4, 0]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -22,6 +22,8 @@ from functools import reduce
|
|||
import numpy as np
|
||||
|
||||
from ... import context
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Rel, check_bool, check_int_positive
|
||||
from ...common import dtype as mstype
|
||||
|
@ -1297,29 +1299,31 @@ class ApplyMomentum(PrimitiveWithInfer):
|
|||
filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
>>> model = Model(net, loss, opt)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
|
||||
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
|
||||
)
|
||||
@prim_attr_register
|
||||
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
||||
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
||||
outputs=['output'])
|
||||
|
||||
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
|
||||
validator.check(f'variable shape {v_shape}', len(v_shape), '', 0, Rel.GT)
|
||||
validator.check(f'accumulation shape {a_shape}', len(a_shape), '', 0, Rel.GT)
|
||||
validator.check(f'learning rate shape {l_shape}', len(l_shape), '', 0, Rel.GE)
|
||||
validator.check(f'gradient shape {g_shape}', len(g_shape), '', 0, Rel.GE)
|
||||
validator.check(f'momentum shape {m_shape}', len(m_shape), '', 0, Rel.GE)
|
||||
return v_shape
|
||||
|
||||
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
|
||||
validator.check_subclass("v_dtype", v_dtype, mstype.tensor)
|
||||
validator.check_subclass("a_dtype", a_dtype, mstype.tensor)
|
||||
v_type = validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
||||
validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
||||
if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
|
||||
validator.check_subclass("v_dtype", v_dtype, mstype.tensor)
|
||||
validator.check_subclass("a_dtype", a_dtype, mstype.tensor)
|
||||
validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
||||
validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
||||
validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
||||
validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
||||
validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64])
|
||||
return v_type
|
||||
return g_dtype
|
||||
|
||||
|
||||
class SmoothL1Loss(PrimitiveWithInfer):
|
||||
|
|
|
@ -82,6 +82,29 @@ def _check_kwargs(key_words):
|
|||
if loss_scale_manager:
|
||||
validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager)
|
||||
|
||||
|
||||
def _add_loss_network(network, loss_fn, cast_model_type):
|
||||
class WithLossCell(nn.Cell):
|
||||
"Wrap loss for amp. Cast network output back to float32"
|
||||
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, data, label):
|
||||
out = self._backbone(data)
|
||||
label = _mp_cast_helper(mstype.float32, label)
|
||||
return self._loss_fn(F.cast(out, mstype.float32), label)
|
||||
|
||||
validator.check_isinstance('loss_fn', loss_fn, nn.Cell)
|
||||
if cast_model_type == mstype.float16:
|
||||
network = WithLossCell(network, loss_fn)
|
||||
else:
|
||||
network = nn.WithLossCell(network, loss_fn)
|
||||
return network
|
||||
|
||||
|
||||
def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
||||
"""
|
||||
Build the mixed precision training cell automatically.
|
||||
|
@ -117,24 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|||
_do_keep_batchnorm_fp32(network)
|
||||
|
||||
if loss_fn:
|
||||
class WithLossCell(nn.Cell):
|
||||
"Wrap loss for amp. Cast network output back to float32"
|
||||
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, data, label):
|
||||
out = self._backbone(data)
|
||||
label = _mp_cast_helper(mstype.float32, label)
|
||||
return self._loss_fn(F.cast(out, mstype.float32), label)
|
||||
|
||||
validator.check_isinstance('loss_fn', loss_fn, nn.Cell)
|
||||
if config.cast_model_type == mstype.float16:
|
||||
network = WithLossCell(network, loss_fn)
|
||||
else:
|
||||
network = nn.WithLossCell(network, loss_fn)
|
||||
network = _add_loss_network(network, loss_fn, config.cast_model_type)
|
||||
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network = _VirtualDatasetCell(network)
|
||||
|
|
|
@ -24,8 +24,7 @@ from .. import context
|
|||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
|
||||
from ..nn.metrics import Loss
|
||||
from ..nn.wrap import WithLossCell, WithEvalCell, \
|
||||
DataWrapper
|
||||
from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from .parallel_utils import ParallelMode
|
||||
from ..common import dtype as mstype
|
||||
|
@ -151,7 +150,10 @@ class Model:
|
|||
else:
|
||||
if self._loss_fn is None:
|
||||
raise ValueError("loss_fn can not be None.")
|
||||
self._eval_network = WithEvalCell(self._network, self._loss_fn)
|
||||
if self._optimizer:
|
||||
self._eval_network = self._train_network.network
|
||||
else:
|
||||
self._eval_network = WithEvalCell(self._network, self._loss_fn)
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
def _clear_metrics(self):
|
||||
|
|
|
@ -21,47 +21,6 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
|
||||
|
||||
run_opt = C.MultitypeFuncGraph("run_opt")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@run_opt.register("Function", "Int", "Number", "Number",
|
||||
"Tensor", "Tensor", "Tensor")
|
||||
def tensor_run_opt(opt, iterator, learning_rate, momentum,
|
||||
gradient, variable, moment):
|
||||
success = True
|
||||
new_weight = opt(gradient, moment, variable, learning_rate, momentum)
|
||||
success = F.depend(success, P.Assign()(variable, new_weight))
|
||||
return success
|
||||
|
||||
|
||||
class OptimizerByMomentum(nn.Cell):
|
||||
"""
|
||||
OptimizerByMomentum definition
|
||||
"""
|
||||
# list of tensor
|
||||
def __init__(self, weights):
|
||||
super(OptimizerByMomentum, self).__init__()
|
||||
self.learning_rate = Parameter(0.1, name="learning_rate")
|
||||
self.momentum = Parameter(0.05, name="momentum")
|
||||
self.iter = Parameter(0, name="iter")
|
||||
|
||||
self.weights = weights
|
||||
self.moments = weights.clone(prefix="moments", init='zeros')
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
|
||||
def construct(self, grads):
|
||||
success = True
|
||||
weights = self.weights
|
||||
moments = self.moments
|
||||
success = self.hyper_map(
|
||||
F.partial(run_opt, self.opt, self.iter,
|
||||
self.learning_rate, self.momentum), grads, weights, moments)
|
||||
# self.learning_rate = updata_lr(self.learning_rate, self.momentum)
|
||||
return success
|
||||
|
||||
class TrainStepWrap(nn.Cell):
|
||||
"""
|
||||
TrainStepWrap definition
|
||||
|
@ -71,7 +30,7 @@ class TrainStepWrap(nn.Cell):
|
|||
self.network = network
|
||||
self.network.set_train()
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = OptimizerByMomentum(self.weights)
|
||||
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.grad = C.GradOperation('grad', get_by_list=True)
|
||||
|
||||
|
@ -107,7 +66,7 @@ class TrainStepWrap2(nn.Cell):
|
|||
self.network = network
|
||||
self.network.set_train()
|
||||
self.weights = ParameterTuple(network.get_parameters())
|
||||
self.optimizer = OptimizerByMomentum(self.weights)
|
||||
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
|
Loading…
Reference in New Issue