forked from mindspore-Ecosystem/mindspore
opt shard fit micro batch
This commit is contained in:
parent
70152adcb3
commit
be1f5a43d7
|
@ -106,7 +106,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
prim::kPrimMirrorMiniStep);
|
||||
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
|
||||
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
|
||||
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual_add", prim::kPrimVirtualAdd);
|
||||
micro_step_allgather_replace_ = MakeSubstitution(std::make_shared<MicroStepAllGatherPass>(),
|
||||
"micro_step_allgather_replace", prim::kPrimMicroStepAllGather);
|
||||
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd);
|
||||
check_bprop_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||
reset_defer_inline_ =
|
||||
|
|
|
@ -60,6 +60,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr mirror_mini_step_elim_;
|
||||
SubstitutionPtr virtual_add_elim_;
|
||||
SubstitutionPtr mini_step_allgather_replace_;
|
||||
SubstitutionPtr micro_step_allgather_replace_;
|
||||
|
||||
// Env Item Eliminate
|
||||
SubstitutionPtr env_get_item_eliminate_;
|
||||
|
|
|
@ -300,6 +300,39 @@ class MiniStepAllGatherPass : public AnfVisitor {
|
|||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimMicroStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
|
||||
class MicroStepAllGatherPass : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMicroStepAllGather) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
std::string group = attrs[parallel::GROUP]->ToString();
|
||||
auto fusion = attrs[parallel::FUSION];
|
||||
parallel::Operator op = parallel::CreateAllGatherOp(group);
|
||||
std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER);
|
||||
auto prim_anf_node = node_input[0]->cast<ValueNodePtr>();
|
||||
prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
attrs = prim->attrs();
|
||||
attrs[parallel::FUSION] = fusion;
|
||||
prim->SetAttrs(attrs);
|
||||
auto func_graph = inputs[1]->func_graph();
|
||||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// Reset defer_inline flag
|
||||
class ResetDeferInline : public AnfVisitor {
|
||||
public:
|
||||
|
|
|
@ -36,7 +36,6 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0
|
|||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1;
|
||||
constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1;
|
||||
|
||||
constexpr char PARAMETER[] = "parameter";
|
||||
const uint64_t MAX_RECURSIVE_CALL_TIMES = 100;
|
||||
class AllreduceFusion {
|
||||
public:
|
||||
|
|
|
@ -87,14 +87,12 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) {
|
|||
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
|
||||
const AnfNodePtr &accu_parameter) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag() ||
|
||||
((IsPrimitiveCNode(node_user.first, prim::kPrimSend) || IsPrimitiveCNode(node_user.first, prim::kPrimDepend)) &&
|
||||
ParallelContext::GetInstance()->enable_parallel_optimizer())) {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
|
||||
return;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAd";
|
||||
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd.";
|
||||
return;
|
||||
}
|
||||
OperatorAttrs attrs;
|
||||
|
@ -154,10 +152,12 @@ void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
|
|||
auto node_users = node_users_map[node];
|
||||
for (auto &temp_user : node_users) {
|
||||
auto temp_node = temp_user.first;
|
||||
// Micro virtual operator might be inserted after cast
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
|
||||
temp_node = node_users_map[temp_node].begin()->first;
|
||||
}
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) {
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
|
||||
IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
|
||||
auto node_set = node_users_map[temp_node];
|
||||
for (auto &node_user : node_set) {
|
||||
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
|
||||
|
@ -182,10 +182,12 @@ void AddVirtualAssignAdd(const FuncGraphPtr &root) {
|
|||
auto node_users = node_users_map[parameter];
|
||||
for (auto &temp_user : node_users) {
|
||||
auto temp_node = temp_user.first;
|
||||
// Micro virtual operator might be inserted after cast
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
|
||||
temp_node = node_users_map[temp_node].begin()->first;
|
||||
}
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) {
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
|
||||
IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
|
||||
auto node_set = node_users_map[temp_node];
|
||||
for (auto &node_user : node_set) {
|
||||
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
|
||||
|
|
|
@ -428,6 +428,26 @@ Operator CreateMiniStepAllGatherOp(const std::string &group) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator CreateMicroStepAllGatherOp(const std::string &group) {
|
||||
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
|
||||
|
||||
OperatorName operator_name = MICRO_STEP_ALL_GATHER;
|
||||
ValuePtr attr0_value = MakeValue(group); // group
|
||||
Attr attr0 = std::make_pair(GROUP, attr0_value);
|
||||
ValuePtr attr1_value = MakeValue(mean_flag); // mean_flag
|
||||
Attr attr1 = std::make_pair(MEAN_FLAG, attr1_value);
|
||||
OperatorAttrs operator_attrs;
|
||||
operator_attrs.push_back(attr0);
|
||||
operator_attrs.push_back(attr1);
|
||||
|
||||
OperatorParams operator_param;
|
||||
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
|
||||
|
||||
Operator op = std::make_pair(operator_name, operator_arg);
|
||||
MS_LOG(INFO) << "Create MICRO_STEP_ALL_GATHER success, the group is " << group;
|
||||
return op;
|
||||
}
|
||||
|
||||
// use for get tensor slice
|
||||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
|
||||
Shape tensor_map = tensor_layout.tensor_map().array();
|
||||
|
|
|
@ -299,6 +299,7 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
|
|||
Operator CreateAllGatherOp(const std::string &group);
|
||||
Operator CreateMiniStepAllGatherOp(const std::string &group);
|
||||
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node);
|
||||
Operator CreateMicroStepAllGatherOp(const std::string &group);
|
||||
void AddCommOpMeanFlag(const CNodePtr &comm_node);
|
||||
void AddCommOpParamFlag(const CNodePtr &comm_node);
|
||||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
|
||||
|
|
|
@ -217,6 +217,7 @@ constexpr char LOCAL_STEP[] = "local_step";
|
|||
constexpr char STRIDED_SLICE[] = "StridedSlice";
|
||||
constexpr char ALL_GATHER[] = "AllGather";
|
||||
constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather";
|
||||
constexpr char MICRO_STEP_ALL_GATHER[] = "_MicroStepAllGather";
|
||||
constexpr char REDUCE_SCATTER[] = "ReduceScatter";
|
||||
constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter";
|
||||
constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup";
|
||||
|
@ -383,6 +384,7 @@ constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad";
|
|||
constexpr char ACCU_GRAD[] = "accu_grad";
|
||||
constexpr char PARAMETER_START[] = "parameter_start";
|
||||
constexpr char PARAM_INDEX[] = "param_index";
|
||||
constexpr char PARAMETER[] = "parameter";
|
||||
|
||||
// Parallel don't care
|
||||
constexpr char STRING_EQUAL[] = "string_equal";
|
||||
|
|
|
@ -199,7 +199,8 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|||
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
||||
op_name = MIRROR_OPERATOR;
|
||||
arg_forward.first.pop_back();
|
||||
} else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR) {
|
||||
} else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR ||
|
||||
op_name == MICRO_STEP_ALL_GATHER) {
|
||||
MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name;
|
||||
}
|
||||
}
|
||||
|
@ -211,7 +212,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|||
|
||||
std::vector<AnfNodePtr> new_node_input;
|
||||
if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER ||
|
||||
op_name == MIRROR_MICRO_STEP_OPERATOR) {
|
||||
op_name == MIRROR_MICRO_STEP_OPERATOR || op_name == MICRO_STEP_ALL_GATHER) {
|
||||
new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
|
||||
MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
|
||||
} else {
|
||||
|
@ -1117,6 +1118,15 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|||
return std::make_pair(nullptr, false);
|
||||
}
|
||||
|
||||
// only used for FindCNode
|
||||
CNodePtr SkipTrivialNodesMoveDown(FuncGraphManagerPtr manager, CNodePtr node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
while (IsInTrivialNodeList(node) || IsSomePrimitive(node, LOAD)) {
|
||||
node = manager->node_users()[node].begin()->first->cast<CNodePtr>();
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(anode);
|
||||
MS_EXCEPTION_IF_NULL(anode->func_graph());
|
||||
|
@ -1130,6 +1140,9 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
|
|||
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
||||
continue;
|
||||
}
|
||||
if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
|
||||
use_apply = SkipTrivialNodesMoveDown(manager, use_apply);
|
||||
}
|
||||
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
|
@ -1202,7 +1215,7 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no
|
|||
}
|
||||
|
||||
// only used for InsertMirrorOps
|
||||
CNodePtr SkipTrivialNodes(CNodePtr node) {
|
||||
CNodePtr SkipTrivialNodesMoveUp(CNodePtr node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
while (!IsSomePrimitive(node, LOAD)) {
|
||||
if (IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) {
|
||||
|
@ -1287,7 +1300,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
// assume Load is inserted next to parameter
|
||||
// skip Load moving up and insert mirror next to the parameter
|
||||
if (pre_node->cast<CNodePtr>()) {
|
||||
CNodePtr load_node = SkipTrivialNodes(node->input(index)->cast<CNodePtr>());
|
||||
CNodePtr load_node = SkipTrivialNodesMoveUp(node->input(index)->cast<CNodePtr>());
|
||||
manager->SetEdge(load_node, 1, next_cnode.second);
|
||||
} else {
|
||||
manager->SetEdge(node, static_cast<int>(index), next_cnode.second);
|
||||
|
@ -1306,7 +1319,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
if (pre_node->cast<CNodePtr>() && (InsertMirrorBeforeCast(node, index) || is_shared_param)) {
|
||||
// assume Load is inserted next to parameter
|
||||
// skip Load moving up and insert mirror next to the parameter
|
||||
CNodePtr load_node = SkipTrivialNodes(pre_node->cast<CNodePtr>());
|
||||
CNodePtr load_node = SkipTrivialNodesMoveUp(pre_node->cast<CNodePtr>());
|
||||
InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root);
|
||||
auto comm_op = load_node->input(1)->cast<CNodePtr>();
|
||||
// add fusion flag
|
||||
|
@ -1706,6 +1719,8 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
|
|||
auto param_name = node->cast<ParameterPtr>()->name();
|
||||
if (op_name == MINI_STEP_ALL_GATHER) {
|
||||
op = CreateMiniStepAllGatherOp(group);
|
||||
} else if (op_name == MICRO_STEP_ALL_GATHER) {
|
||||
op = CreateMicroStepAllGatherOp(group);
|
||||
} else {
|
||||
op = CreateAllGatherOp(group);
|
||||
}
|
||||
|
@ -1733,9 +1748,12 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
std::string op_name;
|
||||
if (grad_accumulation_step > 1) {
|
||||
op_name = MINI_STEP_ALL_GATHER;
|
||||
} else if (split_stage_num > 1) {
|
||||
op_name = MICRO_STEP_ALL_GATHER;
|
||||
} else {
|
||||
op_name = ALL_GATHER;
|
||||
}
|
||||
|
@ -1744,7 +1762,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|||
for (auto ¶m_pair : param_sub_set) {
|
||||
auto cnode = param_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->in_forward_flag()) {
|
||||
if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
|
||||
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
|
||||
if (distribute_operator == nullptr) {
|
||||
MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
|
||||
|
@ -1759,6 +1777,8 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
|
|||
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second);
|
||||
MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and "
|
||||
<< GetPrimName(cnode);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Can not find the shared AllGather with multiple node users.";
|
||||
}
|
||||
} else {
|
||||
// insert allgather operator between shard parameter and cnode
|
||||
|
@ -2852,12 +2872,14 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
|
||||
// insert forward ops
|
||||
InsertForwardOps(distribute_operator, cnode);
|
||||
|
||||
// insert redistribution ops
|
||||
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
|
||||
// skip Send Receive
|
||||
if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
|
||||
// insert forward ops
|
||||
InsertForwardOps(distribute_operator, cnode);
|
||||
|
||||
// insert redistribution ops
|
||||
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
|
||||
}
|
||||
// insert backward ops
|
||||
if (has_backward) {
|
||||
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
|
||||
|
@ -2873,7 +2895,8 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
|
||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE) ||
|
||||
IsSomePrimitive(cnode, SEND)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -2922,7 +2945,8 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
|
|||
|
||||
bool IsCohesiveNode(const CNodePtr &cnode) {
|
||||
return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
|
||||
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather);
|
||||
}
|
||||
|
||||
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
|
||||
|
|
|
@ -309,6 +309,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.virtual_add_elim_,
|
||||
irpass.row_tensor_add_zeros_like_,
|
||||
irpass.mini_step_allgather_replace_,
|
||||
irpass.micro_step_allgather_replace_,
|
||||
},
|
||||
false, true);
|
||||
opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});
|
||||
|
|
|
@ -362,6 +362,7 @@ inline const PrimitivePtr kFusedMulAdd = std::make_shared<Primitive>("FusedMulAd
|
|||
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
|
||||
inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather");
|
||||
inline const PrimitivePtr kPrimMicroStepAllGather = std::make_shared<Primitive>("_MicroStepAllGather");
|
||||
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd");
|
||||
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
|
|
|
@ -31,7 +31,8 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "UpdateState", "Load"};
|
||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather};
|
||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
|
||||
prim::kPrimMicroStepAllGather};
|
||||
static const std::set<PrimitivePtr> TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from types import FunctionType, MethodType
|
||||
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||
_get_parallel_mode)
|
||||
_get_parallel_mode, _get_enable_parallel_optimizer)
|
||||
from mindspore.context import ParallelMode, get_auto_parallel_context
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore import ops, nn
|
||||
|
@ -538,6 +538,7 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell):
|
|||
super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens)
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
|
||||
self.hyper_map = ops.HyperMap()
|
||||
self.opt_shard = _get_enable_parallel_optimizer()
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
|
@ -545,7 +546,10 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell):
|
|||
sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||
accu_grads = ops.depend(self.accu_grads, grads)
|
||||
succ = self.optimizer(accu_grads)
|
||||
if self.opt_shard:
|
||||
succ = self.optimizer(grads)
|
||||
else:
|
||||
succ = self.optimizer(accu_grads)
|
||||
clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
|
||||
loss = ops.depend(loss, succ, clear)
|
||||
return loss
|
||||
|
|
|
@ -18,13 +18,14 @@ from mindspore import Tensor
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.communication import get_rank, get_group_size
|
||||
from mindspore.parallel._utils import _get_enable_parallel_optimizer
|
||||
from .. import operations as P
|
||||
from ...common.tensor import RowTensor
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
||||
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
|
||||
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator)
|
||||
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather)
|
||||
from .grad_base import bprop_getters
|
||||
from ..operations._inner_ops import Send, Receive
|
||||
|
||||
|
@ -102,10 +103,14 @@ def get_bprop_receive(self):
|
|||
depend = P.Depend()
|
||||
cast = P.Cast()
|
||||
out_tensor = Tensor(0.0, mstype.float16)
|
||||
is_opt_shard = _get_enable_parallel_optimizer()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
send_out = receive_grad(dout)
|
||||
dx = depend(cast(out_tensor, F.dtype(x)), send_out)
|
||||
if is_opt_shard:
|
||||
dx = depend(F.zeros_like(x), send_out)
|
||||
else:
|
||||
dx = depend(cast(out_tensor, F.dtype(x)), send_out)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
@ -174,6 +179,7 @@ def get_bprop_mirror_micro_step_operator(self):
|
|||
if "parameter_micro" in self.get_attr_dict():
|
||||
assign.add_prim_attr("parameter_micro", 0)
|
||||
out_tensor = Tensor(1.0, mstype.float16)
|
||||
opt_shard = _get_enable_parallel_optimizer()
|
||||
|
||||
def bprop(x, z, out, dout):
|
||||
real_grad = z
|
||||
|
@ -188,6 +194,8 @@ def get_bprop_mirror_micro_step_operator(self):
|
|||
z = F.depend(z, dout)
|
||||
real_grad = all_reduce(z)
|
||||
assign(z, real_grad)
|
||||
if opt_shard:
|
||||
return (real_grad, cast(out_tensor, dtype(z)))
|
||||
return (cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z)))
|
||||
return bprop
|
||||
|
||||
|
@ -205,30 +213,17 @@ def get_bprop_broad_cast(self):
|
|||
def get_bprop_all_gather(self):
|
||||
"""Generate bprop for AllGather"""
|
||||
fusion = self.get_attr_dict()["fusion"]
|
||||
if fusion == 0:
|
||||
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group)
|
||||
if self.instance_name:
|
||||
instance_name = "grad_" + self.instance_name
|
||||
reduce_scatter.set_prim_instance_name(instance_name)
|
||||
else:
|
||||
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
|
||||
if self.instance_name:
|
||||
instance_name = "grad_" + self.instance_name
|
||||
all_reduce.set_prim_instance_name(instance_name)
|
||||
rank = get_rank(self.group)
|
||||
dev_num = get_group_size(self.group)
|
||||
split = P.Split(output_num=dev_num)
|
||||
mean_flag = self.get_attr_dict()["mean_flag"]
|
||||
scale = 1/self.rank_size
|
||||
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
|
||||
if self.instance_name:
|
||||
instance_name = "grad_" + self.instance_name
|
||||
reduce_scatter.set_prim_instance_name(instance_name)
|
||||
mean_flag = self.get_attr_dict()["mean_flag"]
|
||||
scale = 1 / self.rank_size
|
||||
|
||||
def bprop(x, out, dout):
|
||||
if fusion == 0:
|
||||
dx = reduce_scatter(dout)
|
||||
else:
|
||||
grad = all_reduce(dout)
|
||||
dx = split(grad)[rank]
|
||||
if mean_flag:
|
||||
dx = F.tensor_mul(dx, scale)
|
||||
dx = reduce_scatter(dout)
|
||||
if mean_flag:
|
||||
dx = F.tensor_mul(dx, scale)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
@ -267,6 +262,35 @@ def get_bprop_mini_step_all_gather(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_MicroStepAllGather)
|
||||
def get_bprop_micro_step_all_gather(self):
|
||||
"""Generate bprop for _MicroStepAllGather"""
|
||||
fusion = self.get_attr_dict()["fusion"]
|
||||
mean_flag = self.get_attr_dict()["mean_flag"]
|
||||
scale = 1 / self.rank_size
|
||||
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
|
||||
rank = get_rank(self.group)
|
||||
dev_num = get_group_size(self.group)
|
||||
split = P.Split(output_num=dev_num)
|
||||
if self.instance_name:
|
||||
instance_name = "grad_" + self.instance_name
|
||||
all_reduce.set_prim_instance_name(instance_name)
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
out_tensor = Tensor(1.0, mstype.float16)
|
||||
|
||||
# z: accu_grad
|
||||
def bprop(x, z, out, dout):
|
||||
z = F.depend(z, dout)
|
||||
real_grad = all_reduce(z)
|
||||
real_grad = split(real_grad)[rank]
|
||||
if mean_flag:
|
||||
real_grad = F.tensor_mul(real_grad, scale)
|
||||
return (real_grad, cast(out_tensor, dtype(z)))
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_HostAllGather)
|
||||
def get_bprop_host_all_gather(self):
|
||||
"""Generate bprop for _HostAllGather"""
|
||||
|
|
|
@ -6,4 +6,4 @@
|
|||
bprop.10:x*
|
||||
bprop.10:out*
|
||||
bprop.10:dout2
|
||||
bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
||||
bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a22366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
|
@ -8,4 +8,4 @@
|
|||
bprop.2:x*
|
||||
bprop.2:out*
|
||||
bprop.2:dout2
|
||||
bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
||||
bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a22366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
|
@ -37,7 +37,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
||||
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator)
|
||||
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print, Assert)
|
||||
from .control_ops import GeSwitch, Merge
|
||||
|
|
|
@ -191,6 +191,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
self.add_prim_attr('mean_flag', False)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||
|
@ -239,6 +240,36 @@ class _MiniStepAllGather(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class _MicroStepAllGather(PrimitiveWithInfer):
|
||||
"""
|
||||
Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
|
||||
internal use of parallel modules and cannot be called by users.
|
||||
|
||||
Args:
|
||||
group (str): The communication group to work on. Default: None.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None):
|
||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
self.rank = get_rank(_get_group(group))
|
||||
self.rank_size = get_group_size(_get_group(group))
|
||||
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
|
||||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 1)
|
||||
self.mean_flag = mean_flag
|
||||
|
||||
def infer_shape(self, x_shape, z_shape):
|
||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||
if x_shape[0] > 0:
|
||||
x_shape[0] = x_shape[0] * self.rank_size
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, z_dtype):
|
||||
validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class _HostAllGather(PrimitiveWithInfer):
|
||||
"""
|
||||
Gathers tensors from the specified communication group on host.
|
||||
|
|
|
@ -160,6 +160,11 @@ def _get_parameter_broadcast():
|
|||
return parameter_broadcast
|
||||
|
||||
|
||||
def _get_enable_parallel_optimizer():
|
||||
"""Get if using parallel optimizer."""
|
||||
return auto_parallel_context().get_enable_parallel_optimizer()
|
||||
|
||||
|
||||
def _device_number_check(parallel_mode, device_number):
|
||||
"""
|
||||
Check device num.
|
||||
|
|
|
@ -173,3 +173,67 @@ def test_pipeline_split_shared_parameter_stage1():
|
|||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
||||
def test_pipeline_split_stage0_opt_shard():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||
params = net.network.cell.block[0].trainable_params()
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
for _, param in model._train_network.parameters_and_names():
|
||||
assert param.name != "cell.block.1.param"
|
||||
assert param.name != "cell.block.1.param1"
|
||||
|
||||
def test_pipeline_split_stage1_opt_shard():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||
params = net.network.cell.block[1].trainable_params()
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
for _, param in model._train_network.parameters_and_names():
|
||||
assert param.name != "cell.block.0.param"
|
||||
assert param.name != "cell.block.0.param1"
|
||||
|
||||
|
||||
def test_pipeline_split_shared_parameter_stage0_opt_shard():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
|
||||
params = net.network.cell.block[0].trainable_params()
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_pipeline_split_shared_parameter_stage1_opt_shard():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
|
||||
params = net.network.cell.block[1].trainable_params()
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
|
Loading…
Reference in New Issue