support amp when model eval, fix example of UnsortSegmentsSum

This commit is contained in:
Wei Luning 2020-04-09 23:37:29 +08:00
parent c478be0ff0
commit 2fecdede6b
6 changed files with 59 additions and 78 deletions

View File

@ -636,6 +636,15 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
// Dealing with the RefKey case
auto refkeys = cnode_with_refkeys.second;
auto cnode = cnode_with_refkeys.first;
auto cnode_ptr = cnode->cast<CNodePtr>();
if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) {
continue;
}
if (!IsAutoParallelCareNode(cnode_ptr)) {
continue;
}
if (refkeys.size() > 1) {
MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys.";
}

View File

@ -1235,10 +1235,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Examples:
>>> input_x = [1, 2, 3, 4]
>>> segment_ids = [0, 0, 1, 2]
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> type = P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
[3, 3, 4, 0]
"""
@prim_attr_register

View File

@ -22,6 +22,8 @@ from functools import reduce
import numpy as np
from ... import context
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel, check_bool, check_int_positive
from ...common import dtype as mstype
@ -1297,29 +1299,31 @@ class ApplyMomentum(PrimitiveWithInfer):
filter(lambda x: x.requires_grad, net.get_parameters()))
>>> model = Model(net, loss, opt)
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
)
@prim_attr_register
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
outputs=['output'])
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
validator.check(f'variable shape {v_shape}', len(v_shape), '', 0, Rel.GT)
validator.check(f'accumulation shape {a_shape}', len(a_shape), '', 0, Rel.GT)
validator.check(f'learning rate shape {l_shape}', len(l_shape), '', 0, Rel.GE)
validator.check(f'gradient shape {g_shape}', len(g_shape), '', 0, Rel.GE)
validator.check(f'momentum shape {m_shape}', len(m_shape), '', 0, Rel.GE)
return v_shape
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
validator.check_subclass("v_dtype", v_dtype, mstype.tensor)
validator.check_subclass("a_dtype", a_dtype, mstype.tensor)
v_type = validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64])
validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64])
if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
validator.check_subclass("v_dtype", v_dtype, mstype.tensor)
validator.check_subclass("a_dtype", a_dtype, mstype.tensor)
validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64])
validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64])
validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64])
validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64])
validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64])
return v_type
return g_dtype
class SmoothL1Loss(PrimitiveWithInfer):

View File

@ -82,6 +82,29 @@ def _check_kwargs(key_words):
if loss_scale_manager:
validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager)
def _add_loss_network(network, loss_fn, cast_model_type):
class WithLossCell(nn.Cell):
"Wrap loss for amp. Cast network output back to float32"
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, data, label):
out = self._backbone(data)
label = _mp_cast_helper(mstype.float32, label)
return self._loss_fn(F.cast(out, mstype.float32), label)
validator.check_isinstance('loss_fn', loss_fn, nn.Cell)
if cast_model_type == mstype.float16:
network = WithLossCell(network, loss_fn)
else:
network = nn.WithLossCell(network, loss_fn)
return network
def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
"""
Build the mixed precision training cell automatically.
@ -117,24 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
_do_keep_batchnorm_fp32(network)
if loss_fn:
class WithLossCell(nn.Cell):
"Wrap loss for amp. Cast network output back to float32"
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, data, label):
out = self._backbone(data)
label = _mp_cast_helper(mstype.float32, label)
return self._loss_fn(F.cast(out, mstype.float32), label)
validator.check_isinstance('loss_fn', loss_fn, nn.Cell)
if config.cast_model_type == mstype.float16:
network = WithLossCell(network, loss_fn)
else:
network = nn.WithLossCell(network, loss_fn)
network = _add_loss_network(network, loss_fn, config.cast_model_type)
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network = _VirtualDatasetCell(network)

View File

@ -24,8 +24,7 @@ from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
from ..nn.metrics import Loss
from ..nn.wrap import WithLossCell, WithEvalCell, \
DataWrapper
from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ..common import dtype as mstype
@ -151,7 +150,10 @@ class Model:
else:
if self._loss_fn is None:
raise ValueError("loss_fn can not be None.")
self._eval_network = WithEvalCell(self._network, self._loss_fn)
if self._optimizer:
self._eval_network = self._train_network.network
else:
self._eval_network = WithEvalCell(self._network, self._loss_fn)
self._eval_indexes = [0, 1, 2]
def _clear_metrics(self):

View File

@ -21,47 +21,6 @@ from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore import Parameter, ParameterTuple
run_opt = C.MultitypeFuncGraph("run_opt")
# pylint: disable=unused-argument
@run_opt.register("Function", "Int", "Number", "Number",
"Tensor", "Tensor", "Tensor")
def tensor_run_opt(opt, iterator, learning_rate, momentum,
gradient, variable, moment):
success = True
new_weight = opt(gradient, moment, variable, learning_rate, momentum)
success = F.depend(success, P.Assign()(variable, new_weight))
return success
class OptimizerByMomentum(nn.Cell):
"""
OptimizerByMomentum definition
"""
# list of tensor
def __init__(self, weights):
super(OptimizerByMomentum, self).__init__()
self.learning_rate = Parameter(0.1, name="learning_rate")
self.momentum = Parameter(0.05, name="momentum")
self.iter = Parameter(0, name="iter")
self.weights = weights
self.moments = weights.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
def construct(self, grads):
success = True
weights = self.weights
moments = self.moments
success = self.hyper_map(
F.partial(run_opt, self.opt, self.iter,
self.learning_rate, self.momentum), grads, weights, moments)
# self.learning_rate = updata_lr(self.learning_rate, self.momentum)
return success
class TrainStepWrap(nn.Cell):
"""
TrainStepWrap definition
@ -71,7 +30,7 @@ class TrainStepWrap(nn.Cell):
self.network = network
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = OptimizerByMomentum(self.weights)
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation('grad', get_by_list=True)
@ -107,7 +66,7 @@ class TrainStepWrap2(nn.Cell):
self.network = network
self.network.set_train()
self.weights = ParameterTuple(network.get_parameters())
self.optimizer = OptimizerByMomentum(self.weights)
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens