forked from mindspore-Ecosystem/mindspore
!19848 [AutoParallel]Fix_auto_parallel_bug_master
Merge pull request !19848 from lichen/fix_auto_parallel_bug_master
This commit is contained in:
commit
1260700865
|
@ -29,6 +29,18 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {prim::kPrimDepend, prim::kPrimTupleGetItem,
|
||||||
|
prim::kPrimSoftmaxCrossEntropyWithLogits};
|
||||||
|
|
||||||
|
static bool IsInEndNodeBlackList(const CNodePtr &cnode) {
|
||||||
|
for (auto &prim : END_NODE_BLACK_LIST) {
|
||||||
|
if (IsPrimitiveCNode(cnode, prim)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
|
AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
|
||||||
auto pre_node = cnode->input(1);
|
auto pre_node = cnode->input(1);
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -392,7 +404,7 @@ void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, con
|
||||||
AnfNodePtr GetPreNode(const AnfNodePtr &node) {
|
AnfNodePtr GetPreNode(const AnfNodePtr &node) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
if (IsInEndNodeBlackList(cnode)) {
|
||||||
return GetPreNode(cnode->input(1));
|
return GetPreNode(cnode->input(1));
|
||||||
}
|
}
|
||||||
return cnode;
|
return cnode;
|
||||||
|
|
|
@ -134,6 +134,11 @@ void PipelineTransformer::LabelMicroBatch() {
|
||||||
for (auto &node_user : node_users) {
|
for (auto &node_user : node_users) {
|
||||||
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
|
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
|
||||||
auto data_users = manager_->node_users()[node_user.first];
|
auto data_users = manager_->node_users()[node_user.first];
|
||||||
|
auto node_first = data_users.front().first;
|
||||||
|
if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) {
|
||||||
|
data_users.clear();
|
||||||
|
data_users = node_user_map[node_first];
|
||||||
|
}
|
||||||
auto micro_size = int64_t(data_users.size());
|
auto micro_size = int64_t(data_users.size());
|
||||||
micro_size_ = micro_size;
|
micro_size_ = micro_size;
|
||||||
MS_LOG(INFO) << "Micro Size is: " << micro_size;
|
MS_LOG(INFO) << "Micro Size is: " << micro_size;
|
||||||
|
@ -690,7 +695,10 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
|
||||||
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||||
std::vector<AnfNodePtr> receive_ops;
|
std::vector<AnfNodePtr> receive_ops;
|
||||||
std::vector<AnfNodePtr> send_ops;
|
std::vector<AnfNodePtr> send_ops;
|
||||||
auto all_nodes = graph->nodes();
|
auto ret = graph->get_return();
|
||||||
|
MS_EXCEPTION_IF_NULL(ret);
|
||||||
|
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||||
|
std::reverse(all_nodes.begin(), all_nodes.end());
|
||||||
auto stage_num = g_device_manager->stage_num();
|
auto stage_num = g_device_manager->stage_num();
|
||||||
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
|
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
|
||||||
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
|
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
|
||||||
|
|
|
@ -3045,6 +3045,9 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap
|
||||||
}
|
}
|
||||||
std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
|
std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
|
||||||
auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
|
auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
|
||||||
|
if (cloned_param_layout == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
tensor_info_map[cloned_param_name] = cloned_param_layout;
|
tensor_info_map[cloned_param_name] = cloned_param_layout;
|
||||||
}
|
}
|
||||||
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
|
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
|
||||||
|
|
|
@ -161,6 +161,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
|
||||||
}
|
}
|
||||||
for (auto &node_tensor_info : tensor_info_map) {
|
for (auto &node_tensor_info : tensor_info_map) {
|
||||||
TensorLayoutPtr tensor_layout = node_tensor_info.second;
|
TensorLayoutPtr tensor_layout = node_tensor_info.second;
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_layout);
|
||||||
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
|
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
|
||||||
MS_EXCEPTION_IF_NULL(parallel_layout_item);
|
MS_EXCEPTION_IF_NULL(parallel_layout_item);
|
||||||
parallel_layout_item->set_param_name(node_tensor_info.first);
|
parallel_layout_item->set_param_name(node_tensor_info.first);
|
||||||
|
|
|
@ -18,7 +18,7 @@ Wrap cells for networks.
|
||||||
Use the Wrapper to combine the loss or build the training steps.
|
Use the Wrapper to combine the loss or build the training steps.
|
||||||
"""
|
"""
|
||||||
from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
|
from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
|
||||||
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
|
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, PipelineCell
|
||||||
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
|
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
|
||||||
from .grad_reducer import DistributedGradReducer
|
from .grad_reducer import DistributedGradReducer
|
||||||
from ..layer.timedistributed import TimeDistributed
|
from ..layer.timedistributed import TimeDistributed
|
||||||
|
@ -29,6 +29,7 @@ __all__ = [
|
||||||
"TrainOneStepCell",
|
"TrainOneStepCell",
|
||||||
"WithLossCell",
|
"WithLossCell",
|
||||||
"WithGradCell",
|
"WithGradCell",
|
||||||
|
"PipelineCell",
|
||||||
"WithEvalCell",
|
"WithEvalCell",
|
||||||
"GetNextSingleOp",
|
"GetNextSingleOp",
|
||||||
"TrainOneStepWithLossScaleCell",
|
"TrainOneStepWithLossScaleCell",
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""Loss scale cell for loss scale training."""
|
"""Loss scale cell for loss scale training."""
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.parallel._utils import _get_enable_parallel_optimizer
|
||||||
from .cell_wrapper import TrainOneStepCell
|
from .cell_wrapper import TrainOneStepCell
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
from ...common import Tensor, RowTensor
|
from ...common import Tensor, RowTensor
|
||||||
|
@ -430,3 +431,100 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
||||||
if self.loss_scaling_manager is not None:
|
if self.loss_scaling_manager is not None:
|
||||||
return self.loss_scaling_manager(self.scale_sense, overflow)
|
return self.loss_scaling_manager(self.scale_sense, overflow)
|
||||||
return overflow
|
return overflow
|
||||||
|
|
||||||
|
|
||||||
|
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||||
|
shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
|
||||||
|
reciprocal = P.Reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
@grad_scale.register("Tensor", "Tensor", "Tensor")
|
||||||
|
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
|
||||||
|
accu_grad = F.depend(accu_grad, grad)
|
||||||
|
new_grad = accu_grad * reciprocal(scale)
|
||||||
|
accu_grad = F.depend(accu_grad, new_grad)
|
||||||
|
zeros = F.tensor_mul(accu_grad, 0.0)
|
||||||
|
new_grad = F.depend(new_grad, F.assign(accu_grad, zeros))
|
||||||
|
return new_grad
|
||||||
|
|
||||||
|
|
||||||
|
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
|
||||||
|
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
|
||||||
|
new_grad = grad * reciprocal(scale)
|
||||||
|
accu_grad = F.depend(accu_grad, new_grad)
|
||||||
|
new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad)))
|
||||||
|
return new_grad
|
||||||
|
|
||||||
|
|
||||||
|
class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
|
||||||
|
"""
|
||||||
|
Append an optimizer to the training network after that the construct
|
||||||
|
function can be called to create the backward graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network. Note that loss function should have been added.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the weights.
|
||||||
|
scale_sense (Cell): Cell to do the loss scale.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, optimizer, scale_sense):
|
||||||
|
super(_TrainPipelineWithLossScaleCell, self).__init__(network, optimizer, sens=None)
|
||||||
|
self.network = network
|
||||||
|
self.network.add_flags(defer_inline=True)
|
||||||
|
self.weights = optimizer.parameters
|
||||||
|
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.grad_reducer = F.identity
|
||||||
|
self.degree = 1
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.base = Tensor(1, mstype.float32)
|
||||||
|
self.less_equal = P.LessEqual()
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.loss_scaling_manager = None
|
||||||
|
if isinstance(scale_sense, Cell):
|
||||||
|
self.loss_scaling_manager = scale_sense
|
||||||
|
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
|
||||||
|
name="scale_sense")
|
||||||
|
elif isinstance(scale_sense, Tensor):
|
||||||
|
if scale_sense.shape == (1,) or scale_sense.shape == ():
|
||||||
|
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
||||||
|
else:
|
||||||
|
raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape))
|
||||||
|
else:
|
||||||
|
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
|
||||||
|
self.opt_shard = _get_enable_parallel_optimizer()
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(*inputs)
|
||||||
|
scaling_sens = self.scale_sense
|
||||||
|
init = self.alloc_status()
|
||||||
|
status_clear = self.clear_before_grad(init)
|
||||||
|
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||||
|
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
||||||
|
init = F.depend(init, grads)
|
||||||
|
get_status = self.get_status(init)
|
||||||
|
init = F.depend(init, get_status)
|
||||||
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
|
loss = F.depend(loss, status_clear)
|
||||||
|
if self.opt_shard:
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)
|
||||||
|
else:
|
||||||
|
accu_grads = self.grad_reducer(self.accu_grads)
|
||||||
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
|
||||||
|
cond = self.less_equal(self.base, flag_sum)
|
||||||
|
overflow = cond
|
||||||
|
if self.loss_scaling_manager is not None:
|
||||||
|
overflow = self.loss_scaling_manager(self.scale_sense, cond)
|
||||||
|
if overflow:
|
||||||
|
succ = False
|
||||||
|
else:
|
||||||
|
succ = self.optimizer(grads)
|
||||||
|
ret = (loss, overflow, scaling_sens)
|
||||||
|
return F.depend(ret, succ)
|
||||||
|
|
|
@ -19,6 +19,7 @@ from .._checkparam import Rel
|
||||||
from ..common import dtype as mstype
|
from ..common import dtype as mstype
|
||||||
from ..nn import acc
|
from ..nn import acc
|
||||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
|
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
|
||||||
|
from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
|
||||||
from ..ops import functional as F
|
from ..ops import functional as F
|
||||||
from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
|
from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
|
||||||
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
||||||
|
@ -184,8 +185,12 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
||||||
raise ValueError("Only `loss_scale_manager=None` or "
|
raise ValueError("Only `loss_scale_manager=None` or "
|
||||||
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
|
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
|
||||||
"are supported on device `CPU`. ")
|
"are supported on device `CPU`. ")
|
||||||
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
|
if _get_pipeline_stages() > 1:
|
||||||
scale_sense=update_cell).set_train()
|
network = _TrainPipelineWithLossScaleCell(network, optimizer,
|
||||||
|
scale_sense=update_cell).set_train()
|
||||||
|
else:
|
||||||
|
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
|
||||||
|
scale_sense=update_cell).set_train()
|
||||||
return network
|
return network
|
||||||
if _get_pipeline_stages() > 1:
|
if _get_pipeline_stages() > 1:
|
||||||
network = _TrainPipelineAccuStepCell(network, optimizer).set_train()
|
network = _TrainPipelineAccuStepCell(network, optimizer).set_train()
|
||||||
|
|
Loading…
Reference in New Issue