!16457 [AutoParallel]pipeline_split_adapt_master

Merge pull request !16457 from lichen/pipeline_split_adapt_master
This commit is contained in:
i-robot 2021-06-11 11:37:40 +08:00 committed by Gitee
commit 85d860e6a2
24 changed files with 1505 additions and 417 deletions

View File

@ -63,6 +63,10 @@ bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(cur_node); MS_EXCEPTION_IF_NULL(cur_node);
if (IsPrimitiveCNode(cur_node, prim::kPrimReceive)) {
return false;
}
// when input is a parameter or is a value node // when input is a parameter or is a value node
if (IsParameterOrValueNode(input)) { if (IsParameterOrValueNode(input)) {
return true; return true;

View File

@ -192,90 +192,6 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
} }
} }
bool IsBackward(const CNodePtr &cnode) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
return prim->HasAttr(BACKWARD);
}
// compare the value of send/recv sr_tag
bool comp(const CNodePtr &node1, const CNodePtr &node2) {
auto prim1 = GetValueNode<PrimitivePtr>(node1->input(0));
MS_EXCEPTION_IF_NULL(prim1);
auto prim2 = GetValueNode<PrimitivePtr>(node1->input(0));
MS_EXCEPTION_IF_NULL(prim2);
auto sr_tag_value1 = prim1->GetAttr(SR_TAG);
MS_EXCEPTION_IF_NULL(sr_tag_value1);
auto sr_tag_value2 = prim2->GetAttr(SR_TAG);
MS_EXCEPTION_IF_NULL(sr_tag_value2);
auto sr_tag1 = GetValue<int64_t>(sr_tag_value1);
auto sr_tag2 = GetValue<int64_t>(sr_tag_value2);
return sr_tag1 < sr_tag2;
}
// Reorder the execution order of send
void ReorderSend(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
auto last_node = op_v.back();
for (auto &node : op_v) {
if (node == last_node) {
continue;
}
auto iter = std::find(execution_order->begin(), execution_order->end(), node);
(void)execution_order->erase(iter);
}
std::sort(op_v.begin(), op_v.end(), comp);
auto last_node_iter = std::find(execution_order->begin(), execution_order->end(), last_node);
auto node_iter = execution_order->erase(last_node_iter);
// all send will insert the end of the last node
execution_order->insert(node_iter, op_v.begin(), op_v.end());
}
// Reorder the execution order of receive
void ReorderRecv(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
auto begin_node = op_v.front();
for (auto &node : op_v) {
if (node == begin_node) {
continue;
}
auto iter = std::find(execution_order->begin(), execution_order->end(), node);
(void)execution_order->erase(iter);
}
std::sort(op_v.begin(), op_v.end(), comp);
auto begin_node_iter = std::find(execution_order->begin(), execution_order->end(), begin_node);
auto node_iter = execution_order->erase(begin_node_iter);
// all receive will insert before the begin node
execution_order->insert(node_iter, op_v.begin(), op_v.end());
}
void ReorderSendRecv(std::vector<CNodePtr> *execution_order) {
std::vector<CNodePtr> forward_send, forward_recv, backward_send, backward_recv;
for (auto &cnode : *execution_order) {
if (IsPrimitiveCNode(cnode, prim::kPrimSend) && IsBackward(cnode)) {
backward_send.push_back(cnode);
continue;
} else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) {
forward_send.push_back(cnode);
continue;
}
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && IsBackward(cnode)) {
backward_recv.push_back(cnode);
} else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
forward_recv.push_back(cnode);
}
}
if (!forward_send.empty()) {
ReorderSend(execution_order, forward_send);
}
if (!backward_send.empty()) {
ReorderSend(execution_order, backward_send);
}
if (!forward_recv.empty()) {
ReorderRecv(execution_order, forward_recv);
}
if (!backward_recv.empty()) {
ReorderRecv(execution_order, backward_recv);
}
}
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) { size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Load kInputCtrlTensors"; MS_LOG(INFO) << "Load kInputCtrlTensors";
@ -511,10 +427,6 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
// adjust kernel // adjust kernel
AdjustKernel(root_graph); AdjustKernel(root_graph);
// reorder send/recv
auto execution_order = root_graph->execution_order();
ReorderSendRecv(&execution_order);
root_graph->set_execution_order(execution_order);
#if ENABLE_CPU && ENABLE_D #if ENABLE_CPU && ENABLE_D
InitPsWorker(root_graph); InitPsWorker(root_graph);
#endif #endif

View File

@ -206,8 +206,14 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
virtual_output_eliminate_ = virtual_output_eliminate_ =
MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput); MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
// Receive // PipelineSplit
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive); receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
virtual_accu_grad_ =
MakeSubstitution(std::make_shared<VirtualAccuGradEliminater>(), "virtual_accu_grad", prim::kPrimVirtualAccuGrad);
virtual_assign_add_ =
MakeSubstitution(std::make_shared<VirtualAssignAddEliminater>(), "virtual_assign_add", prim::kPrimVirtualAssignAdd);
mirror_micro_step_ =
MakeSubstitution(std::make_shared<MirrorMicroStepEliminater>(), "mirror_micro_step", prim::kPrimMirrorMicroStep);
// Convert // Convert
print_tuple_wrapper_ = print_tuple_wrapper_ =

View File

