opt shard fit micro batch

This commit is contained in:
Ziyan 2021-06-16 16:30:41 +08:00
parent 70152adcb3
commit be1f5a43d7
20 changed files with 266 additions and 51 deletions

View File

@ -106,7 +106,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
prim::kPrimMirrorMiniStep);
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual_add", prim::kPrimVirtualAdd);
micro_step_allgather_replace_ = MakeSubstitution(std::make_shared<MicroStepAllGatherPass>(),
"micro_step_allgather_replace", prim::kPrimMicroStepAllGather);
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =

View File

@ -60,6 +60,7 @@ class OptimizeIRPassLib {
SubstitutionPtr mirror_mini_step_elim_;
SubstitutionPtr virtual_add_elim_;
SubstitutionPtr mini_step_allgather_replace_;
SubstitutionPtr micro_step_allgather_replace_;
// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;

View File

@ -300,6 +300,39 @@ class MiniStepAllGatherPass : public AnfVisitor {
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimMicroStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
class MicroStepAllGatherPass : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimMicroStepAllGather) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0));
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
std::string group = attrs[parallel::GROUP]->ToString();
auto fusion = attrs[parallel::FUSION];
parallel::Operator op = parallel::CreateAllGatherOp(group);
std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER);
auto prim_anf_node = node_input[0]->cast<ValueNodePtr>();
prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
attrs = prim->attrs();
attrs[parallel::FUSION] = fusion;
prim->SetAttrs(attrs);
auto func_graph = inputs[1]->func_graph();
CNodePtr new_node = func_graph->NewCNode(node_input);
return new_node;
}
void Visit(const AnfNodePtr &) override {}
};
// Reset defer_inline flag
class ResetDeferInline : public AnfVisitor {
public:

View File

@ -36,7 +36,6 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1;
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1;
constexpr char PARAMETER[] = "parameter";
const uint64_t MAX_RECURSIVE_CALL_TIMES = 100;
class AllreduceFusion {
public:

View File

@ -87,14 +87,12 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) {
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
const AnfNodePtr &accu_parameter) {
auto cnode = node_user.first->cast<CNodePtr>();
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag() ||
((IsPrimitiveCNode(node_user.first, prim::kPrimSend) || IsPrimitiveCNode(node_user.first, prim::kPrimDepend)) &&
ParallelContext::GetInstance()->enable_parallel_optimizer())) {
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
return;
}
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAd";
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd.";
return;
}
OperatorAttrs attrs;
@ -154,10 +152,12 @@ void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
auto node_users = node_users_map[node];
for (auto &temp_user : node_users) {
auto temp_node = temp_user.first;
// Micro virtual operator might be inserted after cast
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
temp_node = node_users_map[temp_node].begin()->first;
}
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) {
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
auto node_set = node_users_map[temp_node];
for (auto &node_user : node_set) {
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
@ -182,10 +182,12 @@ void AddVirtualAssignAdd(const FuncGraphPtr &root) {
auto node_users = node_users_map[parameter];
for (auto &temp_user : node_users) {
auto temp_node = temp_user.first;
// Micro virtual operator might be inserted after cast
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
temp_node = node_users_map[temp_node].begin()->first;
}
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) {
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
auto node_set = node_users_map[temp_node];
for (auto &node_user : node_set) {
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);

View File

@ -428,6 +428,26 @@ Operator CreateMiniStepAllGatherOp(const std::string &group) {
return op;
}
Operator CreateMicroStepAllGatherOp(const std::string &group) {
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
OperatorName operator_name = MICRO_STEP_ALL_GATHER;
ValuePtr attr0_value = MakeValue(group); // group
Attr attr0 = std::make_pair(GROUP, attr0_value);
ValuePtr attr1_value = MakeValue(mean_flag); // mean_flag
Attr attr1 = std::make_pair(MEAN_FLAG, attr1_value);
OperatorAttrs operator_attrs;
operator_attrs.push_back(attr0);
operator_attrs.push_back(attr1);
OperatorParams operator_param;
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
Operator op = std::make_pair(operator_name, operator_arg);
MS_LOG(INFO) << "Create MICRO_STEP_ALL_GATHER success, the group is " << group;
return op;
}
// use for get tensor slice
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
Shape tensor_map = tensor_layout.tensor_map().array();

View File

@ -299,6 +299,7 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
Operator CreateAllGatherOp(const std::string &group);
Operator CreateMiniStepAllGatherOp(const std::string &group);
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
Operator CreateMicroStepAllGatherOp(const std::string &group);
void AddCommOpMeanFlag(const CNodePtr &comm_node);
void AddCommOpParamFlag(const CNodePtr &comm_node);
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);

View File

@ -217,6 +217,7 @@ constexpr char LOCAL_STEP[] = "local_step";
constexpr char STRIDED_SLICE[] = "StridedSlice";
constexpr char ALL_GATHER[] = "AllGather";
constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather";
constexpr char MICRO_STEP_ALL_GATHER[] = "_MicroStepAllGather";
constexpr char REDUCE_SCATTER[] = "ReduceScatter";
constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter";
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";
@ -383,6 +384,7 @@ constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad";
constexpr char ACCU_GRAD[] = "accu_grad";
constexpr char PARAMETER_START[] = "parameter_start";
constexpr char PARAM_INDEX[] = "param_index";
constexpr char PARAMETER[] = "parameter";
// Parallel don't care
constexpr char STRING_EQUAL[] = "string_equal";

View File

@ -199,7 +199,8 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
op_name = MIRROR_OPERATOR;
arg_forward.first.pop_back();
} else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR) {
} else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR ||
op_name == MICRO_STEP_ALL_GATHER) {
MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name;
}
}
@ -211,7 +212,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
std::vector<AnfNodePtr> new_node_input;
if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER ||
op_name == MIRROR_MICRO_STEP_OPERATOR) {
op_name == MIRROR_MICRO_STEP_OPERATOR || op_name == MICRO_STEP_ALL_GATHER) {
new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
} else {
@ -1117,6 +1118,15 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
return std::make_pair(nullptr, false);
}
// only used for FindCNode
CNodePtr SkipTrivialNodesMoveDown(FuncGraphManagerPtr manager, CNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
while (IsInTrivialNodeList(node) || IsSomePrimitive(node, LOAD)) {
node = manager->node_users()[node].begin()->first->cast<CNodePtr>();
}
return node;
}
std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(anode);
MS_EXCEPTION_IF_NULL(anode->func_graph());
@ -1130,6 +1140,9 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
use_apply = SkipTrivialNodesMoveDown(manager, use_apply);
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
@ -1202,7 +1215,7 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no
}
// only used for InsertMirrorOps
CNodePtr SkipTrivialNodes(CNodePtr node) {
CNodePtr SkipTrivialNodesMoveUp(CNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
while (!IsSomePrimitive(node, LOAD)) {
if (IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) {
@ -1287,7 +1300,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
// assume Load is inserted next to parameter
// skip Load moving up and insert mirror next to the parameter
if (pre_node->cast<CNodePtr>()) {
CNodePtr load_node = SkipTrivialNodes(node->input(index)->cast<CNodePtr>());
CNodePtr load_node = SkipTrivialNodesMoveUp(node->input(index)->cast<CNodePtr>());
manager->SetEdge(load_node, 1, next_cnode.second);
} else {
manager->SetEdge(node, static_cast<int>(index), next_cnode.second);
@ -1306,7 +1319,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
if (pre_node->cast<CNodePtr>() && (InsertMirrorBeforeCast(node, index) || is_shared_param)) {
// assume Load is inserted next to parameter
// skip Load moving up and insert mirror next to the parameter
CNodePtr load_node = SkipTrivialNodes(pre_node->cast<CNodePtr>());
CNodePtr load_node = SkipTrivialNodesMoveUp(pre_node->cast<CNodePtr>());
InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root);
auto comm_op = load_node->input(1)->cast<CNodePtr>();
// add fusion flag
@ -1706,6 +1719,8 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
auto param_name = node->cast<ParameterPtr>()->name();
if (op_name == MINI_STEP_ALL_GATHER) {
op = CreateMiniStepAllGatherOp(group);
} else if (op_name == MICRO_STEP_ALL_GATHER) {
op = CreateMicroStepAllGatherOp(group);
} else {
op = CreateAllGatherOp(group);
}
@ -1733,9 +1748,12 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(manager);
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
std::string op_name;
if (grad_accumulation_step > 1) {
op_name = MINI_STEP_ALL_GATHER;
} else if (split_stage_num > 1) {
op_name = MICRO_STEP_ALL_GATHER;
} else {
op_name = ALL_GATHER;
}
@ -1744,7 +1762,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
for (auto &param_pair : param_sub_set) {
auto cnode = param_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->in_forward_flag()) {
if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
@ -1759,6 +1777,8 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second);
MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and "
<< GetPrimName(cnode);
} else {
MS_LOG(ERROR) << "Can not find the shared AllGather with multiple node users.";
}
} else {
// insert allgather operator between shard parameter and cnode
@ -2852,12 +2872,14 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
MS_EXCEPTION_IF_NULL(distribute_operator);
// insert forward ops
InsertForwardOps(distribute_operator, cnode);
// insert redistribution ops
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
// skip Send Receive
if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
// insert forward ops
InsertForwardOps(distribute_operator, cnode);
// insert redistribution ops
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
}
// insert backward ops
if (has_backward) {
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
@ -2873,7 +2895,8 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE) ||
IsSomePrimitive(cnode, SEND)) {
continue;
}
@ -2922,7 +2945,8 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
bool IsCohesiveNode(const CNodePtr &cnode) {
return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather) ||
IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather);
}
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {

View File

@ -309,6 +309,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.virtual_add_elim_,
irpass.row_tensor_add_zeros_like_,
irpass.mini_step_allgather_replace_,
irpass.micro_step_allgather_replace_,
},
false, true);
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});

