!19848 [AutoParallel]Fix_auto_parallel_bug_master

Merge pull request !19848 from lichen/fix_auto_parallel_bug_master
This commit is contained in:
i-robot 2021-07-10 03:34:57 +00:00 committed by Gitee
commit 1260700865
7 changed files with 133 additions and 5 deletions

View File

@ -29,6 +29,18 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {prim::kPrimDepend, prim::kPrimTupleGetItem,
prim::kPrimSoftmaxCrossEntropyWithLogits};
static bool IsInEndNodeBlackList(const CNodePtr &cnode) {
for (auto &prim : END_NODE_BLACK_LIST) {
if (IsPrimitiveCNode(cnode, prim)) {
return true;
}
}
return false;
}
AnfNodePtr FindAccuGrad(const CNodePtr &cnode) { AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
auto pre_node = cnode->input(1); auto pre_node = cnode->input(1);
while (true) { while (true) {
@ -392,7 +404,7 @@ void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, con
AnfNodePtr GetPreNode(const AnfNodePtr &node) { AnfNodePtr GetPreNode(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (IsPrimitiveCNode(node, prim::kPrimDepend)) { if (IsInEndNodeBlackList(cnode)) {
return GetPreNode(cnode->input(1)); return GetPreNode(cnode->input(1));
} }
return cnode; return cnode;

View File

@ -134,6 +134,11 @@ void PipelineTransformer::LabelMicroBatch() {
for (auto &node_user : node_users) { for (auto &node_user : node_users) {
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
auto data_users = manager_->node_users()[node_user.first]; auto data_users = manager_->node_users()[node_user.first];
auto node_first = data_users.front().first;
if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) {
data_users.clear();
data_users = node_user_map[node_first];
}
auto micro_size = int64_t(data_users.size()); auto micro_size = int64_t(data_users.size());
micro_size_ = micro_size; micro_size_ = micro_size;
MS_LOG(INFO) << "Micro Size is: " << micro_size; MS_LOG(INFO) << "Micro Size is: " << micro_size;
@ -690,7 +695,10 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND); auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
std::vector<AnfNodePtr> receive_ops; std::vector<AnfNodePtr> receive_ops;
std::vector<AnfNodePtr> send_ops; std::vector<AnfNodePtr> send_ops;
auto all_nodes = graph->nodes(); auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
std::reverse(all_nodes.begin(), all_nodes.end());
auto stage_num = g_device_manager->stage_num(); auto stage_num = g_device_manager->stage_num();
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) { if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;

View File

@ -3045,6 +3045,9 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap
} }
std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name(); std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>(); auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
if (cloned_param_layout == nullptr) {
continue;
}
tensor_info_map[cloned_param_name] = cloned_param_layout; tensor_info_map[cloned_param_name] = cloned_param_layout;
} }
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {

View File

@ -161,6 +161,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
} }
for (auto &node_tensor_info : tensor_info_map) { for (auto &node_tensor_info : tensor_info_map) {
TensorLayoutPtr tensor_layout = node_tensor_info.second; TensorLayoutPtr tensor_layout = node_tensor_info.second;
MS_EXCEPTION_IF_NULL(tensor_layout);
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
MS_EXCEPTION_IF_NULL(parallel_layout_item); MS_EXCEPTION_IF_NULL(parallel_layout_item);
parallel_layout_item->set_param_name(node_tensor_info.first); parallel_layout_item->set_param_name(node_tensor_info.first);

View File

@ -18,7 +18,7 @@ Wrap cells for networks.
Use the Wrapper to combine the loss or build the training steps. Use the Wrapper to combine the loss or build the training steps.
""" """
from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \ from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, PipelineCell
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer from .grad_reducer import DistributedGradReducer
from ..layer.timedistributed import TimeDistributed from ..layer.timedistributed import TimeDistributed
@ -29,6 +29,7 @@ __all__ = [
"TrainOneStepCell", "TrainOneStepCell",
"WithLossCell", "WithLossCell",
"WithGradCell", "WithGradCell",
"PipelineCell",
"WithEvalCell", "WithEvalCell",
"GetNextSingleOp", "GetNextSingleOp",
"TrainOneStepWithLossScaleCell", "TrainOneStepWithLossScaleCell",

View File

@ -15,6 +15,7 @@
"""Loss scale cell for loss scale training.""" """Loss scale cell for loss scale training."""
import mindspore.context as context import mindspore.context as context
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_enable_parallel_optimizer
from .cell_wrapper import TrainOneStepCell from .cell_wrapper import TrainOneStepCell
from ..cell import Cell from ..cell import Cell
from ...common import Tensor, RowTensor from ...common import Tensor, RowTensor
@ -430,3 +431,100 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
if self.loss_scaling_manager is not None: if self.loss_scaling_manager is not None:
return self.loss_scaling_manager(self.scale_sense, overflow) return self.loss_scaling_manager(self.scale_sense, overflow)
return overflow return overflow
grad_scale = C.MultitypeFuncGraph("grad_scale")
shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor", "Tensor")
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
accu_grad = F.depend(accu_grad, grad)
new_grad = accu_grad * reciprocal(scale)
accu_grad = F.depend(accu_grad, new_grad)
zeros = F.tensor_mul(accu_grad, 0.0)
new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
return new_grad
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
new_grad = grad * reciprocal(scale)
accu_grad = F.depend(accu_grad, new_grad)
new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
return new_grad
class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
"""
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_sense (Cell): Cell to do the loss scale.
"""
def __init__(self, network, optimizer, scale_sense):
super(_TrainPipelineWithLossScaleCell, self).__init__(network, optimizer, sens=None)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.grad_reducer = F.identity
self.degree = 1
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.reshape = P.Reshape()
self.loss_scaling_manager = None
if isinstance(scale_sense, Cell):
self.loss_scaling_manager = scale_sense
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
name="scale_sense")
elif isinstance(scale_sense, Tensor):
if scale_sense.shape == (1,) or scale_sense.shape == ():
self.scale_sense = Parameter(scale_sense, name='scale_sense')
else:
raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape))
else:
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
self.opt_shard = _get_enable_parallel_optimizer()
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
scaling_sens = self.scale_sense
init = self.alloc_status()
status_clear = self.clear_before_grad(init)
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
init = F.depend(init, grads)
get_status = self.get_status(init)
init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,))
loss = F.depend(loss, status_clear)
if self.opt_shard:
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)
else:
accu_grads = self.grad_reducer(self.accu_grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if self.loss_scaling_manager is not None:
overflow = self.loss_scaling_manager(self.scale_sense, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, overflow, scaling_sens)
return F.depend(ret, succ)

View File

@ -19,6 +19,7 @@ from .._checkparam import Rel
from ..common import dtype as mstype from ..common import dtype as mstype
from ..nn import acc from ..nn import acc
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
from ..ops import functional as F from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
@ -184,8 +185,12 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
raise ValueError("Only `loss_scale_manager=None` or " raise ValueError("Only `loss_scale_manager=None` or "
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
"are supported on device `CPU`. ") "are supported on device `CPU`. ")
network = nn.TrainOneStepWithLossScaleCell(network, optimizer, if _get_pipeline_stages() > 1:
scale_sense=update_cell).set_train() network = _TrainPipelineWithLossScaleCell(network, optimizer,
scale_sense=update_cell).set_train()
else:
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
scale_sense=update_cell).set_train()
return network return network
if _get_pipeline_stages() > 1: if _get_pipeline_stages() > 1:
network = _TrainPipelineAccuStepCell(network, optimizer).set_train() network = _TrainPipelineAccuStepCell(network, optimizer).set_train()