@ -120,8 +120,12 @@ class OptimizeIRPassLib {
// virtual output // virtual output
SubstitutionPtr virtual_output_eliminate_; SubstitutionPtr virtual_output_eliminate_;
// Receive
// PipelineSplit
SubstitutionPtr receive_eliminate_; SubstitutionPtr receive_eliminate_;
SubstitutionPtr virtual_accu_grad_;
SubstitutionPtr virtual_assign_add_;
SubstitutionPtr mirror_micro_step_;
// Convert // Convert
SubstitutionPtr print_tuple_wrapper_; SubstitutionPtr print_tuple_wrapper_;

View File

@ -134,6 +134,63 @@ class ReceiveEliminater : public AnfVisitor {
void Visit(const AnfNodePtr &) override {} void Visit(const AnfNodePtr &) override {}
}; };
class VirtualAssignAddEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAssignAdd) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
private:
AnfNodePtr x_{nullptr};
};
class VirtualAccuGradEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAccuGrad) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
private:
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMirrorMicroStep, X, Z} -> X
class MirrorMicroStepEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimSameTypeShape, X, Y} -> X // {prim::kPrimSameTypeShape, X, Y} -> X
class SameEliminater : public AnfVisitor { class SameEliminater : public AnfVisitor {
public: public:

View File

@ -0,0 +1,634 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iterator>
#include <memory>
#include <list>
#include <algorithm>
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "base/core_ops.h"
#include "ir/value.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
auto pre_node = cnode->input(1);
while (true) {
if (pre_node->isa<Parameter>()) {
return pre_node;
} else {
if (pre_node->isa<CNode>()) {
auto pre_cnode = pre_node->cast<CNodePtr>();
pre_node = pre_cnode->input(1);
} else {
return nullptr;
}
}
}
return nullptr;
}
bool IsLastStage() {
MS_EXCEPTION_IF_NULL(g_device_manager);
auto stage_num = g_device_manager->stage_num();
auto stage_id = g_device_manager->stage_id();
return ((stage_num - 1) == stage_id);
}
void SetStridedSliceStrategy(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
return;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
int64_t dev_num = 1;
auto attrs_temp = prim->attrs();
std::vector<Shapes> shape_list = ExtractShape(cnode);
if (shape_list.empty()) {
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " failed to extract shape";
}
std::vector<ValuePtr> elements;
for (size_t i = 0; i < shape_list[0].size(); i++) {
if (shape_list[0][i].empty()) {
MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
}
Dimensions input_strategy = {dev_num};
for (size_t j = 1; j < shape_list[0][i].size(); j++) {
input_strategy.push_back(1);
}
elements.push_back(MakeValue(input_strategy));
}
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
attrs_temp[STRATEGY] = strategy;
(void)prim->SetAttrs(attrs_temp);
}
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
const AnfNodePtr &accu_parameter) {
auto cnode = node_user.first->cast<CNodePtr>();
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag() ||
((IsPrimitiveCNode(node_user.first, prim::kPrimSend) || IsPrimitiveCNode(node_user.first, prim::kPrimDepend)) &&
ParallelContext::GetInstance()->enable_parallel_optimizer())) {
return;
}
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAd";
return;
}
OperatorAttrs attrs;
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
auto value_node = NewValueNode(py_instance);
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter};
auto graph = cnode->func_graph();
auto virtual_node = graph->NewCNode(virtual_node_input);
manager->SetEdge(cnode, node_user.second, virtual_node);
}
void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr &param) {
auto cnode = recv->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
OperatorAttrs attrs;
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
auto value_node = NewValueNode(py_instance);
std::vector<AnfNodePtr> virtual_node_input = {value_node, recv, param};
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
auto virtual_node = graph->NewCNode(virtual_node_input);
manager->Replace(recv, virtual_node);
}
AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> &parameters, const std::string &name) {
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name() == name) {
continue;
}
auto expect_name = "accu_grads." + name;
if (param_ptr->name() == expect_name) {
return parameter;
}
}
return nullptr;
}
void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
auto parameters = root->parameters();
auto node_users_map = root->manager()->node_users();
for (auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimReceive)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
continue;
}
auto parameter_ptr = cnode->input(1)->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_ptr);
auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
if (!accu_parameter) {
continue;
}
auto node_users = node_users_map[node];
for (auto &temp_user : node_users) {
auto temp_node = temp_user.first;
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
temp_node = node_users_map[temp_node].begin()->first;
}
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) {
auto node_set = node_users_map[temp_node];
for (auto &node_user : node_set) {
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
}
} else {
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
}
}
InsertVirtualAccuGrad(node, root->manager(), accu_parameter);
}
}
void AddVirtualAssignAdd(const FuncGraphPtr &root) {
auto parameters = root->parameters();
auto node_users_map = root->manager()->node_users();
for (auto &parameter : parameters) {
auto parameter_ptr = parameter->cast<ParameterPtr>();
auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
if (!accu_parameter) {
continue;
}
auto node_users = node_users_map[parameter];
for (auto &temp_user : node_users) {
auto temp_node = temp_user.first;
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
temp_node = node_users_map[temp_node].begin()->first;
}
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) {
auto node_set = node_users_map[temp_node];
for (auto &node_user : node_set) {
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
}
} else {
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
}
}
}
}
bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
auto cnode1 = node1->cast<CNodePtr>();
auto cnode2 = node2->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode1);
MS_EXCEPTION_IF_NULL(cnode2);
auto micro1 = cnode1->GetPrimalAttr(MICRO);
auto micro2 = cnode2->GetPrimalAttr(MICRO);
MS_EXCEPTION_IF_NULL(micro1);
MS_EXCEPTION_IF_NULL(micro2);
auto micro1_value = GetValue<int64_t>(micro1);
auto micro2_value = GetValue<int64_t>(micro2);
if (micro1_value == micro2_value) {
auto prim1 = GetCNodePrimitive(cnode1);
auto prim2 = GetCNodePrimitive(cnode2);
MS_EXCEPTION_IF_NULL(prim1);
MS_EXCEPTION_IF_NULL(prim2);
auto rank_tag1 = prim1->GetAttr(SRC_RANK);
auto rank_tag2 = prim2->GetAttr(SRC_RANK);
if (rank_tag1 == nullptr) {
rank_tag1 = prim1->GetAttr(DEST_RANK);
}
if (rank_tag2 == nullptr) {
rank_tag2 = prim2->GetAttr(DEST_RANK);
}
MS_EXCEPTION_IF_NULL(rank_tag1);
MS_EXCEPTION_IF_NULL(rank_tag2);
auto rank1_value = GetValue<int64_t>(rank_tag1);
auto rank2_value = GetValue<int64_t>(rank_tag2);
if (rank1_value == rank2_value) {
auto sr_tag1 = prim1->GetAttr(SR_TAG);
auto sr_tag2 = prim2->GetAttr(SR_TAG);
MS_EXCEPTION_IF_NULL(sr_tag1);
MS_EXCEPTION_IF_NULL(sr_tag2);
auto sr1_value = GetValue<int64_t>(sr_tag1);
auto sr2_value = GetValue<int64_t>(sr_tag2);
return sr1_value < sr2_value;
}
return rank1_value < rank2_value;
}
return micro1_value < micro2_value;
}
void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(post_node);
auto post_cnode = post_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(post_cnode);
std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node};
auto depend_node = root->NewCNode(depend_input);
manager->SetEdge(post_node, 1, depend_node);
}
void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end,
const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(g_device_manager);
MS_EXCEPTION_IF_NULL(root);
auto manager = root->manager();
MS_EXCEPTION_IF_NULL(manager);
auto stage_num = g_device_manager->stage_num();
auto stage_id = g_device_manager->stage_id();
for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) {
auto prior_node = forward_end[i - 1];
auto post_node = forward_start[i];
InsertDepend(prior_node, post_node, manager, root);
}
}
void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(g_device_manager);
MS_EXCEPTION_IF_NULL(root);
auto manager = root->manager();
MS_EXCEPTION_IF_NULL(manager);
auto stage_num = g_device_manager->stage_num();
auto stage_id = g_device_manager->stage_id();
for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) {
auto prior_node1 = forward_end_before_pair.second[i];
auto post_node1 = backward_start_pair.first[i - stage_num + stage_id + 1];
InsertDepend(prior_node1, post_node1, manager, root);
auto prior_node2 = backward_end_pair.second[i - stage_num + stage_id];
auto post_node2 = forward_start_pair.first[i];
InsertDepend(prior_node2, post_node2, manager, root);
}
for (size_t i = (stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) {
if (!IsLastStage()) {
auto prior_node3 = backward_start_pair.second[i - stage_num + stage_id];
auto post_node3 = forward_end_pair.first[i - 1];
InsertDepend(prior_node3, post_node3, manager, root);
auto prior_node4 = forward_end_pair.second[i - 1];
auto post_node4 = backward_end_pair.first[i - stage_num + stage_id];
InsertDepend(prior_node4, post_node4, manager, root);
}
}
for (size_t j = (backward_start_pair.first.size() - stage_num + stage_id + 1); j < backward_start_pair.first.size();
++j) {
auto prior_node5 = backward_end_pair.second[j - 1];
auto post_node5 = backward_start_pair.first[j];
InsertDepend(prior_node5, post_node5, manager, root);
}
if (!IsLastStage()) {
auto prior_node6 = forward_end_before_pair.second[stage_num - 1 - stage_id];
auto post_node6 = backward_start_pair.first[0];
InsertDepend(prior_node6, post_node6, manager, root);
}
}
void ReorderForParams(const std::vector<AnfNodePtr> &backward_params, const std::vector<AnfNodePtr> &forward_params,
const std::vector<AnfNodePtr> &allreduce_params, const PipelinePair &forward_params_pair,
const PipelinePair &backward_params_pair, const std::vector<AnfNodePtr> &backward_end,
const PipelinePair &forward_start_pair, const FuncGraphPtr &root) {
auto manager = root->manager();
MS_EXCEPTION_IF_NULL(manager);
if (!forward_params.empty()) {
auto prior_node = forward_params_pair.second[0];
auto post_node = forward_start_pair.first[0];
InsertDepend(prior_node, post_node, manager, root);
}
if (!backward_params.empty()) {
if (!allreduce_params.empty()) {
for (auto &node : allreduce_params) {
auto post_node1 = backward_params_pair.first[0];
InsertDepend(node, post_node1, manager, root);
}
}
auto prior_node2 = backward_end.back();
auto post_node2 = backward_params[0];
InsertDepend(prior_node2, post_node2, manager, root);
}
}
int64_t GetMicroBatch(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto micro_value = cnode->GetPrimalAttr(MICRO);
MS_EXCEPTION_IF_NULL(micro_value);
return GetValue<int64_t>(micro_value);
}
PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max) {
std::vector<AnfNodePtr> temp_vec;
std::vector<AnfNodePtr> out_vec_begin;
std::vector<AnfNodePtr> out_vec_end;
auto manager = root->manager();
for (int64_t i = 0; i <= micro_max; ++i) {
temp_vec.clear();
for (auto &node : node_vector) {
auto node_micro = GetMicroBatch(node);
if (node_micro == i) {
temp_vec.push_back(node);
}
}
if (temp_vec.size() <= 1) {
MS_LOG(INFO) << "No Duplicate MicroBatch.";
continue;
}
std::sort(temp_vec.begin(), temp_vec.end(), CompFunc);
for (size_t j = 0; j < temp_vec.size() - 1; ++j) {
auto prior_node = temp_vec[j];
auto post_node = temp_vec[j + 1];
InsertDepend(prior_node, post_node, manager, root);
}
if (!temp_vec.empty()) {
out_vec_begin.push_back(temp_vec.front());
out_vec_end.push_back(temp_vec.back());
}
}
if (out_vec_begin.empty()) {
return std::make_pair(node_vector, node_vector);
}
return std::make_pair(out_vec_begin, out_vec_end);
}
void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value) {
auto node_users = (*node_users_map)[node];
for (auto &node_pair : node_users) {
auto user_node = node_pair.first->cast<CNodePtr>();
if (user_node->HasPrimalAttr(MICRO)) {
continue;
}
user_node->AddPrimalAttr(MICRO, value);
BroadCastMicroBatch(user_node, node_users_map, value);
}
}
AnfNodePtr GetPreNode(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
return GetPreNode(cnode->input(1));
}
return cnode;
}
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
if (!IsLastStage()) {
return;
}
auto node_users_map = manager->node_users();
for (auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!cnode->HasPrimalAttr(MICRO)) {
continue;
}
auto prim = GetCNodePrimitive(node);
if (prim && prim->HasAttr(PIPELINE_END)) {
for (auto &temp_node : cnode->inputs()) {
if (!temp_node->isa<CNode>()) {
continue;
}
auto temp_cnode = temp_node->cast<CNodePtr>();
auto temp_prim = GetCNodePrimitive(temp_node);
if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
continue;
}
auto end_node = GetPreNode(temp_node);
auto end_cnode = end_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(end_cnode);
auto end_prim = GetCNodePrimitive(end_node);
OperatorAttrs attrs_;
auto op = CreatOpInstance(attrs_, end_prim->name(), "");
auto value_node = NewValueNode(op);
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
new_prim->SetAttrs(end_prim->attrs());
manager->SetEdge(end_node, 0, value_node);
end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO));
}
}
}
}
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
auto node_users_map = manager->node_users();
for (auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!cnode->HasPrimalAttr(MICRO)) {
continue;
}
auto micro = cnode->GetPrimalAttr(MICRO);
auto prim = GetCNodePrimitive(node);
if (prim && prim->HasAttr(PARAMETER_START)) {
OperatorAttrs attrs_;
auto op = CreatOpInstance(attrs_, prim->name(), "");
auto value_node = NewValueNode(op);
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
new_prim->SetAttrs(prim->attrs());
manager->SetEdge(cnode, 0, value_node);
cnode->AddPrimalAttr(PARAMETER_START, micro);
}
}
}
void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
auto node_users_map = manager->node_users();
for (auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!cnode->HasPrimalAttr(MICRO)) {
continue;
}
auto micro = cnode->GetPrimalAttr(MICRO);
MS_EXCEPTION_IF_NULL(micro);
BroadCastMicroBatch(cnode, &node_users_map, micro);
}
}
void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root) {
std::list<ValuePtr> name_list = {};
auto stage_id = g_device_manager->stage_id();
for (auto &node : root->nodes()) {
if (!node->isa<CNode>()) {
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimZerosLike)) {
continue;
}
auto prim = GetCNodePrimitive(node);
auto cnode = node->cast<CNodePtr>();
if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
auto forward_node_name = cnode->GetPrimalAttr(kPrimalAttrForwardNodeName);
if (std::find(name_list.begin(), name_list.end(), forward_node_name) != name_list.end()) {
continue;
}
name_list.push_back(forward_node_name);
if (cnode->HasPrimalAttr(PIPELINE_END)) {
backward_start->push_back(node);
}
if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
backward_end->push_back(node);
}
if (cnode->HasPrimalAttr(PARAMETER_START)) {
backward_end->push_back(node);
}
if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
backward_params->push_back(node);
}
if (prim->HasAttr(PARAMETER_MICRO)) {
allreduce_params->push_back(node);
}
} else {
if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
if (stage_id != 0 && IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
continue;
}
forward_start->push_back(node);
}
if (cnode->HasPrimalAttr(PIPELINE_END)) {
forward_end->push_back(node);
}
if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
forward_params->push_back(node);
}
}
}
std::sort((*backward_start).begin(), (*backward_start).end(), CompFunc);
std::sort((*backward_end).begin(), (*backward_end).end(), CompFunc);
std::sort((*forward_start).begin(), (*forward_start).end(), CompFunc);
std::sort((*forward_end).begin(), (*forward_end).end(), CompFunc);
std::sort((*backward_params).begin(), (*backward_params).end(), CompFunc);
std::sort((*forward_params).begin(), (*forward_params).end(), CompFunc);
}
void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
size_t micro_size) {
micro_size = micro_size + 1;
if (forward_start_pair.first.size() != micro_size) {
MS_LOG(EXCEPTION) << "forward_node's size:" << forward_start_pair.first.size()
<< "is not equal to micro size:" << micro_size;
}
if (forward_end_pair.first.size() != micro_size) {
MS_LOG(EXCEPTION) << "forward_node's size:" << forward_end_pair.first.size()
<< "is not equal to micro size:" << micro_size;
}
if (backward_start_pair.first.size() != micro_size) {
MS_LOG(EXCEPTION) << "backward_node's size:" << backward_start_pair.first.size()
<< "is not equal to micro size:" << micro_size;
}
if (backward_end_pair.first.size() != micro_size) {
MS_LOG(EXCEPTION) << "backward_node's size:" << backward_end_pair.first.size()
<< "is not equal to micro size:" << micro_size;
}
}
void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
std::vector<AnfNodePtr> forward_start;
std::vector<AnfNodePtr> forward_end;
std::vector<AnfNodePtr> forward_params;
std::vector<AnfNodePtr> backward_start;
std::vector<AnfNodePtr> backward_end;
std::vector<AnfNodePtr> backward_params;
std::vector<AnfNodePtr> allreduce_params;
GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
&allreduce_params, root);
auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
MS_EXCEPTION_IF_NULL(micro_size);
auto micro_max = GetValue<int64_t>(micro_size);
auto backward_start_pair = Deduplicate(backward_start, root, micro_max);
auto backward_end_pair = Deduplicate(backward_end, root, micro_max);
auto forward_start_pair = Deduplicate(forward_start, root, micro_max);
auto forward_end_pair = Deduplicate(forward_end, root, micro_max);
auto forward_params_pair = Deduplicate(forward_params, root, micro_max);
auto backward_params_pair = Deduplicate(backward_params, root, micro_max);
CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, LongToSize(micro_max));
PipelinePair forward_end_before_pair;
if (!IsLastStage()) {
for (auto &node : forward_end_pair.first) {
auto cnode = node->cast<CNodePtr>();
forward_end_before_pair.first.push_back(cnode->input(1));
}
for (auto &node : forward_end_pair.second) {
auto cnode = node->cast<CNodePtr>();
forward_end_before_pair.second.push_back(cnode->input(1));
}
} else {
forward_end_before_pair = forward_end_pair;
}
ReorderForForward(forward_start_pair.first, forward_end_pair.second, root);
ReorderForBackward(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair,
forward_end_before_pair, root);
ReorderForParams(backward_params, forward_params, allreduce_params, forward_params_pair, backward_params_pair,
backward_end, forward_start_pair, root);
}
void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
std::vector<AnfNodePtr> forward_end;
std::vector<AnfNodePtr> forward_start;
std::vector<AnfNodePtr> forward_params;
for (auto &node : root->nodes()) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
forward_start.push_back(node);
}
if (cnode->HasPrimalAttr(PIPELINE_END)) {
forward_end.push_back(node);
}
if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
forward_params.push_back(node);
}
}
std::sort(forward_start.begin(), forward_start.end(), CompFunc);
std::sort(forward_end.begin(), forward_end.end(), CompFunc);
std::sort(forward_params.begin(), forward_params.end(), CompFunc);
auto forward_start_pair = Deduplicate(forward_start, root, 0);
auto forward_end_pair = Deduplicate(forward_end, root, 0);
auto forward_params_pair = Deduplicate(forward_params, root, 0);
if (!forward_end.empty() && !forward_params.empty()) {
InsertDepend(forward_params_pair.second[0], forward_end_pair.first[0], manager, root);
}
if (!forward_start.empty() && !forward_params.empty()) {
InsertDepend(forward_params_pair.second[0], forward_start_pair.first[0], manager, root);
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_
#include <utility>
#include <vector>
#include <string>
#include "ir/anf.h"
#include "ir/manager.h"
namespace mindspore {
namespace parallel {
using PipelinePair = std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>;
AnfNodePtr FindAccuGrad(const CNodePtr &cnode);
bool IsLastStage();
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
const AnfNodePtr &accu_parameter);
void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr &param);
AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> &parameters, const std::string &name);
void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
void AddVirtualAssignAdd(const FuncGraphPtr &root);
bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2);
void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end,
const FuncGraphPtr &root);
void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root);
void ReorderForParams(const std::vector<AnfNodePtr> &backward_params, const std::vector<AnfNodePtr> &forward_params,
const std::vector<AnfNodePtr> &allreduce_params, const PipelinePair &forward_params_pair,
const PipelinePair &backward_params_pair, const std::vector<AnfNodePtr> &backward_end,
const PipelinePair &forward_start_pair, const FuncGraphPtr &root);
int64_t GetMicroBatch(const AnfNodePtr &node);
void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
const FuncGraphPtr &root);
PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max);
void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root);
void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value);
AnfNodePtr GetPreNode(const AnfNodePtr &node);
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
void SetStridedSliceStrategy(const AnfNodePtr &node);
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, size_t micro_size);
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_

