!18666 [AutoParallel]change pipeline split shared parameter
Merge pull request !18666 from lichen/change_pipeline_shared_param
This commit is contained in:
commit
ead6c37b3c
|
@ -380,6 +380,7 @@ constexpr char VIRTUAL_ASSIGN_ADD[] = "_VirtualAssignAdd";
|
|||
constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad";
|
||||
constexpr char ACCU_GRAD[] = "accu_grad";
|
||||
constexpr char PARAMETER_START[] = "parameter_start";
|
||||
constexpr char PARAM_INDEX[] = "param_index";
|
||||
|
||||
// Parallel don't care
|
||||
constexpr char STRING_EQUAL[] = "string_equal";
|
||||
|
|
|
@ -246,12 +246,12 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
|
|||
return op_info;
|
||||
}
|
||||
|
||||
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
|
||||
std::pair<OperatorInfoPtr, int> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Handle Cast and TupleGetitem situation
|
||||
size_t tensor_info_index = 0;
|
||||
int tensor_info_index = 0;
|
||||
OperatorInfoPtr op_info;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
|
||||
op_info = node->user_data<OperatorInfo>();
|
||||
|
@ -259,19 +259,17 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const A
|
|||
if (IsPrimitiveCNode(node, prim::kPrimCast)) {
|
||||
cnode = cnode->input(1)->cast<CNodePtr>();
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode));
|
||||
tensor_info_index = LongToInt(GetTupleGetItemIndex(cnode));
|
||||
cnode = cnode->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
// Create OperatorInfo to get slice_shape for send/recv
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
op_info = CreateOpInfo(cnode);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index];
|
||||
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
|
||||
return std::make_pair(op_info, tensor_info_index);
|
||||
}
|
||||
|
||||
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
|
||||
std::pair<OperatorInfoPtr, int> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_users_map = manager_->node_users();
|
||||
auto node_users = node_users_map[node];
|
||||
|
@ -322,11 +320,9 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(
|
|||
continue;
|
||||
}
|
||||
auto op_info = CreateOpInfo(care_node);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
auto tensor_info = op_info->inputs_tensor_info()[IntToSize(index) - 1];
|
||||
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
|
||||
return std::make_pair(op_info, index - 1);
|
||||
}
|
||||
return std::make_pair(nullptr, nullptr);
|
||||
return std::make_pair(nullptr, 0);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
|
||||
|
@ -478,10 +474,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|||
auto send_op = CreatOpInstance(attrs, SEND, SEND);
|
||||
auto send_node = NewValueNode(send_op);
|
||||
auto prim = GetValueNode<PrimitivePtr>(send_node);
|
||||
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
|
||||
std::pair<OperatorInfoPtr, int> op_info_pair;
|
||||
AnfNodePtr care_node;
|
||||
TensorInfo tensor_info;
|
||||
if (parameter->isa<Parameter>()) {
|
||||
op_info_pair = GetParameterPair(parameter);
|
||||
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
} else {
|
||||
if (IsPrimitiveCNode(parameter, prim::kPrimCast)) {
|
||||
auto parameter_cnode = parameter->cast<CNodePtr>();
|
||||
|
@ -491,13 +489,15 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|||
}
|
||||
if (care_node->isa<Parameter>()) {
|
||||
op_info_pair = GetParameterPair(care_node);
|
||||
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
} else {
|
||||
op_info_pair = GetOpInfo(care_node);
|
||||
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
}
|
||||
}
|
||||
auto tensor_info = op_info_pair.second;
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
auto slice_shape = tensor_info->slice_shape();
|
||||
auto index = op_info_pair.second;
|
||||
auto op_info = op_info_pair.first;
|
||||
auto slice_shape = tensor_info.slice_shape();
|
||||
auto shape_type_pair = GetShapeType(parameter, slice_shape);
|
||||
prim->set_attr(SHAPE, shape_type_pair.first);
|
||||
prim->set_attr(DTYPE, shape_type_pair.second);
|
||||
|
@ -508,6 +508,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|||
} else {
|
||||
send->AddPrimalAttr(PIPELINE_PARAM, value);
|
||||
send->AddPrimalAttr(MICRO, value);
|
||||
send->set_user_data<OperatorInfo>(op_info);
|
||||
send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
|
||||
}
|
||||
OperatorAttrs depend_attrs;
|
||||
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||
|
@ -533,23 +535,25 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
|
|||
}
|
||||
Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
|
||||
Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
|
||||
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
|
||||
std::pair<OperatorInfoPtr, int> op_info_pair;
|
||||
bool is_param = true;
|
||||
TensorInfo tensor_info;
|
||||
if (node->isa<Parameter>()) {
|
||||
op_info_pair = GetParameterPair(node);
|
||||
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
} else {
|
||||
auto care_node = FindPipelineCareNode(node);
|
||||
if (care_node->isa<Parameter>()) {
|
||||
op_info_pair = GetParameterPair(care_node);
|
||||
tensor_info = op_info_pair.first->inputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
} else {
|
||||
op_info_pair = GetOpInfo(care_node);
|
||||
tensor_info = op_info_pair.first->outputs_tensor_info().at(IntToSize(op_info_pair.second));
|
||||
is_param = false;
|
||||
}
|
||||
}
|
||||
auto tensor_info = op_info_pair.second;
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
auto tensor_layout = tensor_info->tensor_layout();
|
||||
Shape slice_shape = tensor_info->slice_shape();
|
||||
auto tensor_layout = tensor_info.tensor_layout();
|
||||
Shape slice_shape = tensor_info.slice_shape();
|
||||
auto shape_type_pair = GetShapeType(node, slice_shape);
|
||||
Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
|
||||
Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
|
||||
|
|
|
@ -72,8 +72,8 @@ class PipelineTransformer {
|
|||
AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
|
||||
const std::string &tag);
|
||||
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
|
||||
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
|
||||
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);
|
||||
std::pair<OperatorInfoPtr, int> GetOpInfo(const AnfNodePtr &node);
|
||||
std::pair<OperatorInfoPtr, int> GetParameterPair(const AnfNodePtr &node);
|
||||
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);
|
||||
bool IsPipelineCareNode(const CNodePtr &cnode);
|
||||
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();
|
||||
|
|
|
@ -611,6 +611,9 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|||
AnfNodeIndexSet node_set = manager->node_users()[node];
|
||||
CNodePtr insert_node_new;
|
||||
|
||||
if (IsPrimitiveCNode(node, prim::kPrimSend)) {
|
||||
return;
|
||||
}
|
||||
if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
|
||||
MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node";
|
||||
return;
|
||||
|
@ -1091,9 +1094,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|||
}
|
||||
}
|
||||
|
||||
if (IsSomePrimitive(cnode, RECEIVE) && cnode->has_user_data(PIPELINE_PARAM)) {
|
||||
return std::make_pair(node, false);
|
||||
}
|
||||
// When not fully use opt shard, allgather and mirror would be both inserted.
|
||||
// Skip allgather here and find parameter recursively.
|
||||
if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
|
||||
|
@ -1180,6 +1180,9 @@ bool InsertMirrorBeforeCast(const CNodePtr &node, size_t index) {
|
|||
}
|
||||
|
||||
static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimSend)) {
|
||||
return true;
|
||||
}
|
||||
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
|
||||
MS_LOG(INFO) << "Input is ValueList, skip it.";
|
||||
return false;
|
||||
|
@ -1242,6 +1245,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
|
||||
for (size_t index = 1; index < node_size; ++index) {
|
||||
OperatorVector backward_op = mirror_ops[index - 1];
|
||||
if (IsPrimitiveCNode(node, prim::kPrimSend)) {
|
||||
auto param_index = GetValue<int>(node->GetPrimalAttr(PARAM_INDEX));
|
||||
backward_op = mirror_ops[IntToSize(param_index)];
|
||||
}
|
||||
if (backward_op.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1271,10 +1278,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
}
|
||||
// not a RefKey
|
||||
std::string mirror_op_name = MirrorOpName();
|
||||
if (IsPrimitiveCNode(param_node_pair.first, prim::kPrimReceive)) {
|
||||
param_ptr = param_node_pair.first->cast<CNodePtr>()->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
|
||||
param_name = param_ptr->name();
|
||||
}
|
||||
AnfNodePtr pre_node = node->input(index);
|
||||
if (!param_node_pair.second) {
|
||||
auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph);
|
||||
|
@ -1329,6 +1332,9 @@ void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &dist
|
|||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
|
||||
return;
|
||||
}
|
||||
bool is_loss_cnode =
|
||||
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
|
||||
[node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
|
||||
|
@ -2109,7 +2115,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|||
|
||||
for (auto &node : all_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!CheckExtractInfomation(cnode)) {
|
||||
if (!CheckExtractInfomation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -2631,6 +2637,9 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
|
|||
void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
|
||||
return;
|
||||
}
|
||||
OperatorVector forward_op = distribute_operator->forward_op();
|
||||
if (!forward_op.empty()) {
|
||||
MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name();
|
||||
|
@ -2820,7 +2829,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
// the make_tuple is parallel care node, but it may have not operator info
|
||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || cnode->HasPrimalAttr(PIPELINE_PARAM)) {
|
||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -2828,15 +2837,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
|
||||
// insert forward ops
|
||||
if (!IsSomePrimitive(cnode, RECEIVE)) {
|
||||
InsertForwardOps(distribute_operator, cnode);
|
||||
}
|
||||
InsertForwardOps(distribute_operator, cnode);
|
||||
|
||||
// insert redistribution ops
|
||||
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
|
||||
|
||||
// insert backward ops
|
||||
if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) {
|
||||
if (has_backward) {
|
||||
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
|
||||
}
|
||||
|
||||
|
@ -3414,12 +3421,18 @@ ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> &p
|
|||
OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
|
||||
size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
|
||||
if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
|
||||
MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
|
||||
<< ", but the index is " << user_input_index - 1;
|
||||
TensorInfo tensor_info;
|
||||
if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) {
|
||||
auto param_index = IntToSize(GetValue<int>(user_cnode->GetPrimalAttr(PARAM_INDEX)));
|
||||
tensor_info = op_info->inputs_tensor_info()[param_index];
|
||||
} else {
|
||||
size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
|
||||
if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
|
||||
MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
|
||||
<< ", but the index is " << user_input_index - 1;
|
||||
}
|
||||
tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];
|
||||
}
|
||||
TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1];
|
||||
|
||||
ParameterSliceInfo parameter_slice_info;
|
||||
parameter_slice_info.slice_shape = tensor_info.slice_shape();
|
||||
|
|
|
@ -30,7 +30,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
||||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "Send", "UpdateState", "Load"};
|
||||
"stop_gradient", "UpdateState", "Load"};
|
||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather};
|
||||
static const std::set<PrimitivePtr> TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend};
|
||||
// clang-format on
|
||||
|
|
Loading…
Reference in New Issue