diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc index 0c8d691e776..5f93d5ee8c9 100755 --- a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc @@ -29,6 +29,18 @@ namespace mindspore { namespace parallel { +const std::set END_NODE_BLACK_LIST = {prim::kPrimDepend, prim::kPrimTupleGetItem, + prim::kPrimSoftmaxCrossEntropyWithLogits}; + +static bool IsInEndNodeBlackList(const CNodePtr &cnode) { + for (auto &prim : END_NODE_BLACK_LIST) { + if (IsPrimitiveCNode(cnode, prim)) { + return true; + } + } + return false; +} + AnfNodePtr FindAccuGrad(const CNodePtr &cnode) { auto pre_node = cnode->input(1); while (true) { @@ -392,7 +404,7 @@ void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, con AnfNodePtr GetPreNode(const AnfNodePtr &node) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + if (IsInEndNodeBlackList(cnode)) { return GetPreNode(cnode->input(1)); } return cnode; diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index f91e3cf55ac..d4c6f63347a 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -134,6 +134,11 @@ void PipelineTransformer::LabelMicroBatch() { for (auto &node_user : node_users) { if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { auto data_users = manager_->node_users()[node_user.first]; + auto node_first = data_users.front().first; + if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) { + data_users.clear(); + data_users = node_user_map[node_first]; + } auto micro_size = int64_t(data_users.size()); micro_size_ = micro_size; MS_LOG(INFO) << "Micro Size is: " << micro_size; @@ -690,7 +695,10 @@ std::pair, std::vector> PipelineTransformer: auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND); std::vector receive_ops; std::vector send_ops; - auto all_nodes = graph->nodes(); + auto ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + std::reverse(all_nodes.begin(), all_nodes.end()); auto stage_num = g_device_manager->stage_num(); if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) { MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index a207c7de716..ccb16a5ebc5 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -3045,6 +3045,9 @@ void CheckpointStrategy(const std::vector &all_nodes, const FuncGrap } std::string cloned_param_name = cloned_parameter_node->cast()->name(); auto cloned_param_layout = cloned_parameter_node->user_data(); + if (cloned_param_layout == nullptr) { + continue; + } tensor_info_map[cloned_param_name] = cloned_param_layout; } if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index 2e32258712a..71f6eb767c2 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -161,6 +161,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf } for (auto &node_tensor_info : tensor_info_map) { TensorLayoutPtr tensor_layout = node_tensor_info.second; + MS_EXCEPTION_IF_NULL(tensor_layout); straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); MS_EXCEPTION_IF_NULL(parallel_layout_item); parallel_layout_item->set_param_name(node_tensor_info.first); diff --git a/mindspore/nn/wrap/__init__.py b/mindspore/nn/wrap/__init__.py index 63db4437de8..a663fbbf4fc 100644 --- a/mindspore/nn/wrap/__init__.py +++ b/mindspore/nn/wrap/__init__.py @@ -18,7 +18,7 @@ Wrap cells for networks. Use the Wrapper to combine the loss or build the training steps. """ from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \ - ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple + ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, PipelineCell from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell from .grad_reducer import DistributedGradReducer from ..layer.timedistributed import TimeDistributed @@ -29,6 +29,7 @@ __all__ = [ "TrainOneStepCell", "WithLossCell", "WithGradCell", + "PipelineCell", "WithEvalCell", "GetNextSingleOp", "TrainOneStepWithLossScaleCell", diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index fed7d040e01..735ef2edcec 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -15,6 +15,7 @@ """Loss scale cell for loss scale training.""" import mindspore.context as context from mindspore.context import ParallelMode +from mindspore.parallel._utils import _get_enable_parallel_optimizer from .cell_wrapper import TrainOneStepCell from ..cell import Cell from ...common import Tensor, RowTensor @@ -430,3 +431,100 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): if self.loss_scaling_manager is not None: return self.loss_scaling_manager(self.scale_sense, overflow) return overflow + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor", "Tensor") +def tensor_grad_scale_pipeline(scale, grad, accu_grad): + accu_grad = F.depend(accu_grad, grad) + new_grad = accu_grad * reciprocal(scale) + accu_grad = F.depend(accu_grad, new_grad) + zeros = F.tensor_mul(accu_grad, 0.0) + new_grad = F.depend(new_grad, F.assign(accu_grad, zeros)) + return new_grad + + +@shard_grad_scale.register("Tensor", "Tensor", "Tensor") +def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad): + new_grad = grad * reciprocal(scale) + accu_grad = F.depend(accu_grad, new_grad) + new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad))) + return new_grad + + +class _TrainPipelineWithLossScaleCell(TrainOneStepCell): + """ + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_sense (Cell): Cell to do the loss scale. + """ + def __init__(self, network, optimizer, scale_sense): + super(_TrainPipelineWithLossScaleCell, self).__init__(network, optimizer, sens=None) + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.grad_reducer = F.identity + self.degree = 1 + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.reshape = P.Reshape() + self.loss_scaling_manager = None + if isinstance(scale_sense, Cell): + self.loss_scaling_manager = scale_sense + self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), + name="scale_sense") + elif isinstance(scale_sense, Tensor): + if scale_sense.shape == (1,) or scale_sense.shape == (): + self.scale_sense = Parameter(scale_sense, name='scale_sense') + else: + raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape)) + else: + raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) + self.opt_shard = _get_enable_parallel_optimizer() + + def construct(self, *inputs): + weights = self.weights + loss = self.network(*inputs) + scaling_sens = self.scale_sense + init = self.alloc_status() + status_clear = self.clear_before_grad(init) + scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) + grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) + flag_sum = self.reduce_sum(init, (0,)) + loss = F.depend(loss, status_clear) + if self.opt_shard: + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads) + else: + accu_grads = self.grad_reducer(self.accu_grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if self.loss_scaling_manager is not None: + overflow = self.loss_scaling_manager(self.scale_sense, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, overflow, scaling_sens) + return F.depend(ret, succ) diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 46dd79339d2..61e69fce2a0 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -19,6 +19,7 @@ from .._checkparam import Rel from ..common import dtype as mstype from ..nn import acc from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell +from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell from ..ops import functional as F from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager @@ -184,8 +185,12 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): raise ValueError("Only `loss_scale_manager=None` or " "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" "are supported on device `CPU`. ") - network = nn.TrainOneStepWithLossScaleCell(network, optimizer, - scale_sense=update_cell).set_train() + if _get_pipeline_stages() > 1: + network = _TrainPipelineWithLossScaleCell(network, optimizer, + scale_sense=update_cell).set_train() + else: + network = nn.TrainOneStepWithLossScaleCell(network, optimizer, + scale_sense=update_cell).set_train() return network if _get_pipeline_stages() > 1: network = _TrainPipelineAccuStepCell(network, optimizer).set_train()