View File

@ -342,11 +342,12 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) { void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(comm_node); MS_EXCEPTION_IF_NULL(comm_node);
MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(param_node);
ParameterPtr param;
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; param = param_node->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
return; } else {
param = param_node->cast<ParameterPtr>();
} }
auto param = param_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param); MS_EXCEPTION_IF_NULL(param);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
@ -372,6 +373,22 @@ void AddCommOpMeanFlag(const CNodePtr &comm_node) {
prim->SetAttrs(attrs); prim->SetAttrs(attrs);
} }
void AddCommOpParamFlag(const CNodePtr &comm_node) {
MS_EXCEPTION_IF_NULL(comm_node);
auto graph = comm_node->func_graph();
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users = manager->node_users()[comm_node->input(1)];
for (auto &node_user : node_users) {
if (IsPrimitiveCNode(node_user.first, prim::kPrimSend)) {
auto prim = GetCNodePrimitive(comm_node);
prim->AddAttr(PARAMETER_MICRO, MakeValue(0));
return;
}
}
}
Operator CreateAllGatherOp(const std::string &group) { Operator CreateAllGatherOp(const std::string &group) {
OperatorName operator_name = ALL_GATHER; OperatorName operator_name = ALL_GATHER;
ValuePtr attr0_value = MakeValue(group); // group ValuePtr attr0_value = MakeValue(group); // group
@ -438,6 +455,7 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
OperatorVector op_for_weight; OperatorVector op_for_weight;
bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
ValuePtr attr0_value = MakeValue(group_name); ValuePtr attr0_value = MakeValue(group_name);
ValuePtr attr1_value = MakeValue(SizeToLong(dev_num)); ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
@ -459,6 +477,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value); Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value);
operator_attrs.push_back(attr3); operator_attrs.push_back(attr3);
MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror"; MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror";
} else if (split_stage_num > 1) {
operator_name = MIRROR_MICRO_STEP_OPERATOR;
} else { } else {
operator_name = MIRROR_OPERATOR; operator_name = MIRROR_OPERATOR;
} }

View File

@ -294,6 +294,7 @@ Operator CreateAllGatherOp(const std::string &group);
Operator CreateMiniStepAllGatherOp(const std::string &group); Operator CreateMiniStepAllGatherOp(const std::string &group);
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node); void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
void AddCommOpMeanFlag(const CNodePtr &comm_node); void AddCommOpMeanFlag(const CNodePtr &comm_node);
void AddCommOpParamFlag(const CNodePtr &comm_node);
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);

View File

@ -168,7 +168,7 @@ constexpr char CLONED_INDEX[] = "cloned_index";
constexpr char BE_CLONED_INDEX[] = "be_cloned_index"; constexpr char BE_CLONED_INDEX[] = "be_cloned_index";
constexpr char GROUP_RANKS[] = "group_ranks"; constexpr char GROUP_RANKS[] = "group_ranks";
constexpr char IS_IN_FORWARD[] = "is_in_forward"; constexpr char IS_IN_FORWARD[] = "is_in_forward";
constexpr char DTYPE[] = "DType"; constexpr char DTYPE[] = "dtype";
constexpr char DEV_NUM[] = "dev_num"; constexpr char DEV_NUM[] = "dev_num";
constexpr char MEAN_FLAG[] = "mean_flag"; constexpr char MEAN_FLAG[] = "mean_flag";
constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step"; constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step";
@ -358,6 +358,23 @@ constexpr char UNIQUE[] = "Unique";
constexpr char GATHERND[] = "GatherNd"; constexpr char GATHERND[] = "GatherNd";
constexpr char SCATTER_UPDATE[] = "ScatterUpdate"; constexpr char SCATTER_UPDATE[] = "ScatterUpdate";
// pipeline
constexpr char MICRO[] = "micro";
constexpr char DEST_RANK[] = "dest_rank";
constexpr char SRC_RANK[] = "src_rank";
constexpr char PIPELINE_PARAM[] = "pipeline_param";
constexpr char PIPELINE_END[] = "pipeline_end";
constexpr char PIPELINE_BEGIN[] = "pipeline_begin";
constexpr char MAIN_GRAPH[] = "main_graph";
constexpr char SR_TAG[] = "sr_tag";
constexpr char GROUP_BACK[] = "group_back";
constexpr char MIRROR_MICRO_STEP_OPERATOR[] = "_MirrorMicroStepOperator";
constexpr char PARAMETER_MICRO[] = "parameter_micro";
constexpr char VIRTUAL_ASSIGN_ADD[] = "_VirtualAssignAdd";
constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad";
constexpr char ACCU_GRAD[] = "accu_grad";
constexpr char PARAMETER_START[] = "parameter_start";
// Parallel don't care // Parallel don't care
constexpr char STRING_EQUAL[] = "string_equal"; constexpr char STRING_EQUAL[] = "string_equal";
constexpr char MAKE_TUPLE[] = "MakeTuple"; constexpr char MAKE_TUPLE[] = "MakeTuple";

View File