View File

@ -362,6 +362,7 @@ inline const PrimitivePtr kFusedMulAdd = std::make_shared<Primitive>("FusedMulAd
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather");
inline const PrimitivePtr kPrimMicroStepAllGather = std::make_shared<Primitive>("_MicroStepAllGather");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");

View File

@ -31,7 +31,8 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load"};
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather};
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
prim::kPrimMicroStepAllGather};
static const std::set<PrimitivePtr> TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend};
// clang-format on

View File

@ -16,7 +16,7 @@
from types import FunctionType, MethodType
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
_get_parallel_mode, _get_enable_parallel_optimizer)
from mindspore.context import ParallelMode, get_auto_parallel_context
from mindspore._checkparam import Validator as validator
from mindspore import ops, nn
@ -538,6 +538,7 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell):
super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens)
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
self.hyper_map = ops.HyperMap()
self.opt_shard = _get_enable_parallel_optimizer()
def construct(self, *inputs):
weights = self.weights
@ -545,7 +546,10 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell):
sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
accu_grads = ops.depend(self.accu_grads, grads)
succ = self.optimizer(accu_grads)
if self.opt_shard:
succ = self.optimizer(grads)
else:
succ = self.optimizer(accu_grads)
clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
loss = ops.depend(loss, succ, clear)
return loss