@ -29,6 +29,7 @@
#include "frontend/parallel/step_parallel.h" #include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/node_check.h" #include "frontend/parallel/node_check.h"
#include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "base/core_ops.h" #include "base/core_ops.h"
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
@ -52,30 +53,74 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
return false; return false;
} }
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, size_t curr_depth) { void PipelineTransformer::MainGraph() {
if (curr_depth > MAX_RECURSIVE_DEPTH) { if (!root_->has_flag(TRAINING)) {
MS_LOG(WARNING) << "When setting the tags for Grad nodes, exceeded the maximum recursion depth: " main_graph_ = root_;
<< MAX_RECURSIVE_DEPTH;
return; return;
} }
const auto &node_users = manager->node_users()[node]; for (auto &fg : manager_->func_graphs()) {
for (auto &user_pair : node_users) { for (auto &node : fg->nodes()) {
auto user_node = user_pair.first; if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
if (!user_node->grad()) { main_graph_ = fg;
user_node->set_grad(true); main_graph_->set_flag(MAIN_GRAPH, true);
SetGradTag(user_node, manager, ++curr_depth); virtual_dataset_ = node;
return;
}
}
}
MS_LOG(EXCEPTION) << "Can't find main graph, possible reason is can't find virtual dataset.";
}
ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size) {
if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
MS_LOG(EXCEPTION) << "Can't find MicroBatch information.";
}
auto cnode = node->cast<CNodePtr>();
auto value = GetValueNode(cnode->input(2));
MS_EXCEPTION_IF_NULL(value);
auto tuple = GetValue<std::vector<int64_t>>(value);
auto input_shape = GetNodeShape(cnode->input(1)).at(0);
int64_t micro = tuple.at(0) * micro_size / input_shape.at(0);
cnode->AddPrimalAttr(MICRO, MakeValue(micro));
cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro));
return MakeValue(micro);
}
void PipelineTransformer::LabelMicroBatch() {
MS_EXCEPTION_IF_NULL(main_graph_);
MS_EXCEPTION_IF_NULL(virtual_dataset_);
auto node_user_map = manager_->node_users();
auto node_users = node_user_map[virtual_dataset_];
for (auto &node_user : node_users) {
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
auto data_users = manager_->node_users()[node_user.first];
auto micro_size = int64_t(data_users.size());
micro_size_ = micro_size;
MS_LOG(INFO) << "Micro Size is: " << micro_size;
for (auto &data_user : data_users) {
auto micro = SetMicroBatch(data_user.first, micro_size);
SetStridedSliceStrategy(data_user.first);
auto cnode = data_user.first->cast<CNodePtr>();
BroadCastMicroBatch(cnode, &node_user_map, micro);
}
} }
} }
} }
void PipelineTransformer::LabelRequiredGradCNode() { void PipelineTransformer::CreateForwardGroup() {
auto parameters = root_->parameters(); std::vector<int64_t> rank_list;
for (auto parameter : parameters) { auto rank_id = g_device_manager->global_rank();
if (!ParameterRequireGrad(parameter)) { auto stage_id = g_device_manager->stage_id();
continue; auto stage_num = g_device_manager->stage_num();
} for (int64_t i = 0; i < stage_num; ++i) {
SetGradTag(parameter, manager_, 0); rank_list.push_back(rank_id + per_stage_rank_num_ * (i - stage_id));
} }
auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list);
auto g = g_device_manager->CreateGroup(rank_list);
auto g_back_name = g.name() + BACKWARD;
auto g_back = g_device_manager->CreateGroup(g_back_name, dev_list);
group_.push_back(g.name());
group_.push_back(g_back.name());
} }
void PipelineTransformer::Coloring() { void PipelineTransformer::Coloring() {
@ -84,7 +129,7 @@ void PipelineTransformer::Coloring() {
while (need_coloring) { while (need_coloring) {
need_coloring = false; need_coloring = false;
for (auto &fg : manager_->func_graphs()) { for (auto &fg : manager_->func_graphs()) {
if (fg == root_) { if (fg == root_ && root_->has_flag(TRAINING)) {
continue; continue;
} }
auto value_nodes = fg->value_nodes(); auto value_nodes = fg->value_nodes();
@ -94,16 +139,15 @@ void PipelineTransformer::Coloring() {
continue; continue;
} }
auto graph = GetValueNode<FuncGraphPtr>(node); auto graph = GetValueNode<FuncGraphPtr>(node);
auto need_grad = graph->get_return()->grad(); if (graph->stage() == -1) {
continue;
}
stage_set.insert(graph->stage());
auto node_users = manager_->node_users()[node]; auto node_users = manager_->node_users()[node];
for (auto &user_pair : node_users) { for (auto &user_pair : node_users) {
auto user_node = user_pair.first->cast<CNodePtr>(); auto user_node = user_pair.first->cast<CNodePtr>();
user_node->set_stage(graph->stage()); user_node->set_stage(graph->stage());
user_node->set_grad(need_grad);
auto user_node_graph = user_node->func_graph(); auto user_node_graph = user_node->func_graph();
if (graph->stage() != -1) {
stage_set.insert(graph->stage());
}
if (graph->stage() == stage_ && user_node_graph->stage() == -1) { if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
user_node_graph->set_stage(graph->stage()); user_node_graph->set_stage(graph->stage());
need_coloring = true; need_coloring = true;
@ -117,22 +161,37 @@ void PipelineTransformer::Coloring() {
if (SizeToLong(stage_set.size()) != stage_num) { if (SizeToLong(stage_set.size()) != stage_num) {
MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size(); MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
} }
return;
} }
void PipelineTransformer::BroadCastColoring() { void PipelineTransformer::BroadCastColoring() {
for (auto &fg : manager_->func_graphs()) { auto need_coloring = true;
if (fg == root_ || fg->stage() == -1) { while (need_coloring) {
need_coloring = false;
auto all_nodes = main_graph_->nodes();
auto node_users = manager_->node_users();
for (auto &node : all_nodes) {
if (!node->isa<CNode>() || node->stage() == -1) {
continue; continue;
} }
DoBroadCast(fg); auto stage = node->stage();
SetNoStageNode(fg); for (auto &user_pair : node_users[node]) {
auto user_node = user_pair.first->cast<CNodePtr>();
auto user_node_stage = user_node->stage();
if (stage > user_node_stage) {
user_node->set_stage(stage);
need_coloring = true;
}
}
}
} }
} }
bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (!prim) {
return false;
}
if (IsInWhiteList(cnode)) { if (IsInWhiteList(cnode)) {
return false; return false;
} }
@ -148,6 +207,9 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
if (!IsPipelineCareNode(cnode)) { if (!IsPipelineCareNode(cnode)) {
MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " is not a Pipeline Care Node."; MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " is not a Pipeline Care Node.";
} }
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
SetVirtualDatasetStrategy(cnode);
}
auto shape_list = ExtractShape(cnode); auto shape_list = ExtractShape(cnode);
if (shape_list.empty()) { if (shape_list.empty()) {
MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " failed to extract shape."; MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " failed to extract shape.";
@ -155,7 +217,7 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == RESHAPE) { if (prim->name() == RESHAPE) {
MS_LOG(EXCEPTION) << "Reshape op can't be a border."; MS_LOG(EXCEPTION) << "Reshape op can't be a border. node:" << cnode->DebugString();
} }
auto attrs = prim->attrs(); auto attrs = prim->attrs();
auto op_info = OperatorInstance(prim, attrs, shape_list); auto op_info = OperatorInstance(prim, attrs, shape_list);
@ -190,7 +252,11 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const A
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// Handle Cast and TupleGetitem situation // Handle Cast and TupleGetitem situation
size_t tensor_info_index = 0; size_t tensor_info_index = 0;
if (IsPrimitiveCNode(cnode, prim::kPrimCast)) { OperatorInfoPtr op_info;
if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
op_info = node->user_data<OperatorInfo>();
} else {
if (IsPrimitiveCNode(node, prim::kPrimCast)) {
cnode = cnode->input(1)->cast<CNodePtr>(); cnode = cnode->input(1)->cast<CNodePtr>();
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { } else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode)); tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode));
@ -198,85 +264,75 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const A
} }
// Create OperatorInfo to get slice_shape for send/recv // Create OperatorInfo to get slice_shape for send/recv
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto op_info = CreateOpInfo(cnode); op_info = CreateOpInfo(cnode);
}
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index]; auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index];
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info)); return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
} }
CNodePtr PipelineTransformer::HandleMonadLoad(const AnfNodePtr &node) { std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto &node_users = manager_->node_users()[node]; auto node_users_map = manager_->node_users();
auto node_users = node_users_map[node];
for (auto &node_user : node_users) {
auto load = node_user.first->cast<CNodePtr>();
if (IsPrimitiveCNode(load, prim::kPrimLoad)) {
node_users = node_users_map[load];
break;
}
}
for (auto &user_pair : node_users) { for (auto &user_pair : node_users) {
auto user_node = user_pair.first->cast<CNodePtr>(); auto user_node = user_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_node); MS_EXCEPTION_IF_NULL(user_node);
if (IsPipelineCareNode(user_node)) { auto user_node_graph = user_node->func_graph();
return user_node; MS_EXCEPTION_IF_NULL(user_node_graph);
} if (user_node_graph->stage() == -1) {
}
return nullptr;
}
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto &node_users = manager_->node_users()[node];
for (auto &user_pair : node_users) {
auto care_node = user_pair.first;
auto care_cnode = care_node->cast<CNodePtr>();
if (IsPrimitiveCNode(care_node, prim::kPrimLoad)) {
care_cnode = HandleMonadLoad(care_node);
if (!care_cnode) {
continue; continue;
} }
} else { auto care_node = user_node;
if (!IsPipelineCareNode(care_cnode)) { auto index = user_pair.second;
if (IsValueNode<FuncGraph>(user_node->input(0))) {
auto graph = GetValueNode<FuncGraphPtr>(user_node->input(0));
auto temp_params = graph->parameters();
if (temp_params.size() < IntToSize(user_pair.second)) {
MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString()
<< "'s range.";
}
auto temp_param = temp_params[user_pair.second - 1];
auto temp_users = node_users_map[temp_param];
for (auto &temp_user : temp_users) {
auto load_temp = temp_user.first->cast<CNodePtr>();
if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) {
temp_users = node_users_map[load_temp];
break;
}
}
for (auto &temp_pair : temp_users) {
auto temp_cnode = temp_pair.first->cast<CNodePtr>();
if (!IsPipelineCareNode(temp_cnode)) {
continue; continue;
} }
care_node = temp_cnode;
index = temp_pair.second;
break;
} }
MS_EXCEPTION_IF_NULL(care_cnode); }
auto op_info = CreateOpInfo(care_cnode); if (!IsPipelineCareNode(care_node)) {
continue;
}
auto op_info = CreateOpInfo(care_node);
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
auto tensor_info = op_info->inputs_tensor_info()[IntToSize(user_pair.second) - 1]; auto tensor_info = op_info->inputs_tensor_info()[IntToSize(index) - 1];
return std::make_pair(nullptr, std::make_shared<TensorInfo>(tensor_info)); return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
} }
return std::make_pair(nullptr, nullptr); return std::make_pair(nullptr, nullptr);
} }
void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
auto need_coloring = true;
while (need_coloring) {
need_coloring = false;
auto all_nodes = func->nodes();
auto &node_users = manager_->node_users();
for (auto &node : all_nodes) {
if (node->isa<CNode>() || node->stage() == -1) {
continue;
}
auto stage = node->stage();
for (auto &user_pair : node_users[node]) {
auto user_node = user_pair.first->cast<CNodePtr>();
auto user_node_stage = user_node->stage();
if (IsValueNode<FuncGraph>(user_node->input(0)) && stage > user_node_stage) {
user_node->set_stage(stage);
need_coloring = true;
}
}
}
}
}
void PipelineTransformer::SetNoStageNode(const FuncGraphPtr &func) {
auto all_nodes = func->nodes();
for (auto &node : all_nodes) {
if (!node->isa<CNode>() || node->stage() != -1) {
continue;
}
node->set_stage(0);
}
}
void PipelineTransformer::HandleSharedParameter() {
auto parameters = root_->parameters(); auto parameters = root_->parameters();
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
std::vector<AnfNodePtr> recvs = {};
for (auto &parameter : parameters) { for (auto &parameter : parameters) {
auto parameter_stage = parameter_color_map[parameter]; auto parameter_stage = parameter_color_map[parameter];
if (parameter_stage.size() <= 1) { if (parameter_stage.size() <= 1) {
@ -285,38 +341,42 @@ void PipelineTransformer::HandleSharedParameter() {
auto users = manager_->node_users()[parameter]; auto users = manager_->node_users()[parameter];
for (auto &user : users) { for (auto &user : users) {
auto node = user.first; auto node = user.first;
auto cnode = node->cast<CNodePtr>();
auto graph = node->func_graph(); auto graph = node->func_graph();
if (graph != root_ && graph->stage() == -1) { if (IsValueNode<FuncGraph>(cnode->input(0))) {
MS_LOG(EXCEPTION) << "Don't support this situation."; graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
} }
if (graph == root_ || graph->stage() != stage_) { if (graph == root_ || graph->stage() == -1 || !parameter_stage.count(stage_)) {
continue; continue;
} }
auto micro = cnode->GetPrimalAttr(MICRO);
if (!micro) {
MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch";
micro = MakeValue(int64_t(0));
}
auto user_stage = node->stage();
if (stage_ == *parameter_stage.begin()) { if (stage_ == *parameter_stage.begin()) {
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; if (graph->stage() == stage_) {
for (auto &stage : parameter_stage) {
if (stage == stage_) {
continue; continue;
} else { }
auto send_out = InsertSend(graph, parameter, stage, stage_); if (Reuse(parameter, user_stage, make_tuple_input, DEST_RANK)) {
continue;
}
auto send_out = InsertSend(main_graph_, parameter, user_stage, stage_, micro);
make_tuple_input.push_back(send_out.depend); make_tuple_input.push_back(send_out.depend);
}
}
auto make_tuple = graph->NewCNode(make_tuple_input);
OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "");
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, make_tuple};
auto depend = graph->NewCNode(depend_input);
depend->set_abstract(parameter->abstract());
manager_->SetEdge(node, user.second, depend);
break;
} else { } else {
(void)InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin()); auto receive = Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK);
break; if (receive) {
manager_->SetEdge(node, user.second, receive);
} else {
auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro);
recvs.push_back(recv);
} }
} }
} }
} }
return make_tuple_input;
}
void PipelineTransformer::ParameterColoring() { void PipelineTransformer::ParameterColoring() {
auto parameters = root_->parameters(); auto parameters = root_->parameters();
@ -324,14 +384,24 @@ void PipelineTransformer::ParameterColoring() {
auto users = manager_->node_users()[parameter]; auto users = manager_->node_users()[parameter];
std::set<int64_t> parameter_stage; std::set<int64_t> parameter_stage;
for (auto &user : users) { for (auto &user : users) {
auto node = user.first; auto node = user.first->cast<CNodePtr>();
auto graph = node->func_graph(); auto graph = node->func_graph();
if (IsValueNode<FuncGraph>(node->input(0))) {
graph = GetValueNode<FuncGraphPtr>(node->input(0));
}
if (graph != root_ && graph->stage() != -1) { if (graph != root_ && graph->stage() != -1) {
parameter_stage.insert(graph->stage()); parameter_stage.insert(graph->stage());
parameter->set_stage(graph->stage()); parameter->set_stage(graph->stage());
} }
} }
if (*parameter_stage.begin() == stage_ && !virtual_param_) { auto param_info = parameter->cast<ParameterPtr>()->param_info();
if (!param_info) {
parameter_color_map[parameter] = parameter_stage;
continue;
}
MS_EXCEPTION_IF_NULL(param_info);
auto requires_grad = param_info->requires_grad();
if (*parameter_stage.begin() == stage_ && !virtual_param_ && requires_grad) {
virtual_param_ = parameter; virtual_param_ = parameter;
} }
parameter_color_map[parameter] = parameter_stage; parameter_color_map[parameter] = parameter_stage;
@ -343,8 +413,8 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) { if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) {
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
auto graph_return = graph->get_return(); auto graph_output = graph->output();
type = graph_return->Type(); type = graph_output->Type();
} else { } else {
type = node->Type(); type = node->Type();
} }
@ -359,40 +429,38 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con
return std::make_pair(shape_list, dtype); return std::make_pair(shape_list, dtype);
} }
AnfNodePtr PipelineTransformer::HandleMonadDepend(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
auto cnode = node->cast<CNodePtr>();
return HandleMonadDepend(cnode->input(1));
}
return node;
}
AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) { AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (IsValueNode<FuncGraph>(cnode->input(0))) { if (IsValueNode<FuncGraph>(cnode->input(0))) {
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0)); auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
auto output = HandleMonadDepend(graph->output()); auto output = graph->output();
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
if (output->isa<Parameter>()) { if (output->isa<Parameter>()) {
return output; auto parameters = graph->parameters();
auto pos_iter = std::find(parameters.begin(), parameters.end(), output);
auto pos = std::distance(parameters.begin(), pos_iter);
return FindPipelineCareNode(cnode->input(pos + 1));
} }
cnode = output->cast<CNodePtr>(); cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
} }
if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
return FindPipelineCareNode(cnode->input(1));
}
if (IsInWhiteList(cnode)) { if (IsInWhiteList(cnode)) {
return cnode->cast<AnfNodePtr>(); return cnode->cast<AnfNodePtr>();
} }
if (!IsPipelineCareNode(cnode)) { if (!IsPipelineCareNode(cnode)) {
MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."; MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."
<< " border node: " << cnode->DebugString();
} }
return cnode->cast<AnfNodePtr>(); return cnode->cast<AnfNodePtr>();
} }
SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter,
int64_t user_node_stage, int64_t node_stage) { int64_t user_node_stage, int64_t node_stage, const ValuePtr &value) {
auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
int64_t send_tag; int64_t send_tag;
if (send_tag_map.find(dest_rank) != send_tag_map.end()) { if (send_tag_map.find(dest_rank) != send_tag_map.end()) {
@ -402,17 +470,25 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
send_tag = 0; send_tag = 0;
send_tag_map[dest_rank] = 0; send_tag_map[dest_rank] = 0;
} }
Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage));
OperatorAttrs attrs = {attr_tag, attr_rank}; Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
auto send_op = CreatOpInstance(attrs, SEND, "send"); Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
auto send_op = CreatOpInstance(attrs, SEND, SEND);
auto send_node = NewValueNode(send_op); auto send_node = NewValueNode(send_op);
auto prim = GetValueNode<PrimitivePtr>(send_node); auto prim = GetValueNode<PrimitivePtr>(send_node);
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair; std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
AnfNodePtr care_node;
if (parameter->isa<Parameter>()) { if (parameter->isa<Parameter>()) {
op_info_pair = GetParameterPair(parameter); op_info_pair = GetParameterPair(parameter);
} else { } else {
auto care_node = FindPipelineCareNode(parameter); if (IsPrimitiveCNode(parameter, prim::kPrimCast)) {
auto parameter_cnode = parameter->cast<CNodePtr>();
care_node = FindPipelineCareNode(parameter_cnode->input(1));
} else {
care_node = FindPipelineCareNode(parameter);
}
if (care_node->isa<Parameter>()) { if (care_node->isa<Parameter>()) {
op_info_pair = GetParameterPair(care_node); op_info_pair = GetParameterPair(care_node);
} else { } else {
@ -423,14 +499,20 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
MS_EXCEPTION_IF_NULL(tensor_info); MS_EXCEPTION_IF_NULL(tensor_info);
auto slice_shape = tensor_info->slice_shape(); auto slice_shape = tensor_info->slice_shape();
auto shape_type_pair = GetShapeType(parameter, slice_shape); auto shape_type_pair = GetShapeType(parameter, slice_shape);
prim->set_attr("shape", shape_type_pair.first); prim->set_attr(SHAPE, shape_type_pair.first);
prim->set_attr("dtype", shape_type_pair.second); prim->set_attr(DTYPE, shape_type_pair.second);
std::vector<AnfNodePtr> send_input = {send_node, parameter}; std::vector<AnfNodePtr> send_input = {send_node, parameter};
auto send = graph->NewCNode(send_input); auto send = main_graph_->NewCNode(send_input);
if (!parameter->isa<Parameter>() && care_node != nullptr && !care_node->isa<Parameter>()) {
send->AddPrimalAttr(PIPELINE_END, value);
} else {
send->AddPrimalAttr(PIPELINE_PARAM, value);
send->AddPrimalAttr(MICRO, value);
}
OperatorAttrs depend_attrs; OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend"); auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send}; std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
auto depend = graph->NewCNode(depend_input); auto depend = main_graph_->NewCNode(depend_input);
auto abstract = parameter->abstract(); auto abstract = parameter->abstract();
depend->set_abstract(abstract); depend->set_abstract(abstract);
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
@ -439,7 +521,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
const AnfNodePtr &use_node, int index, int64_t user_node_stage, const AnfNodePtr &use_node, int index, int64_t user_node_stage,
int64_t node_stage) { int64_t node_stage, const ValuePtr &value) {
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
int64_t recv_tag; int64_t recv_tag;
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) { if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
@ -449,9 +531,10 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
recv_tag = 0; recv_tag = 0;
recv_tag_map[src_rank] = 0; recv_tag_map[src_rank] = 0;
} }
Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank)); Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair; std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
bool is_param = true;
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
op_info_pair = GetParameterPair(node); op_info_pair = GetParameterPair(node);
} else { } else {
@ -460,28 +543,34 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
op_info_pair = GetParameterPair(care_node); op_info_pair = GetParameterPair(care_node);
} else { } else {
op_info_pair = GetOpInfo(care_node); op_info_pair = GetOpInfo(care_node);
is_param = false;
} }
} }
auto tensor_info = op_info_pair.second; auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info); MS_EXCEPTION_IF_NULL(tensor_info);
auto slice_shape = tensor_info->slice_shape(); auto tensor_layout = tensor_info->tensor_layout();
Shape slice_shape = tensor_info->slice_shape();
auto shape_type_pair = GetShapeType(node, slice_shape); auto shape_type_pair = GetShapeType(node, slice_shape);
Attr attr_shape = std::make_pair("shape", shape_type_pair.first); Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv"); Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
auto recv_op = CreatOpInstance(attrs, RECEIVE, RECEIVE);
std::vector<AnfNodePtr> recv_input; std::vector<AnfNodePtr> recv_input;
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
recv_input = {NewValueNode(recv_op), node}; recv_input = {NewValueNode(recv_op), node};
} else { } else {
if (node->grad()) {
recv_input = {NewValueNode(recv_op), virtual_param_}; recv_input = {NewValueNode(recv_op), virtual_param_};
} else {
auto param = root_->parameters()[0];
recv_input = {NewValueNode(recv_op), param};
}
} }
auto recv = graph->NewCNode(recv_input); auto recv = graph->NewCNode(recv_input);
if (is_param) {
recv->set_user_data<AnfNode>(PIPELINE_PARAM, node);
recv->AddPrimalAttr(PIPELINE_PARAM, value);
} else {
recv->AddPrimalAttr(PIPELINE_BEGIN, value);
}
recv->AddPrimalAttr(MICRO, value);
auto node_abstract = node->abstract(); auto node_abstract = node->abstract();
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
@ -494,65 +583,53 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
} }
MS_EXCEPTION_IF_NULL(node_abstract); MS_EXCEPTION_IF_NULL(node_abstract);
recv->set_abstract(node_abstract); recv->set_abstract(node_abstract);
if (op_info_pair.first != nullptr) { if (node->isa<Parameter>()) {
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout())); BaseShapePtr parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
recv->set_user_data<OperatorInfo>(op_info_pair.first); auto abstract_clone = node->abstract()->Clone();
MS_EXCEPTION_IF_NULL(abstract_clone);
abstract_clone->set_shape(parallel_shape);
node->set_abstract(abstract_clone);
node->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
} }
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
recv->set_user_data<OperatorInfo>(op_info_pair.first);
manager_->SetEdge(use_node, index, recv); manager_->SetEdge(use_node, index, recv);
return recv; return recv;
} }
bool PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage, AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
const std::vector<AnfNodePtr> &out_input) { const std::string &tag) {
auto node_users = manager_->node_users()[node]; for (auto &input : out_input) {
auto dest_rank = global_rank_ + (next_node_stage - node_stage) * per_stage_rank_num_; auto cnode = input->cast<CNodePtr>();
for (auto &depend : out_input) { if (!cnode) {
if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) {
continue; continue;
} }
auto cnode = depend->cast<CNodePtr>(); if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
cnode = cnode->input(2)->cast<CNodePtr>();
}
if (cnode->input(1) == node) { if (cnode->input(1) == node) {
auto send_cnode = cnode->input(2)->cast<CNodePtr>(); auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
auto prim = GetValueNode<PrimitivePtr>(send_cnode->input(0)); auto dest_rank_send = GetValue<int64_t>(prim->GetAttr(tag));
auto dest_rank_send = GetValue<int64_t>(prim->GetAttr("dest_rank")); if (dest_rank_send == stage) {
if (dest_rank_send == dest_rank) { return input;
return true;
} }
} }
} }
return false; return nullptr;
} }
std::pair<bool, int64_t> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) { std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
std::set<int64_t> tag_set;
auto node_stage = node->stage();
int64_t min_tag = node_stage;
for (auto &user_pair : node_users) {
auto user_node = user_pair.first;
auto user_node_stage = user_node->stage();
tag_set.insert(user_node_stage);
if (user_node_stage == -1) {
continue;
}
min_tag = min_tag > user_node_stage ? user_node_stage : min_tag;
}
bool is_shared = tag_set.size() > 1;
return std::make_pair(is_shared, min_tag);
}
void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
OperatorAttrs depend_attrs; OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, "Depend", ""); auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
std::vector<AnfNodePtr> out_input = {NewValueNode(depend_op)}; std::vector<AnfNodePtr> receive_ops;
std::vector<AnfNodePtr> send_ops;
auto all_nodes = graph->nodes(); auto all_nodes = graph->nodes();
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
if (!node->isa<CNode>() || node->stage() == -1) { if (!node->isa<CNode>() || node->stage() == -1) {
continue; continue;
} }
auto node_users = manager_->node_users()[node]; auto node_users = manager_->node_users()[node];
auto shared_min_tag_pair = IsSharedNode(node, node_users);
auto is_shared = shared_min_tag_pair.first;
auto min_tag = shared_min_tag_pair.second;
AnfNodePtr receive = nullptr; AnfNodePtr receive = nullptr;
for (auto &user_pair : node_users) { for (auto &user_pair : node_users) {
auto user_node = user_pair.first; auto user_node = user_pair.first;
@ -561,21 +638,25 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
if (node_stage != stage_ && user_node_stage != stage_) { if (node_stage != stage_ && user_node_stage != stage_) {
continue; continue;
} }
auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
if (!micro) {
MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
micro = MakeValue(int64_t(0));
}
if (node_stage < user_node_stage) { if (node_stage < user_node_stage) {
if (is_shared && (min_tag != node_stage)) {
continue;
}
if (node_stage == stage_) { if (node_stage == stage_) {
if (Reuse(node, user_node_stage, node_stage, out_input)) { if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) {
continue; continue;
} }
auto send_out = InsertSend(graph, node, user_node_stage, node_stage); auto send_out = InsertSend(graph, node, user_node_stage, node_stage, micro);
out_input.insert(out_input.begin() + 1, send_out.depend); MS_EXCEPTION_IF_NULL(send_out.depend);
type_ptr_ = send_out.type; send_ops.push_back(send_out.depend);
shape_ = send_out.shape; send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
} else { } else {
if (!receive) { if (!receive) {
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage); receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro);
receive_ops.push_back(receive);
} else { } else {
manager_->SetEdge(user_node, user_pair.second, receive); manager_->SetEdge(user_node, user_pair.second, receive);
} }
@ -583,46 +664,40 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
continue; continue;
} }
if (node_stage > user_node_stage) { if (node_stage > user_node_stage) {
auto cnode = node->cast<CNodePtr>(); MS_LOG(EXCEPTION) << "node_stage: " << node_stage
auto user_cnode = user_node->cast<CNodePtr>(); << " must be smaller than user_node_stage: " << user_node_stage;
if (IsValueNode<FuncGraph>(cnode->input(0)) && IsValueNode<FuncGraph>(user_cnode->input(0))) {
MS_LOG(EXCEPTION) << "Don't support this situation";
}
continue;
} }
} }
} }
if (out_input.size() == 2) { return std::make_pair(send_ops, receive_ops);
manager_->Replace(graph->output(), out_input[1]);
}
if (out_input.size() > 2) {
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
make_tuple_inputs.insert(make_tuple_inputs.begin() + 1, out_input.begin() + 2, out_input.end());
auto make_tuple = graph->NewCNode(make_tuple_inputs);
std::vector<AnfNodePtr> out_depend_inputs = {out_input[0], out_input[1], make_tuple};
auto out_node = graph->NewCNode(out_depend_inputs);
manager_->Replace(graph->output(), out_node);
}
} }
void PipelineTransformer::CutGraph() { void PipelineTransformer::CutGraph() {
for (auto &fg : manager_->func_graphs()) { std::vector<AnfNodePtr> make_tuple_inputs;
CutBorder(fg); CreateForwardGroup();
MS_EXCEPTION_IF_NULL(main_graph_);
if (make_tuple_inputs.empty()) {
make_tuple_inputs = HandleSharedParameter();
} }
auto send_recv_ops = CutBorder(main_graph_);
auto send_ops = send_recv_ops.first;
if (IsLastStage()) {
return;
} }
if (send_ops.empty() && !root_->has_flag(TRAINING)) {
bool PipelineTransformer::IsStageNode(const CNodePtr &node) { return;
for (auto &input : node->inputs()) {
if (input->isa<Parameter>()) {
return (*parameter_color_map[input].begin() == stage_ || input->stage() == -1);
} else if (input->isa<CNode>()) {
auto pre_node = input->cast<CNodePtr>();
return IsStageNode(pre_node);
} else {
continue;
} }
make_tuple_inputs.insert(make_tuple_inputs.end(), send_ops.begin(), send_ops.end());
if (!send_ops.empty()) {
type_ptr_ = send_ops.back()->user_data<Type>(DTYPE);
shape_ = send_ops.back()->user_data<ValueList>(SHAPE);
} }
return true; auto make_tuple = main_graph_->NewCNode(make_tuple_inputs);
std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend)};
out.push_back(send_ops.back());
out.push_back(make_tuple);
auto out_node = main_graph_->NewCNode(out);
manager_->Replace(main_graph_->output(), out_node);
} }
void PipelineTransformer::ElimGraphStage() { void PipelineTransformer::ElimGraphStage() {
@ -694,7 +769,21 @@ void PipelineTransformer::ElimParameter() {
std::vector<AnfNodePtr> parameter_list; std::vector<AnfNodePtr> parameter_list;
for (auto &parameter : parameters) { for (auto &parameter : parameters) {
if (!manager_->node_users()[parameter].empty()) { if (!manager_->node_users()[parameter].empty()) {
if (!root_->has_flag(TRAINING)) {
for (auto &node_pair : manager_->node_users()[parameter]) {
auto user_node = node_pair.first;
if (!IsPrimitiveCNode(user_node, prim::kPrimReceive)) {
parameter_list.push_back(parameter); parameter_list.push_back(parameter);
break;
}
// remove_receive_inputs
auto cnode = user_node->cast<CNodePtr>();
std::vector<AnfNodePtr> new_inputs = {cnode->input(0)};
cnode->set_inputs(new_inputs);
}
} else {
parameter_list.push_back(parameter);
}
} }
} }
auto del_num = parameters.size() - parameter_list.size(); auto del_num = parameters.size() - parameter_list.size();

View File

@ -45,13 +45,15 @@ class PipelineTransformer {
: manager_(manager), : manager_(manager),
stage_(stage), stage_(stage),
root_(root), root_(root),
main_graph_(nullptr),
virtual_dataset_(nullptr),
global_rank_(global_rank), global_rank_(global_rank),
per_stage_rank_num_(per_stage_rank_num) {} per_stage_rank_num_(per_stage_rank_num) {}
virtual ~PipelineTransformer() = default; virtual ~PipelineTransformer() = default;
void LabelRequiredGradCNode();
void Coloring(); void Coloring();
void MainGraph();
void LabelMicroBatch();
void BroadCastColoring(); void BroadCastColoring();
void HandleSharedParameter();
void CutGraph(); void CutGraph();
void ParameterColoring(); void ParameterColoring();
void CoverSensShape(); void CoverSensShape();
@ -59,21 +61,18 @@ class PipelineTransformer {
void ElimParameter(); void ElimParameter();
private: private:
std::pair<bool, int64_t> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); void CreateForwardGroup();
void DoBroadCast(const FuncGraphPtr &func); ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size);
std::vector<AnfNodePtr> HandleSharedParameter();
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int64_t user_node_stage, SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int64_t user_node_stage,
int64_t node_stage); int64_t node_stage, const ValuePtr &value);
AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int64_t user_node_stage, int64_t node_stage); int64_t user_node_stage, int64_t node_stage, const ValuePtr &value);
void SetNoStageNode(const FuncGraphPtr &func); std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph);
void CutBorder(const FuncGraphPtr &graph); AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
bool IsStageNode(const CNodePtr &node); const std::string &tag);
bool Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage,
const std::vector<AnfNodePtr> &out_input);
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node); AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node); std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
AnfNodePtr HandleMonadDepend(const AnfNodePtr &node);
CNodePtr HandleMonadLoad(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node); std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode); OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);
bool IsPipelineCareNode(const CNodePtr &cnode); bool IsPipelineCareNode(const CNodePtr &cnode);
@ -81,11 +80,15 @@ class PipelineTransformer {
FuncGraphManagerPtr manager_; FuncGraphManagerPtr manager_;
int64_t stage_; int64_t stage_;
FuncGraphPtr root_; FuncGraphPtr root_;
FuncGraphPtr main_graph_;
AnfNodePtr virtual_dataset_;
int64_t global_rank_; int64_t global_rank_;
int64_t per_stage_rank_num_; int64_t per_stage_rank_num_;
TypePtr type_ptr_; TypePtr type_ptr_;
ValueListPtr shape_; ValueListPtr shape_;
AnfNodePtr virtual_param_; AnfNodePtr virtual_param_;
int64_t micro_size_ = 0;
std::vector<std::string> group_ = {};
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -37,6 +37,7 @@
#include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/graph_util/graph_info.h" #include "frontend/parallel/graph_util/graph_info.h"
#include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "frontend/parallel/node_check.h" #include "frontend/parallel/node_check.h"
#include "frontend/parallel/ops_info/matmul_info.h" #include "frontend/parallel/ops_info/matmul_info.h"
#include "ir/param_info.h" #include "ir/param_info.h"
@ -172,8 +173,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
OperatorArgs arg_forward = op.second; OperatorArgs arg_forward = op.second;
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (grad_accumulation_step > 1) { if (grad_accumulation_step > 1 || split_stage_num > 1) {
auto parameters = root->parameters(); auto parameters = root->parameters();
bool find_grad_accu_node = false; bool find_grad_accu_node = false;
for (auto &param : parameters) { for (auto &param : parameters) {
@ -196,8 +198,8 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
if (op_name == MIRROR_MINI_STEP_OPERATOR) { if (op_name == MIRROR_MINI_STEP_OPERATOR) {
op_name = MIRROR_OPERATOR; op_name = MIRROR_OPERATOR;
arg_forward.first.pop_back(); arg_forward.first.pop_back();
} else if (op_name == MINI_STEP_ALL_GATHER) { } else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR) {
MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation."; MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name;
} }
} }
} }
@ -207,7 +209,8 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
OperatorParams params = arg_forward.second; OperatorParams params = arg_forward.second;
std::vector<AnfNodePtr> new_node_input; std::vector<AnfNodePtr> new_node_input;
if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) { if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER ||
op_name == MIRROR_MICRO_STEP_OPERATOR) {
new_node_input = {NewValueNode(pyop_instance), node, grad_accu}; new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input"; MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
} else { } else {
@ -496,6 +499,9 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[LongToSize(index - 1)]; TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
if (IsPrimitiveCNode(middle_node, prim::kPrimReceive)) {
tensorlayout_in = *(middle_node->user_data<TensorLayout>());
}
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
@ -866,11 +872,13 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
SetUserAttrs(origin_prim->attrs(), prim); SetUserAttrs(origin_prim->attrs(), prim);
if (index == replace_op.size() - 1) { if (index == replace_op.size() - 1) {
replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>()); replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
replace_node->set_primal_attrs(node->primal_attrs());
} }
replace_node->set_in_forward_flag(true); replace_node->set_in_forward_flag(true);
replace_input[0]->set_scope(scope); replace_input[0]->set_scope(scope);
if (replace_op_info_flag && replace_op_info[index].first) { if (replace_op_info_flag && replace_op_info[index].first) {
auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph); auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph);
new_cnode->set_primal_attrs(node->primal_attrs());
(void)manager->Replace(node, new_cnode); // using Replace function to insert node (void)manager->Replace(node, new_cnode); // using Replace function to insert node
} else { } else {
(void)manager->Replace(node, replace_node); // using Replace function to insert node (void)manager->Replace(node, replace_node); // using Replace function to insert node
@ -920,8 +928,9 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
manager->SetEdge(replace_input.first, appear_count, pre_node); manager->SetEdge(replace_input.first, appear_count, pre_node);
} }
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
auto replace_output = replace_graph->second; auto replace_output = replace_graph->second->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replace_output); MS_EXCEPTION_IF_NULL(replace_output);
replace_output->set_primal_attrs(node->primal_attrs());
(void)manager->Replace(node, replace_output); (void)manager->Replace(node, replace_output);
} }
@ -1075,7 +1084,7 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
} }
} }
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) { if (IsSomePrimitive(cnode, RECEIVE) && cnode->has_user_data(PIPELINE_PARAM)) {
return std::make_pair(node, false); return std::make_pair(node, false);
} }
// When not fully use opt shard, allgather and mirror would be both inserted. // When not fully use opt shard, allgather and mirror would be both inserted.
@ -1193,6 +1202,20 @@ CNodePtr SkipTrivialNodes(CNodePtr node) {
return node; return node;
} }
std::string MirrorOpName() {
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
std::string mirror_op_name;
if (grad_accumulation_step > 1) {
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
} else if (split_stage_num > 1) {
mirror_op_name = MIRROR_MICRO_STEP_OPERATOR;
} else {
mirror_op_name = MIRROR_OPERATOR;
}
return mirror_op_name;
}
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) { void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
size_t node_size = node->inputs().size(); size_t node_size = node->inputs().size();
@ -1240,12 +1263,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
} }
} }
// not a RefKey // not a RefKey
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); std::string mirror_op_name = MirrorOpName();
std::string mirror_op_name; if (IsPrimitiveCNode(param_node_pair.first, prim::kPrimReceive)) {
if (grad_accumulation_step > 1) { param_ptr = param_node_pair.first->cast<CNodePtr>()->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
mirror_op_name = MIRROR_MINI_STEP_OPERATOR; param_name = param_ptr->name();
} else {
mirror_op_name = MIRROR_OPERATOR;
} }
AnfNodePtr pre_node = node->input(index); AnfNodePtr pre_node = node->input(index);
if (!param_node_pair.second) { if (!param_node_pair.second) {
@ -1282,6 +1303,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
AddCommOpFusionType(comm_op, param_node_pair.first); AddCommOpFusionType(comm_op, param_node_pair.first);
MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>()) MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
<< " and insert mirror before Load"; << " and insert mirror before Load";
AddCommOpParamFlag(comm_op);
continue; continue;
} }
InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root); InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root);
@ -1291,6 +1313,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
// add fusion flag // add fusion flag
// pipeline mirror would not be set, which should be supported later // pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first); AddCommOpFusionType(comm_op, param_node_pair.first);
AddCommOpParamFlag(comm_op);
} }
} }
@ -2333,6 +2356,9 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(cnode->input(0))) { if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr; return nullptr;
} }
if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
return cnode->user_data<TensorLayout>();
}
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() && if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() &&
!IsPrimitiveCNode(node, prim::kPrimReshape)) { !IsPrimitiveCNode(node, prim::kPrimReshape)) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
@ -2764,13 +2790,6 @@ std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphP
return sens_loss_pairs; return sens_loss_pairs;
} }
bool IsLastStage() {
MS_EXCEPTION_IF_NULL(g_device_manager);
auto stage_num = g_device_manager->stage_num();
auto stage_id = g_device_manager->stage_id();
return ((stage_num - 1) == stage_id);
}
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager) { const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(root);
@ -2793,7 +2812,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
// the make_tuple is parallel care node, but it may have not operator info // the make_tuple is parallel care node, but it may have not operator info
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || cnode->HasPrimalAttr(PIPELINE_PARAM)) {
continue; continue;
} }
@ -3545,20 +3564,6 @@ static bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
return false; return false;
} }
static AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> &parameters, const std::string &name) {
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name() == name) {
continue;
}
if (param_ptr->name().find(name) != std::string::npos && param_ptr->name().find("accu_grad") != std::string::npos) {
return parameter;
}
}
return nullptr;
}
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user, static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) { const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
auto cnode = node_user.first->cast<CNodePtr>(); auto cnode = node_user.first->cast<CNodePtr>();
@ -3612,6 +3617,17 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) {
} }
} }
void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages) {
if (!root->has_flag(BACKWARD) && pipeline_stages > 1) {
root->set_flag(BACKWARD, true);
if (root->has_flag(TRAINING)) {
Reorder(root, manager);
} else {
ReorderForPredict(root, manager);
}
}
}
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
#if (ENABLE_CPU && !_WIN32) #if (ENABLE_CPU && !_WIN32)
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
@ -3622,6 +3638,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
pipeline::ResourceBasePtr res = optimizer->resource();
MS_EXCEPTION_IF_NULL(res);
FuncGraphManagerPtr manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
// assume no change to graph // assume no change to graph
bool changes = false; bool changes = false;
// control whether use model_parallel mode // control whether use model_parallel mode
@ -3634,6 +3655,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
} }
root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
} }
ReorderForPipelineSplit(root, manager, pipeline_stages);
return changes; return changes;
} }
@ -3643,23 +3665,22 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
MS_LOG(INFO) << "Now entering step parallel"; MS_LOG(INFO) << "Now entering step parallel";
DumpGraph(root, std::string(STEP_PARALLEL_BEGIN)); DumpGraph(root, std::string(STEP_PARALLEL_BEGIN));
pipeline::ResourceBasePtr res = optimizer->resource();
MS_EXCEPTION_IF_NULL(res);
FuncGraphManagerPtr manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodePtr ret = root->get_return(); AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret); MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
std::reverse(all_nodes.begin(), all_nodes.end()); std::reverse(all_nodes.begin(), all_nodes.end());
if (parallel_mode != AUTO_PARALLEL) { if (parallel_mode != AUTO_PARALLEL) {
TOTAL_OPS = 0; TOTAL_OPS = 0;
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) { if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
MS_LOG(EXCEPTION) << "Parallel init failed"; MS_LOG(EXCEPTION) << "Parallel init failed";
} }
if (pipeline_stages > 1) {
HandleMicroBatch(all_nodes, manager);
ParameterStartNode(all_nodes, manager);
LastStageEndNode(all_nodes, manager);
}
// mark the forward cnodes, parallel only care these nodes // mark the forward cnodes, parallel only care these nodes
MarkForwardCNode(root); MarkForwardCNode(root);
@ -3705,6 +3726,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// ForwardCommunication BackwardCommunication TensorRedistribution // ForwardCommunication BackwardCommunication TensorRedistribution
ParallelCommunication(root, all_nodes, manager); ParallelCommunication(root, all_nodes, manager);
if (pipeline_stages > 1) {
AddVirtualAssignAdd(root);
HandleReceiveParam(root, all_nodes);
}
auto group_info = g_device_manager->group_info(); auto group_info = g_device_manager->group_info();
if (StrategyCheckpoint::GetInstance().group_info_save_on() && if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) { StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {

View File

@ -134,8 +134,6 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim); StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
bool IsLastStage();
// Add node for whole graph // Add node for whole graph
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager); const FuncGraphManagerPtr &manager);
@ -177,6 +175,10 @@ void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *u
std::vector<size_t> *indexes); std::vector<size_t> *indexes);
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes); void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
std::string MirrorOpName();
void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -410,7 +410,9 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.env_get_item_depend_swap_, irpass.env_get_item_depend_swap_,
irpass.incorporate_env_getitem_switch_layer_, irpass.incorporate_env_getitem_switch_layer_,
irpass.value_based_eliminate_, irpass.value_based_eliminate_,
irpass.receive_eliminate_}, irpass.virtual_accu_grad_,
irpass.virtual_assign_add_,
irpass.mirror_micro_step_},
false, true); false, true);
opt::OptPassConfig b_2 = opt::OptPassConfig({ opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,

View File

@ -90,17 +90,19 @@ bool PipelineSplit(const ResourcePtr &res) {
auto transformer = auto transformer =
std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num); std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
// step1: Do color graph // step1: Do color graph
transformer->LabelRequiredGradCNode();
transformer->Coloring(); transformer->Coloring();
transformer->MainGraph();
transformer->LabelMicroBatch();
// step2: Do color broadcast // step2: Do color broadcast
transformer->BroadCastColoring(); transformer->BroadCastColoring();
// step3: Handle shared parameters // step3: Handle shared parameters
transformer->ParameterColoring(); transformer->ParameterColoring();
transformer->HandleSharedParameter();
// step4: Cut Graph // step4: Cut Graph
transformer->CutGraph(); transformer->CutGraph();
// step5: Handle Sens // step5: Handle Sens
if (root->has_flag(parallel::TRAINING)) {
transformer->CoverSensShape(); transformer->CoverSensShape();
}
// step6: Elim Graph stages and no used parameter // step6: Elim Graph stages and no used parameter
transformer->ElimGraphStage(); transformer->ElimGraphStage();
transformer->ElimParameter(); transformer->ElimParameter();

View File

@ -375,6 +375,9 @@ inline const PrimitivePtr kPrimFill = std::make_shared<Primitive>("Fill");
inline const PrimitivePtr kPrimFusedPushWeight = std::make_shared<Primitive>("FusedPushWeight"); inline const PrimitivePtr kPrimFusedPushWeight = std::make_shared<Primitive>("FusedPushWeight");
inline const PrimitivePtr kPrimFusedPullWeight = std::make_shared<Primitive>("FusedPullWeight"); inline const PrimitivePtr kPrimFusedPullWeight = std::make_shared<Primitive>("FusedPullWeight");
inline const PrimitivePtr kPrimInitDataSetQueue = std::make_shared<Primitive>("InitDataSetQueue"); inline const PrimitivePtr kPrimInitDataSetQueue = std::make_shared<Primitive>("InitDataSetQueue");
inline const PrimitivePtr kPrimVirtualAssignAdd = std::make_shared<Primitive>("_VirtualAssignAdd");
inline const PrimitivePtr kPrimVirtualAccuGrad = std::make_shared<Primitive>("_VirtualAccuGrad");
inline const PrimitivePtr kPrimMirrorMicroStep = std::make_shared<Primitive>("_MirrorMicroStepOperator");
// Quant ops // Quant ops
inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold"); inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold");

View File

@ -19,6 +19,7 @@ from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.context import ParallelMode, get_auto_parallel_context from mindspore.context import ParallelMode, get_auto_parallel_context
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore import ops, nn
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple from ...common.parameter import Parameter, ParameterTuple
from ...common.tensor import Tensor from ...common.tensor import Tensor
@ -503,6 +504,95 @@ class _VirtualDatasetCell(Cell):
return self._backbone(*output) return self._backbone(*output)
class _MicroBatch(Cell):
"""
transform mini-batch to micro-batch in pipeline parallel.
Args:
params (micro_size): The number of micro-batch.
"""
def __init__(self, micro_size):
super(_MicroBatch, self).__init__()
self.shape = P.Shape()
self.micro_size = micro_size
def construct(self, i, *inputs):
micro_inputs = ()
for each_input in inputs:
input_shape = self.shape(each_input)
micro_batch_begin = i * input_shape[0] // self.micro_size
micro_batch_end = (i + 1) * input_shape[0] // self.micro_size
micro_input = each_input[micro_batch_begin:micro_batch_end, :]
micro_inputs += (micro_input,)
return micro_inputs
class PipelineCell(Cell):
"""
Wrap the network with Micro Batch.
Note:
micro_size must be greater or equal to pipeline stages.
Args:
network (Cell): The target network to wrap.
micro_size (Int): MicroBatch size.
Examples:
>>> net = Net()
>>> net = PipelineCell(net, 4)
"""
def __init__(self, network, micro_size):
super(PipelineCell, self).__init__()
self.network = network
self.micro_inputs = nn.CellList()
self.micro_size = micro_size
self.add_list = []
for i in range(micro_size):
micro_input = _MicroBatch(micro_size)
self.micro_inputs.append(micro_input)
self.add = P.Add().add_prim_attr("pipeline_end", i)
self.add_list.append(self.add)
def construct(self, *inputs):
ret = None
for i in range(self.micro_size):
micro_input = self.micro_inputs[i](i, *inputs)
output = self.network(*micro_input)
if ret is not None:
ret = self.add_list[i](ret, output)
else:
ret = output
return ret
def _pipeline_clear_grad(accu_grad, grad):
accu_grad = F.depend(accu_grad, grad)
zeros = F.tensor_mul(accu_grad, 0.0)
return F.assign(accu_grad, zeros)
class _TrainPipelineAccuStepCell(TrainOneStepCell):
"""
Wraps the network with an optimizer in pipeline mode.
"""
def __init__(self, network, optimizer, sens=1.0):
super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens)
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
self.hyper_map = ops.HyperMap()
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
accu_grads = ops.depend(self.accu_grads, grads)
succ = self.optimizer(accu_grads)
clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
loss = ops.depend(loss, succ, clear)
return loss
class VirtualDatasetCellTriple(Cell): class VirtualDatasetCellTriple(Cell):
""" """
Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Generate bprop for comm ops""" """Generate bprop for comm ops"""
from mindspore import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.communication import get_rank, get_group_size from mindspore.communication import get_rank, get_group_size
@ -22,7 +23,8 @@ from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap) ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator)
from .grad_base import bprop_getters from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive from ..operations._inner_ops import Send, Receive
@ -84,11 +86,11 @@ def get_bprop_send(self):
"""Generate bprop for Send.""" """Generate bprop for Send."""
shape = self.get_attr_dict()["shape"] shape = self.get_attr_dict()["shape"]
dtype = self.get_attr_dict()["dtype"] dtype = self.get_attr_dict()["dtype"]
send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group) send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group_back)
send_grad.add_prim_attr("backward", True) virtual_input = Tensor(0.0, dtype)
def bprop(x, out, dout): def bprop(x, out, dout):
dx = send_grad() dx = send_grad(virtual_input)
return (dx,) return (dx,)
return bprop return bprop
@ -96,14 +98,14 @@ def get_bprop_send(self):
@bprop_getters.register(Receive) @bprop_getters.register(Receive)
def get_bprop_receive(self): def get_bprop_receive(self):
"""Generate bprop for Receive.""" """Generate bprop for Receive."""
receive_grad = Send(self.tag, self.rank, self.group) receive_grad = Send(self.tag, self.rank, self.group_back)
receive_grad.add_prim_attr("backward", True)
depend = P.Depend() depend = P.Depend()
cast = P.Cast() cast = P.Cast()
out_tensor = Tensor(0.0, mstype.float16)
def bprop(x, out, dout): def bprop(x, out, dout):
send_out = receive_grad(dout) send_out = receive_grad(dout)
dx = depend(cast(zeros_like(x), F.dtype(x)), send_out) dx = depend(cast(out_tensor, F.dtype(x)), send_out)
return (dx,) return (dx,)
return bprop return bprop
@ -116,6 +118,80 @@ def get_bprop_virtual_add(self):
return bprop return bprop
@bprop_getters.register(_VirtualAssignAdd)
def get_bprop_virtual_assign_add(self):
"""Generate bprop for VirtualAssignAdd."""
assign_add = P.AssignAdd()
cast = P.Cast()
dtype = P.DType()
out_tensor = Tensor(0.0, mstype.float16)
def bprop(x, y, out, dout):
temp = assign_add(y, dout)
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp)
return bprop
@bprop_getters.register(_VirtualAccuGrad)
def get_bprop_virtual_accu_grad(self):
"""Generate bprop for VirtualAccuGrad."""
cast = P.Cast()
dtype = P.DType()
out_tensor = Tensor(0.0, mstype.float16)
def bprop(x, y, out, dout):
return (F.depend(y, dout), cast(out_tensor, dtype(y)))
return bprop
@bprop_getters.register(_MirrorMicroStepOperator)
def get_bprop_mirror_micro_step_operator(self):
"""
Backpropagator for _MirrorMicroStepOperator, do allreduce or allgather for the devices in the group,
allgather for sparse feature.
"""
group = self.group
dev_num = self.dev_num
mean_flag = self.mean_flag
scale = 1 / dev_num
all_reduce = AllReduce(group=group)
fusion = self.get_attr_dict()["fusion"]
all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'):
parameter = self.parameter
all_reduce.add_prim_attr("parameter", parameter)
if self.instance_name:
instance_name = "grad_mirror" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
cast = P.Cast()
dtype = P.DType()
assign = P.Assign()
if "parameter_micro" in self.get_attr_dict():
assign.add_prim_attr("parameter_micro", 0)
out_tensor = Tensor(1.0, mstype.float16)
def bprop(x, z, out, dout):
real_grad = z
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout)
real_grad = all_reduce(z)
real_grad = F.tensor_mul(real_grad, scale)
assign(z, real_grad)
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout)
real_grad = all_reduce(z)
assign(z, real_grad)
return (cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z)))
return bprop
@bprop_getters.register(Broadcast) @bprop_getters.register(Broadcast)
def get_bprop_broad_cast(self): def get_bprop_broad_cast(self):
"""Generate bprop for Broadcast.""" """Generate bprop for Broadcast."""

View File

@ -36,8 +36,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect, SearchSorted) EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect, SearchSorted)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
_HostAllGather, _HostReduceScatter) _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert) TensorSummary, HistogramSummary, Print, Assert)
from .control_ops import GeSwitch, Merge from .control_ops import GeSwitch, Merge

View File

@ -417,7 +417,7 @@ class Send(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
self.rank = dest_rank self.rank = dest_rank
self.sr_tag = sr_tag self.sr_tag = sr_tag
self.group = group self.group = group
@ -427,7 +427,6 @@ class Send(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
self.add_prim_attr("dtype", x_dtype)
return x_dtype return x_dtype
@ -474,7 +473,8 @@ class Receive(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP,
group_back=GlobalComm.WORLD_COMM_GROUP):
self.rank = src_rank self.rank = src_rank
self.tag = sr_tag self.tag = sr_tag
self.shape = shape self.shape = shape

View File

@ -690,6 +690,72 @@ class _VirtualDataset(PrimitiveWithInfer):
virtual_dataset = _VirtualDataset() virtual_dataset = _VirtualDataset()
class _VirtualAssignAdd(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for
internal use of parallel modules and cannot be called by users.
Args:
micro (int): MicroBatch. Default: 0.
"""
@prim_attr_register
def __init__(self):
"""init"""
def infer_shape(self, x_shape, y_shape):
return x_shape
def infer_dtype(self, x_dtype, y_dtype):
return x_dtype
virtual_assign_add = _VirtualAssignAdd()
class _VirtualAccuGrad(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for
internal use of parallel modules and cannot be called by users.
"""
@prim_attr_register
def __init__(self):
"""init"""
def infer_shape(self, x_shape, y_shape):
return x_shape
def infer_dtype(self, x_dtype, y_dtype):
return x_dtype
virtual_accu_grad = _VirtualAccuGrad()
class _MirrorMicroStepOperator(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
internal use of parallel modules and cannot be called by users.
Args:
group (str): The communication group to work on. Default: None.
dev_num (int): The device number of the group. Default: None.
mean_flag (bool): Whether use mean in backward. Default: None.
"""
@prim_attr_register
def __init__(self, group=None, dev_num=None, mean_flag=None):
self.group = group
self.dev_num = dev_num
self.mean_flag = mean_flag
def infer_shape(self, x_shape, z_shape):
return x_shape
def infer_dtype(self, x_dtype, z_shape):
return x_dtype
class _VirtualOutput(PrimitiveWithInfer): class _VirtualOutput(PrimitiveWithInfer):
""" """
Auto parallel virtual out operator. Auto parallel virtual out operator.

View File

@ -18,9 +18,9 @@ from .._checkparam import Validator as validator
from .._checkparam import Rel from .._checkparam import Rel
from ..common import dtype as mstype from ..common import dtype as mstype
from ..nn import acc from ..nn import acc
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
from ..ops import functional as F from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from ..context import ParallelMode from ..context import ParallelMode
from .. import context from .. import context
@ -187,5 +187,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
network = nn.TrainOneStepWithLossScaleCell(network, optimizer, network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
scale_sense=update_cell).set_train() scale_sense=update_cell).set_train()
return network return network
if _get_pipeline_stages() > 1:
network = _TrainPipelineAccuStepCell(network, optimizer).set_train()
else:
network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
return network return network