View File

@ -18,13 +18,14 @@ from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.communication import get_rank, get_group_size
from mindspore.parallel._utils import _get_enable_parallel_optimizer
from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator)
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
@ -102,10 +103,14 @@ def get_bprop_receive(self):
depend = P.Depend()
cast = P.Cast()
out_tensor = Tensor(0.0, mstype.float16)
is_opt_shard = _get_enable_parallel_optimizer()
def bprop(x, out, dout):
send_out = receive_grad(dout)
dx = depend(cast(out_tensor, F.dtype(x)), send_out)
if is_opt_shard:
dx = depend(F.zeros_like(x), send_out)
else:
dx = depend(cast(out_tensor, F.dtype(x)), send_out)
return (dx,)
return bprop
@ -174,6 +179,7 @@ def get_bprop_mirror_micro_step_operator(self):
if "parameter_micro" in self.get_attr_dict():
assign.add_prim_attr("parameter_micro", 0)
out_tensor = Tensor(1.0, mstype.float16)
opt_shard = _get_enable_parallel_optimizer()
def bprop(x, z, out, dout):
real_grad = z
@ -188,6 +194,8 @@ def get_bprop_mirror_micro_step_operator(self):
z = F.depend(z, dout)
real_grad = all_reduce(z)
assign(z, real_grad)
if opt_shard:
return (real_grad, cast(out_tensor, dtype(z)))
return (cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z)))
return bprop
@ -205,30 +213,17 @@ def get_bprop_broad_cast(self):
def get_bprop_all_gather(self):
"""Generate bprop for AllGather"""
fusion = self.get_attr_dict()["fusion"]
if fusion == 0:
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group)
if self.instance_name:
instance_name = "grad_" + self.instance_name
reduce_scatter.set_prim_instance_name(instance_name)
else:
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)
mean_flag = self.get_attr_dict()["mean_flag"]
scale = 1/self.rank_size
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "grad_" + self.instance_name
reduce_scatter.set_prim_instance_name(instance_name)
mean_flag = self.get_attr_dict()["mean_flag"]
scale = 1 / self.rank_size
def bprop(x, out, dout):
if fusion == 0:
dx = reduce_scatter(dout)
else:
grad = all_reduce(dout)
dx = split(grad)[rank]
if mean_flag:
dx = F.tensor_mul(dx, scale)
dx = reduce_scatter(dout)
if mean_flag:
dx = F.tensor_mul(dx, scale)
return (dx,)
return bprop
@ -267,6 +262,35 @@ def get_bprop_mini_step_all_gather(self):
return bprop
@bprop_getters.register(_MicroStepAllGather)
def get_bprop_micro_step_all_gather(self):
"""Generate bprop for _MicroStepAllGather"""
fusion = self.get_attr_dict()["fusion"]
mean_flag = self.get_attr_dict()["mean_flag"]
scale = 1 / self.rank_size
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
cast = P.Cast()
dtype = P.DType()
out_tensor = Tensor(1.0, mstype.float16)
# z: accu_grad
def bprop(x, z, out, dout):
z = F.depend(z, dout)
real_grad = all_reduce(z)
real_grad = split(real_grad)[rank]
if mean_flag:
real_grad = F.tensor_mul(real_grad, scale)
return (real_grad, cast(out_tensor, dtype(z)))
return bprop
@bprop_getters.register(_HostAllGather)
def get_bprop_host_all_gather(self):
"""Generate bprop for _HostAllGather"""