View File

@ -21,6 +21,7 @@ from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.nn.wrap.cell_wrapper import PipelineCell
class DatasetLenet(): class DatasetLenet():
@ -90,6 +91,7 @@ class PipelineSplit(nn.Cell):
def __init__(self, strategy1, strategy2): def __init__(self, strategy1, strategy2):
super().__init__() super().__init__()
self.cell = Net(strategy1, strategy2) self.cell = Net(strategy1, strategy2)
self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
def construct(self, x, label): def construct(self, x, label):
x = self.cell(x) x = self.cell(x)
@ -101,6 +103,7 @@ class PipelineSplit2(nn.Cell):
super().__init__() super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param") self.param = Parameter(initializer("zeros", [64, 64]), name="param")
self.cell = Net(strategy1, strategy2, self.param) self.cell = Net(strategy1, strategy2, self.param)
self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
def construct(self, x, label): def construct(self, x, label):
x = self.cell(x) x = self.cell(x)
@ -114,8 +117,8 @@ def test_pipeline_split_stage0():
label = Tensor(np.ones([64, 64]), dtype=ms.float32) label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1)) strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1)) strategy2 = ((2, 1), (1, 1))
net = PipelineSplit(strategy1, strategy2) net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.cell.block[0].trainable_params() params = net.network.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3) dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01) optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer) model = Model(net, optimizer=optimizer)
@ -131,8 +134,8 @@ def test_pipeline_split_stage1():
label = Tensor(np.ones([64, 64]), dtype=ms.float32) label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1)) strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1)) strategy2 = ((2, 1), (1, 1))
net = PipelineSplit(strategy1, strategy2) net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.cell.block[1].trainable_params() params = net.network.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3) dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01) optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer) model = Model(net, optimizer=optimizer)
@ -149,8 +152,8 @@ def test_pipeline_split_shared_parameter_stage0():
label = Tensor(np.ones([64, 64]), dtype=ms.float32) label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1)) strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1)) strategy2 = ((2, 1), (1, 1))
net = PipelineSplit2(strategy1, strategy2) net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
params = net.cell.block[0].trainable_params() params = net.network.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3) dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01) optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer) model = Model(net, optimizer=optimizer)
@ -164,8 +167,8 @@ def test_pipeline_split_shared_parameter_stage1():
label = Tensor(np.ones([64, 64]), dtype=ms.float32) label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1)) strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1)) strategy2 = ((2, 1), (1, 1))
net = PipelineSplit2(strategy1, strategy2) net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
params = net.cell.block[1].trainable_params() params = net.network.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3) dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01) optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer) model = Model(net, optimizer=optimizer)