View File

@ -6,4 +6,4 @@
bprop.10:x*
bprop.10:out*
bprop.10:dout2
bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a22366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260

View File

@ -8,4 +8,4 @@
bprop.2:x*
bprop.2:out*
bprop.2:dout2
bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a22366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260

View File

@ -37,7 +37,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator)
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)
from .control_ops import GeSwitch, Merge

View File

@ -191,6 +191,7 @@ class AllGather(PrimitiveWithInfer):
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('mean_flag', False)
def infer_shape(self, x_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
@ -239,6 +240,36 @@ class _MiniStepAllGather(PrimitiveWithInfer):
return x_dtype
class _MicroStepAllGather(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
internal use of parallel modules and cannot be called by users.
Args:
group (str): The communication group to work on. Default: None.
"""
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None):
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.rank = get_rank(_get_group(group))
self.rank_size = get_group_size(_get_group(group))
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 1)
self.mean_flag = mean_flag
def infer_shape(self, x_shape, z_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
if x_shape[0] > 0:
x_shape[0] = x_shape[0] * self.rank_size
return x_shape
def infer_dtype(self, x_dtype, z_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
return x_dtype
class _HostAllGather(PrimitiveWithInfer):
"""
Gathers tensors from the specified communication group on host.

View File

@ -160,6 +160,11 @@ def _get_parameter_broadcast():
return parameter_broadcast
def _get_enable_parallel_optimizer():
"""Get if using parallel optimizer."""
return auto_parallel_context().get_enable_parallel_optimizer()
def _device_number_check(parallel_mode, device_number):
"""
Check device num.

View File

@ -173,3 +173,67 @@ def test_pipeline_split_shared_parameter_stage1():
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_stage0_opt_shard():
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.network.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.1.param"
assert param.name != "cell.block.1.param1"
def test_pipeline_split_stage1_opt_shard():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.network.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.0.param"
assert param.name != "cell.block.0.param1"
def test_pipeline_split_shared_parameter_stage0_opt_shard():
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
params = net.network.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_shared_parameter_stage1_opt_shard():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
params = net.network.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)