forked from mindspore-Ecosystem/mindspore
!16457 [AutoParallel]pipeline_split_adapt_master
Merge pull request !16457 from lichen/pipeline_split_adapt_master
This commit is contained in:
commit
85d860e6a2
|
@ -63,6 +63,10 @@ bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph,
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
MS_EXCEPTION_IF_NULL(cur_node);
|
MS_EXCEPTION_IF_NULL(cur_node);
|
||||||
|
if (IsPrimitiveCNode(cur_node, prim::kPrimReceive)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// when input is a parameter or is a value node
|
// when input is a parameter or is a value node
|
||||||
if (IsParameterOrValueNode(input)) {
|
if (IsParameterOrValueNode(input)) {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -192,90 +192,6 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsBackward(const CNodePtr &cnode) {
|
|
||||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
||||||
return prim->HasAttr(BACKWARD);
|
|
||||||
}
|
|
||||||
|
|
||||||
// compare the value of send/recv sr_tag
|
|
||||||
bool comp(const CNodePtr &node1, const CNodePtr &node2) {
|
|
||||||
auto prim1 = GetValueNode<PrimitivePtr>(node1->input(0));
|
|
||||||
MS_EXCEPTION_IF_NULL(prim1);
|
|
||||||
auto prim2 = GetValueNode<PrimitivePtr>(node1->input(0));
|
|
||||||
MS_EXCEPTION_IF_NULL(prim2);
|
|
||||||
auto sr_tag_value1 = prim1->GetAttr(SR_TAG);
|
|
||||||
MS_EXCEPTION_IF_NULL(sr_tag_value1);
|
|
||||||
auto sr_tag_value2 = prim2->GetAttr(SR_TAG);
|
|
||||||
MS_EXCEPTION_IF_NULL(sr_tag_value2);
|
|
||||||
auto sr_tag1 = GetValue<int64_t>(sr_tag_value1);
|
|
||||||
auto sr_tag2 = GetValue<int64_t>(sr_tag_value2);
|
|
||||||
return sr_tag1 < sr_tag2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reorder the execution order of send
|
|
||||||
void ReorderSend(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
|
|
||||||
auto last_node = op_v.back();
|
|
||||||
for (auto &node : op_v) {
|
|
||||||
if (node == last_node) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto iter = std::find(execution_order->begin(), execution_order->end(), node);
|
|
||||||
(void)execution_order->erase(iter);
|
|
||||||
}
|
|
||||||
std::sort(op_v.begin(), op_v.end(), comp);
|
|
||||||
auto last_node_iter = std::find(execution_order->begin(), execution_order->end(), last_node);
|
|
||||||
auto node_iter = execution_order->erase(last_node_iter);
|
|
||||||
// all send will insert the end of the last node
|
|
||||||
execution_order->insert(node_iter, op_v.begin(), op_v.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reorder the execution order of receive
|
|
||||||
void ReorderRecv(std::vector<CNodePtr> *execution_order, std::vector<CNodePtr> op_v) {
|
|
||||||
auto begin_node = op_v.front();
|
|
||||||
for (auto &node : op_v) {
|
|
||||||
if (node == begin_node) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto iter = std::find(execution_order->begin(), execution_order->end(), node);
|
|
||||||
(void)execution_order->erase(iter);
|
|
||||||
}
|
|
||||||
std::sort(op_v.begin(), op_v.end(), comp);
|
|
||||||
auto begin_node_iter = std::find(execution_order->begin(), execution_order->end(), begin_node);
|
|
||||||
auto node_iter = execution_order->erase(begin_node_iter);
|
|
||||||
// all receive will insert before the begin node
|
|
||||||
execution_order->insert(node_iter, op_v.begin(), op_v.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReorderSendRecv(std::vector<CNodePtr> *execution_order) {
|
|
||||||
std::vector<CNodePtr> forward_send, forward_recv, backward_send, backward_recv;
|
|
||||||
for (auto &cnode : *execution_order) {
|
|
||||||
if (IsPrimitiveCNode(cnode, prim::kPrimSend) && IsBackward(cnode)) {
|
|
||||||
backward_send.push_back(cnode);
|
|
||||||
continue;
|
|
||||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) {
|
|
||||||
forward_send.push_back(cnode);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) && IsBackward(cnode)) {
|
|
||||||
backward_recv.push_back(cnode);
|
|
||||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
|
|
||||||
forward_recv.push_back(cnode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!forward_send.empty()) {
|
|
||||||
ReorderSend(execution_order, forward_send);
|
|
||||||
}
|
|
||||||
if (!backward_send.empty()) {
|
|
||||||
ReorderSend(execution_order, backward_send);
|
|
||||||
}
|
|
||||||
if (!forward_recv.empty()) {
|
|
||||||
ReorderRecv(execution_order, forward_recv);
|
|
||||||
}
|
|
||||||
if (!backward_recv.empty()) {
|
|
||||||
ReorderRecv(execution_order, backward_recv);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
|
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
||||||
|
@ -511,10 +427,6 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
||||||
|
|
||||||
// adjust kernel
|
// adjust kernel
|
||||||
AdjustKernel(root_graph);
|
AdjustKernel(root_graph);
|
||||||
// reorder send/recv
|
|
||||||
auto execution_order = root_graph->execution_order();
|
|
||||||
ReorderSendRecv(&execution_order);
|
|
||||||
root_graph->set_execution_order(execution_order);
|
|
||||||
#if ENABLE_CPU && ENABLE_D
|
#if ENABLE_CPU && ENABLE_D
|
||||||
InitPsWorker(root_graph);
|
InitPsWorker(root_graph);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -206,8 +206,14 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
virtual_output_eliminate_ =
|
virtual_output_eliminate_ =
|
||||||
MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
|
MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
|
||||||
|
|
||||||
// Receive
|
// PipelineSplit
|
||||||
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
|
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
|
||||||
|
virtual_accu_grad_ =
|
||||||
|
MakeSubstitution(std::make_shared<VirtualAccuGradEliminater>(), "virtual_accu_grad", prim::kPrimVirtualAccuGrad);
|
||||||
|
virtual_assign_add_ =
|
||||||
|
MakeSubstitution(std::make_shared<VirtualAssignAddEliminater>(), "virtual_assign_add", prim::kPrimVirtualAssignAdd);
|
||||||
|
mirror_micro_step_ =
|
||||||
|
MakeSubstitution(std::make_shared<MirrorMicroStepEliminater>(), "mirror_micro_step", prim::kPrimMirrorMicroStep);
|
||||||
|
|
||||||
// Convert
|
// Convert
|
||||||
print_tuple_wrapper_ =
|
print_tuple_wrapper_ =
|
||||||
|
|
|
@ -120,8 +120,12 @@ class OptimizeIRPassLib {
|
||||||
|
|
||||||
// virtual output
|
// virtual output
|
||||||
SubstitutionPtr virtual_output_eliminate_;
|
SubstitutionPtr virtual_output_eliminate_;
|
||||||
// Receive
|
|
||||||
|
// PipelineSplit
|
||||||
SubstitutionPtr receive_eliminate_;
|
SubstitutionPtr receive_eliminate_;
|
||||||
|
SubstitutionPtr virtual_accu_grad_;
|
||||||
|
SubstitutionPtr virtual_assign_add_;
|
||||||
|
SubstitutionPtr mirror_micro_step_;
|
||||||
|
|
||||||
// Convert
|
// Convert
|
||||||
SubstitutionPtr print_tuple_wrapper_;
|
SubstitutionPtr print_tuple_wrapper_;
|
||||||
|
|
|
@ -134,6 +134,63 @@ class ReceiveEliminater : public AnfVisitor {
|
||||||
void Visit(const AnfNodePtr &) override {}
|
void Visit(const AnfNodePtr &) override {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class VirtualAssignAddEliminater : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAssignAdd) || node->func_graph() == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||||
|
if (inputs.size() < 2) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AnfNodePtr x_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
class VirtualAccuGradEliminater : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAccuGrad) || node->func_graph() == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||||
|
if (inputs.size() < 2) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AnfNodePtr x_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
// {prim::kPrimMirrorMicroStep, X, Z} -> X
|
||||||
|
class MirrorMicroStepEliminater : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) || node->func_graph() == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||||
|
if (inputs.size() < 2) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit(const AnfNodePtr &) override {}
|
||||||
|
};
|
||||||
|
|
||||||
// {prim::kPrimSameTypeShape, X, Y} -> X
|
// {prim::kPrimSameTypeShape, X, Y} -> X
|
||||||
class SameEliminater : public AnfVisitor {
|
class SameEliminater : public AnfVisitor {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -0,0 +1,634 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
|
#include <memory>
|
||||||
|
#include <list>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||||
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||||
|
#include "frontend/parallel/device_manager.h"
|
||||||
|
#include "frontend/parallel/context.h"
|
||||||
|
#include "frontend/parallel/step_parallel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
|
||||||
|
auto pre_node = cnode->input(1);
|
||||||
|
while (true) {
|
||||||
|
if (pre_node->isa<Parameter>()) {
|
||||||
|
return pre_node;
|
||||||
|
} else {
|
||||||
|
if (pre_node->isa<CNode>()) {
|
||||||
|
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||||
|
pre_node = pre_cnode->input(1);
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsLastStage() {
|
||||||
|
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||||
|
auto stage_num = g_device_manager->stage_num();
|
||||||
|
auto stage_id = g_device_manager->stage_id();
|
||||||
|
return ((stage_num - 1) == stage_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetStridedSliceStrategy(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
int64_t dev_num = 1;
|
||||||
|
auto attrs_temp = prim->attrs();
|
||||||
|
std::vector<Shapes> shape_list = ExtractShape(cnode);
|
||||||
|
if (shape_list.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " failed to extract shape";
|
||||||
|
}
|
||||||
|
std::vector<ValuePtr> elements;
|
||||||
|
for (size_t i = 0; i < shape_list[0].size(); i++) {
|
||||||
|
if (shape_list[0][i].empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
|
||||||
|
}
|
||||||
|
Dimensions input_strategy = {dev_num};
|
||||||
|
for (size_t j = 1; j < shape_list[0][i].size(); j++) {
|
||||||
|
input_strategy.push_back(1);
|
||||||
|
}
|
||||||
|
elements.push_back(MakeValue(input_strategy));
|
||||||
|
}
|
||||||
|
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
|
||||||
|
attrs_temp[STRATEGY] = strategy;
|
||||||
|
(void)prim->SetAttrs(attrs_temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
|
||||||
|
const AnfNodePtr &accu_parameter) {
|
||||||
|
auto cnode = node_user.first->cast<CNodePtr>();
|
||||||
|
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag() ||
|
||||||
|
((IsPrimitiveCNode(node_user.first, prim::kPrimSend) || IsPrimitiveCNode(node_user.first, prim::kPrimDepend)) &&
|
||||||
|
ParallelContext::GetInstance()->enable_parallel_optimizer())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto prim = GetCNodePrimitive(cnode);
|
||||||
|
if (prim == nullptr) {
|
||||||
|
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAd";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
OperatorAttrs attrs;
|
||||||
|
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
|
||||||
|
auto value_node = NewValueNode(py_instance);
|
||||||
|
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter};
|
||||||
|
auto graph = cnode->func_graph();
|
||||||
|
auto virtual_node = graph->NewCNode(virtual_node_input);
|
||||||
|
manager->SetEdge(cnode, node_user.second, virtual_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr ¶m) {
|
||||||
|
auto cnode = recv->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
OperatorAttrs attrs;
|
||||||
|
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
|
||||||
|
auto value_node = NewValueNode(py_instance);
|
||||||
|
std::vector<AnfNodePtr> virtual_node_input = {value_node, recv, param};
|
||||||
|
auto graph = cnode->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto virtual_node = graph->NewCNode(virtual_node_input);
|
||||||
|
manager->Replace(recv, virtual_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name) {
|
||||||
|
for (auto ¶meter : parameters) {
|
||||||
|
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||||
|
if (param_ptr->name() == name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto expect_name = "accu_grads." + name;
|
||||||
|
if (param_ptr->name() == expect_name) {
|
||||||
|
return parameter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
|
auto parameters = root->parameters();
|
||||||
|
auto node_users_map = root->manager()->node_users();
|
||||||
|
for (auto &node : all_nodes) {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimReceive)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto parameter_ptr = cnode->input(1)->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
||||||
|
auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
|
||||||
|
if (!accu_parameter) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto node_users = node_users_map[node];
|
||||||
|
for (auto &temp_user : node_users) {
|
||||||
|
auto temp_node = temp_user.first;
|
||||||
|
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
|
||||||
|
temp_node = node_users_map[temp_node].begin()->first;
|
||||||
|
}
|
||||||
|
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) {
|
||||||
|
auto node_set = node_users_map[temp_node];
|
||||||
|
for (auto &node_user : node_set) {
|
||||||
|
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
InsertVirtualAccuGrad(node, root->manager(), accu_parameter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddVirtualAssignAdd(const FuncGraphPtr &root) {
|
||||||
|
auto parameters = root->parameters();
|
||||||
|
auto node_users_map = root->manager()->node_users();
|
||||||
|
for (auto ¶meter : parameters) {
|
||||||
|
auto parameter_ptr = parameter->cast<ParameterPtr>();
|
||||||
|
auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
|
||||||
|
if (!accu_parameter) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto node_users = node_users_map[parameter];
|
||||||
|
for (auto &temp_user : node_users) {
|
||||||
|
auto temp_node = temp_user.first;
|
||||||
|
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
|
||||||
|
temp_node = node_users_map[temp_node].begin()->first;
|
||||||
|
}
|
||||||
|
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) {
|
||||||
|
auto node_set = node_users_map[temp_node];
|
||||||
|
for (auto &node_user : node_set) {
|
||||||
|
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node1);
|
||||||
|
MS_EXCEPTION_IF_NULL(node2);
|
||||||
|
auto cnode1 = node1->cast<CNodePtr>();
|
||||||
|
auto cnode2 = node2->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode1);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode2);
|
||||||
|
auto micro1 = cnode1->GetPrimalAttr(MICRO);
|
||||||
|
auto micro2 = cnode2->GetPrimalAttr(MICRO);
|
||||||
|
MS_EXCEPTION_IF_NULL(micro1);
|
||||||
|
MS_EXCEPTION_IF_NULL(micro2);
|
||||||
|
auto micro1_value = GetValue<int64_t>(micro1);
|
||||||
|
auto micro2_value = GetValue<int64_t>(micro2);
|
||||||
|
if (micro1_value == micro2_value) {
|
||||||
|
auto prim1 = GetCNodePrimitive(cnode1);
|
||||||
|
auto prim2 = GetCNodePrimitive(cnode2);
|
||||||
|
MS_EXCEPTION_IF_NULL(prim1);
|
||||||
|
MS_EXCEPTION_IF_NULL(prim2);
|
||||||
|
auto rank_tag1 = prim1->GetAttr(SRC_RANK);
|
||||||
|
auto rank_tag2 = prim2->GetAttr(SRC_RANK);
|
||||||
|
if (rank_tag1 == nullptr) {
|
||||||
|
rank_tag1 = prim1->GetAttr(DEST_RANK);
|
||||||
|
}
|
||||||
|
if (rank_tag2 == nullptr) {
|
||||||
|
rank_tag2 = prim2->GetAttr(DEST_RANK);
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(rank_tag1);
|
||||||
|
MS_EXCEPTION_IF_NULL(rank_tag2);
|
||||||
|
auto rank1_value = GetValue<int64_t>(rank_tag1);
|
||||||
|
auto rank2_value = GetValue<int64_t>(rank_tag2);
|
||||||
|
if (rank1_value == rank2_value) {
|
||||||
|
auto sr_tag1 = prim1->GetAttr(SR_TAG);
|
||||||
|
auto sr_tag2 = prim2->GetAttr(SR_TAG);
|
||||||
|
MS_EXCEPTION_IF_NULL(sr_tag1);
|
||||||
|
MS_EXCEPTION_IF_NULL(sr_tag2);
|
||||||
|
auto sr1_value = GetValue<int64_t>(sr_tag1);
|
||||||
|
auto sr2_value = GetValue<int64_t>(sr_tag2);
|
||||||
|
return sr1_value < sr2_value;
|
||||||
|
}
|
||||||
|
return rank1_value < rank2_value;
|
||||||
|
}
|
||||||
|
return micro1_value < micro2_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
|
||||||
|
const FuncGraphPtr &root) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prior_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(post_node);
|
||||||
|
auto post_cnode = post_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(post_cnode);
|
||||||
|
std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node};
|
||||||
|
auto depend_node = root->NewCNode(depend_input);
|
||||||
|
manager->SetEdge(post_node, 1, depend_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end,
|
||||||
|
const FuncGraphPtr &root) {
|
||||||
|
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||||
|
MS_EXCEPTION_IF_NULL(root);
|
||||||
|
auto manager = root->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto stage_num = g_device_manager->stage_num();
|
||||||
|
auto stage_id = g_device_manager->stage_id();
|
||||||
|
for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) {
|
||||||
|
auto prior_node = forward_end[i - 1];
|
||||||
|
auto post_node = forward_start[i];
|
||||||
|
InsertDepend(prior_node, post_node, manager, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
|
||||||
|
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
|
||||||
|
const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root) {
|
||||||
|
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||||
|
MS_EXCEPTION_IF_NULL(root);
|
||||||
|
auto manager = root->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto stage_num = g_device_manager->stage_num();
|
||||||
|
auto stage_id = g_device_manager->stage_id();
|
||||||
|
for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) {
|
||||||
|
auto prior_node1 = forward_end_before_pair.second[i];
|
||||||
|
auto post_node1 = backward_start_pair.first[i - stage_num + stage_id + 1];
|
||||||
|
InsertDepend(prior_node1, post_node1, manager, root);
|
||||||
|
auto prior_node2 = backward_end_pair.second[i - stage_num + stage_id];
|
||||||
|
auto post_node2 = forward_start_pair.first[i];
|
||||||
|
InsertDepend(prior_node2, post_node2, manager, root);
|
||||||
|
}
|
||||||
|
for (size_t i = (stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) {
|
||||||
|
if (!IsLastStage()) {
|
||||||
|
auto prior_node3 = backward_start_pair.second[i - stage_num + stage_id];
|
||||||
|
auto post_node3 = forward_end_pair.first[i - 1];
|
||||||
|
InsertDepend(prior_node3, post_node3, manager, root);
|
||||||
|
auto prior_node4 = forward_end_pair.second[i - 1];
|
||||||
|
auto post_node4 = backward_end_pair.first[i - stage_num + stage_id];
|
||||||
|
InsertDepend(prior_node4, post_node4, manager, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t j = (backward_start_pair.first.size() - stage_num + stage_id + 1); j < backward_start_pair.first.size();
|
||||||
|
++j) {
|
||||||
|
auto prior_node5 = backward_end_pair.second[j - 1];
|
||||||
|
auto post_node5 = backward_start_pair.first[j];
|
||||||
|
InsertDepend(prior_node5, post_node5, manager, root);
|
||||||
|
}
|
||||||
|
if (!IsLastStage()) {
|
||||||
|
auto prior_node6 = forward_end_before_pair.second[stage_num - 1 - stage_id];
|
||||||
|
auto post_node6 = backward_start_pair.first[0];
|
||||||
|
InsertDepend(prior_node6, post_node6, manager, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReorderForParams(const std::vector<AnfNodePtr> &backward_params, const std::vector<AnfNodePtr> &forward_params,
|
||||||
|
const std::vector<AnfNodePtr> &allreduce_params, const PipelinePair &forward_params_pair,
|
||||||
|
const PipelinePair &backward_params_pair, const std::vector<AnfNodePtr> &backward_end,
|
||||||
|
const PipelinePair &forward_start_pair, const FuncGraphPtr &root) {
|
||||||
|
auto manager = root->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
if (!forward_params.empty()) {
|
||||||
|
auto prior_node = forward_params_pair.second[0];
|
||||||
|
auto post_node = forward_start_pair.first[0];
|
||||||
|
InsertDepend(prior_node, post_node, manager, root);
|
||||||
|
}
|
||||||
|
if (!backward_params.empty()) {
|
||||||
|
if (!allreduce_params.empty()) {
|
||||||
|
for (auto &node : allreduce_params) {
|
||||||
|
auto post_node1 = backward_params_pair.first[0];
|
||||||
|
InsertDepend(node, post_node1, manager, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto prior_node2 = backward_end.back();
|
||||||
|
auto post_node2 = backward_params[0];
|
||||||
|
InsertDepend(prior_node2, post_node2, manager, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t GetMicroBatch(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto micro_value = cnode->GetPrimalAttr(MICRO);
|
||||||
|
MS_EXCEPTION_IF_NULL(micro_value);
|
||||||
|
return GetValue<int64_t>(micro_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max) {
|
||||||
|
std::vector<AnfNodePtr> temp_vec;
|
||||||
|
std::vector<AnfNodePtr> out_vec_begin;
|
||||||
|
std::vector<AnfNodePtr> out_vec_end;
|
||||||
|
auto manager = root->manager();
|
||||||
|
for (int64_t i = 0; i <= micro_max; ++i) {
|
||||||
|
temp_vec.clear();
|
||||||
|
for (auto &node : node_vector) {
|
||||||
|
auto node_micro = GetMicroBatch(node);
|
||||||
|
if (node_micro == i) {
|
||||||
|
temp_vec.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (temp_vec.size() <= 1) {
|
||||||
|
MS_LOG(INFO) << "No Duplicate MicroBatch.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::sort(temp_vec.begin(), temp_vec.end(), CompFunc);
|
||||||
|
for (size_t j = 0; j < temp_vec.size() - 1; ++j) {
|
||||||
|
auto prior_node = temp_vec[j];
|
||||||
|
auto post_node = temp_vec[j + 1];
|
||||||
|
InsertDepend(prior_node, post_node, manager, root);
|
||||||
|
}
|
||||||
|
if (!temp_vec.empty()) {
|
||||||
|
out_vec_begin.push_back(temp_vec.front());
|
||||||
|
out_vec_end.push_back(temp_vec.back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (out_vec_begin.empty()) {
|
||||||
|
return std::make_pair(node_vector, node_vector);
|
||||||
|
}
|
||||||
|
return std::make_pair(out_vec_begin, out_vec_end);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value) {
|
||||||
|
auto node_users = (*node_users_map)[node];
|
||||||
|
for (auto &node_pair : node_users) {
|
||||||
|
auto user_node = node_pair.first->cast<CNodePtr>();
|
||||||
|
if (user_node->HasPrimalAttr(MICRO)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
user_node->AddPrimalAttr(MICRO, value);
|
||||||
|
BroadCastMicroBatch(user_node, node_users_map, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr GetPreNode(const AnfNodePtr &node) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||||
|
return GetPreNode(cnode->input(1));
|
||||||
|
}
|
||||||
|
return cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
|
||||||
|
if (!IsLastStage()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto node_users_map = manager->node_users();
|
||||||
|
for (auto &node : all_nodes) {
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (!cnode->HasPrimalAttr(MICRO)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto prim = GetCNodePrimitive(node);
|
||||||
|
if (prim && prim->HasAttr(PIPELINE_END)) {
|
||||||
|
for (auto &temp_node : cnode->inputs()) {
|
||||||
|
if (!temp_node->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto temp_cnode = temp_node->cast<CNodePtr>();
|
||||||
|
auto temp_prim = GetCNodePrimitive(temp_node);
|
||||||
|
if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto end_node = GetPreNode(temp_node);
|
||||||
|
auto end_cnode = end_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(end_cnode);
|
||||||
|
auto end_prim = GetCNodePrimitive(end_node);
|
||||||
|
OperatorAttrs attrs_;
|
||||||
|
auto op = CreatOpInstance(attrs_, end_prim->name(), "");
|
||||||
|
auto value_node = NewValueNode(op);
|
||||||
|
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
|
||||||
|
new_prim->SetAttrs(end_prim->attrs());
|
||||||
|
manager->SetEdge(end_node, 0, value_node);
|
||||||
|
end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
|
||||||
|
auto node_users_map = manager->node_users();
|
||||||
|
for (auto &node : all_nodes) {
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (!cnode->HasPrimalAttr(MICRO)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto micro = cnode->GetPrimalAttr(MICRO);
|
||||||
|
auto prim = GetCNodePrimitive(node);
|
||||||
|
if (prim && prim->HasAttr(PARAMETER_START)) {
|
||||||
|
OperatorAttrs attrs_;
|
||||||
|
auto op = CreatOpInstance(attrs_, prim->name(), "");
|
||||||
|
auto value_node = NewValueNode(op);
|
||||||
|
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
|
||||||
|
new_prim->SetAttrs(prim->attrs());
|
||||||
|
manager->SetEdge(cnode, 0, value_node);
|
||||||
|
cnode->AddPrimalAttr(PARAMETER_START, micro);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
|
||||||
|
auto node_users_map = manager->node_users();
|
||||||
|
for (auto &node : all_nodes) {
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (!cnode->HasPrimalAttr(MICRO)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto micro = cnode->GetPrimalAttr(MICRO);
|
||||||
|
MS_EXCEPTION_IF_NULL(micro);
|
||||||
|
BroadCastMicroBatch(cnode, &node_users_map, micro);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
|
||||||
|
std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
|
||||||
|
std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
|
||||||
|
std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root) {
|
||||||
|
std::list<ValuePtr> name_list = {};
|
||||||
|
auto stage_id = g_device_manager->stage_id();
|
||||||
|
for (auto &node : root->nodes()) {
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimZerosLike)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto prim = GetCNodePrimitive(node);
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
|
||||||
|
auto forward_node_name = cnode->GetPrimalAttr(kPrimalAttrForwardNodeName);
|
||||||
|
if (std::find(name_list.begin(), name_list.end(), forward_node_name) != name_list.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
name_list.push_back(forward_node_name);
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_END)) {
|
||||||
|
backward_start->push_back(node);
|
||||||
|
}
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
|
||||||
|
backward_end->push_back(node);
|
||||||
|
}
|
||||||
|
if (cnode->HasPrimalAttr(PARAMETER_START)) {
|
||||||
|
backward_end->push_back(node);
|
||||||
|
}
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
|
||||||
|
backward_params->push_back(node);
|
||||||
|
}
|
||||||
|
if (prim->HasAttr(PARAMETER_MICRO)) {
|
||||||
|
allreduce_params->push_back(node);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
|
||||||
|
if (stage_id != 0 && IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
forward_start->push_back(node);
|
||||||
|
}
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_END)) {
|
||||||
|
forward_end->push_back(node);
|
||||||
|
}
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
|
||||||
|
forward_params->push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::sort((*backward_start).begin(), (*backward_start).end(), CompFunc);
|
||||||
|
std::sort((*backward_end).begin(), (*backward_end).end(), CompFunc);
|
||||||
|
std::sort((*forward_start).begin(), (*forward_start).end(), CompFunc);
|
||||||
|
std::sort((*forward_end).begin(), (*forward_end).end(), CompFunc);
|
||||||
|
std::sort((*backward_params).begin(), (*backward_params).end(), CompFunc);
|
||||||
|
std::sort((*forward_params).begin(), (*forward_params).end(), CompFunc);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
|
||||||
|
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
|
||||||
|
size_t micro_size) {
|
||||||
|
micro_size = micro_size + 1;
|
||||||
|
if (forward_start_pair.first.size() != micro_size) {
|
||||||
|
MS_LOG(EXCEPTION) << "forward_node's size:" << forward_start_pair.first.size()
|
||||||
|
<< "is not equal to micro size:" << micro_size;
|
||||||
|
}
|
||||||
|
if (forward_end_pair.first.size() != micro_size) {
|
||||||
|
MS_LOG(EXCEPTION) << "forward_node's size:" << forward_end_pair.first.size()
|
||||||
|
<< "is not equal to micro size:" << micro_size;
|
||||||
|
}
|
||||||
|
if (backward_start_pair.first.size() != micro_size) {
|
||||||
|
MS_LOG(EXCEPTION) << "backward_node's size:" << backward_start_pair.first.size()
|
||||||
|
<< "is not equal to micro size:" << micro_size;
|
||||||
|
}
|
||||||
|
if (backward_end_pair.first.size() != micro_size) {
|
||||||
|
MS_LOG(EXCEPTION) << "backward_node's size:" << backward_end_pair.first.size()
|
||||||
|
<< "is not equal to micro size:" << micro_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||||
|
std::vector<AnfNodePtr> forward_start;
|
||||||
|
std::vector<AnfNodePtr> forward_end;
|
||||||
|
std::vector<AnfNodePtr> forward_params;
|
||||||
|
std::vector<AnfNodePtr> backward_start;
|
||||||
|
std::vector<AnfNodePtr> backward_end;
|
||||||
|
std::vector<AnfNodePtr> backward_params;
|
||||||
|
std::vector<AnfNodePtr> allreduce_params;
|
||||||
|
GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
|
||||||
|
&allreduce_params, root);
|
||||||
|
auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
|
||||||
|
auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
|
||||||
|
MS_EXCEPTION_IF_NULL(micro_size);
|
||||||
|
auto micro_max = GetValue<int64_t>(micro_size);
|
||||||
|
auto backward_start_pair = Deduplicate(backward_start, root, micro_max);
|
||||||
|
auto backward_end_pair = Deduplicate(backward_end, root, micro_max);
|
||||||
|
auto forward_start_pair = Deduplicate(forward_start, root, micro_max);
|
||||||
|
auto forward_end_pair = Deduplicate(forward_end, root, micro_max);
|
||||||
|
auto forward_params_pair = Deduplicate(forward_params, root, micro_max);
|
||||||
|
auto backward_params_pair = Deduplicate(backward_params, root, micro_max);
|
||||||
|
CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, LongToSize(micro_max));
|
||||||
|
PipelinePair forward_end_before_pair;
|
||||||
|
if (!IsLastStage()) {
|
||||||
|
for (auto &node : forward_end_pair.first) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
forward_end_before_pair.first.push_back(cnode->input(1));
|
||||||
|
}
|
||||||
|
for (auto &node : forward_end_pair.second) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
forward_end_before_pair.second.push_back(cnode->input(1));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
forward_end_before_pair = forward_end_pair;
|
||||||
|
}
|
||||||
|
ReorderForForward(forward_start_pair.first, forward_end_pair.second, root);
|
||||||
|
ReorderForBackward(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair,
|
||||||
|
forward_end_before_pair, root);
|
||||||
|
ReorderForParams(backward_params, forward_params, allreduce_params, forward_params_pair, backward_params_pair,
|
||||||
|
backward_end, forward_start_pair, root);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||||
|
std::vector<AnfNodePtr> forward_end;
|
||||||
|
std::vector<AnfNodePtr> forward_start;
|
||||||
|
std::vector<AnfNodePtr> forward_params;
|
||||||
|
for (auto &node : root->nodes()) {
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
|
||||||
|
forward_start.push_back(node);
|
||||||
|
}
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_END)) {
|
||||||
|
forward_end.push_back(node);
|
||||||
|
}
|
||||||
|
if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
|
||||||
|
forward_params.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::sort(forward_start.begin(), forward_start.end(), CompFunc);
|
||||||
|
std::sort(forward_end.begin(), forward_end.end(), CompFunc);
|
||||||
|
std::sort(forward_params.begin(), forward_params.end(), CompFunc);
|
||||||
|
auto forward_start_pair = Deduplicate(forward_start, root, 0);
|
||||||
|
auto forward_end_pair = Deduplicate(forward_end, root, 0);
|
||||||
|
auto forward_params_pair = Deduplicate(forward_params, root, 0);
|
||||||
|
if (!forward_end.empty() && !forward_params.empty()) {
|
||||||
|
InsertDepend(forward_params_pair.second[0], forward_end_pair.first[0], manager, root);
|
||||||
|
}
|
||||||
|
if (!forward_start.empty() && !forward_params.empty()) {
|
||||||
|
InsertDepend(forward_params_pair.second[0], forward_start_pair.first[0], manager, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/manager.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
using PipelinePair = std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>;
|
||||||
|
AnfNodePtr FindAccuGrad(const CNodePtr &cnode);
|
||||||
|
bool IsLastStage();
|
||||||
|
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
|
||||||
|
const AnfNodePtr &accu_parameter);
|
||||||
|
void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr ¶m);
|
||||||
|
AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name);
|
||||||
|
void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
|
||||||
|
void AddVirtualAssignAdd(const FuncGraphPtr &root);
|
||||||
|
bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2);
|
||||||
|
void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end,
|
||||||
|
const FuncGraphPtr &root);
|
||||||
|
void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
|
||||||
|
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
|
||||||
|
const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root);
|
||||||
|
void ReorderForParams(const std::vector<AnfNodePtr> &backward_params, const std::vector<AnfNodePtr> &forward_params,
|
||||||
|
const std::vector<AnfNodePtr> &allreduce_params, const PipelinePair &forward_params_pair,
|
||||||
|
const PipelinePair &backward_params_pair, const std::vector<AnfNodePtr> &backward_end,
|
||||||
|
const PipelinePair &forward_start_pair, const FuncGraphPtr &root);
|
||||||
|
int64_t GetMicroBatch(const AnfNodePtr &node);
|
||||||
|
void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
|
||||||
|
const FuncGraphPtr &root);
|
||||||
|
PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max);
|
||||||
|
void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
|
||||||
|
std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
|
||||||
|
std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
|
||||||
|
std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root);
|
||||||
|
void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||||
|
void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||||
|
void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||||
|
void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value);
|
||||||
|
AnfNodePtr GetPreNode(const AnfNodePtr &node);
|
||||||
|
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||||
|
void SetStridedSliceStrategy(const AnfNodePtr &node);
|
||||||
|
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||||
|
void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
|
||||||
|
const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, size_t micro_size);
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_
|
|
@ -342,11 +342,12 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
|
||||||
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) {
|
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) {
|
||||||
MS_EXCEPTION_IF_NULL(comm_node);
|
MS_EXCEPTION_IF_NULL(comm_node);
|
||||||
MS_EXCEPTION_IF_NULL(param_node);
|
MS_EXCEPTION_IF_NULL(param_node);
|
||||||
|
ParameterPtr param;
|
||||||
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
|
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
|
||||||
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now.";
|
param = param_node->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
|
||||||
return;
|
} else {
|
||||||
|
param = param_node->cast<ParameterPtr>();
|
||||||
}
|
}
|
||||||
auto param = param_node->cast<ParameterPtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(param);
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
|
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
@ -372,6 +373,22 @@ void AddCommOpMeanFlag(const CNodePtr &comm_node) {
|
||||||
prim->SetAttrs(attrs);
|
prim->SetAttrs(attrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddCommOpParamFlag(const CNodePtr &comm_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(comm_node);
|
||||||
|
auto graph = comm_node->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto manager = graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto node_users = manager->node_users()[comm_node->input(1)];
|
||||||
|
for (auto &node_user : node_users) {
|
||||||
|
if (IsPrimitiveCNode(node_user.first, prim::kPrimSend)) {
|
||||||
|
auto prim = GetCNodePrimitive(comm_node);
|
||||||
|
prim->AddAttr(PARAMETER_MICRO, MakeValue(0));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Operator CreateAllGatherOp(const std::string &group) {
|
Operator CreateAllGatherOp(const std::string &group) {
|
||||||
OperatorName operator_name = ALL_GATHER;
|
OperatorName operator_name = ALL_GATHER;
|
||||||
ValuePtr attr0_value = MakeValue(group); // group
|
ValuePtr attr0_value = MakeValue(group); // group
|
||||||
|
@ -438,6 +455,7 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
|
||||||
OperatorVector op_for_weight;
|
OperatorVector op_for_weight;
|
||||||
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
|
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
|
||||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||||
|
int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||||
|
|
||||||
ValuePtr attr0_value = MakeValue(group_name);
|
ValuePtr attr0_value = MakeValue(group_name);
|
||||||
ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
|
ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
|
||||||
|
@ -459,6 +477,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
|
||||||
Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value);
|
Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value);
|
||||||
operator_attrs.push_back(attr3);
|
operator_attrs.push_back(attr3);
|
||||||
MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror";
|
MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror";
|
||||||
|
} else if (split_stage_num > 1) {
|
||||||
|
operator_name = MIRROR_MICRO_STEP_OPERATOR;
|
||||||
} else {
|
} else {
|
||||||
operator_name = MIRROR_OPERATOR;
|
operator_name = MIRROR_OPERATOR;
|
||||||
}
|
}
|
||||||
|
|
|
@ -294,6 +294,7 @@ Operator CreateAllGatherOp(const std::string &group);
|
||||||
Operator CreateMiniStepAllGatherOp(const std::string &group);
|
Operator CreateMiniStepAllGatherOp(const std::string &group);
|
||||||
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node);
|
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node);
|
||||||
void AddCommOpMeanFlag(const CNodePtr &comm_node);
|
void AddCommOpMeanFlag(const CNodePtr &comm_node);
|
||||||
|
void AddCommOpParamFlag(const CNodePtr &comm_node);
|
||||||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
|
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
|
||||||
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
|
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
|
||||||
int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
|
int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
|
||||||
|
|
|
@ -168,7 +168,7 @@ constexpr char CLONED_INDEX[] = "cloned_index";
|
||||||
constexpr char BE_CLONED_INDEX[] = "be_cloned_index";
|
constexpr char BE_CLONED_INDEX[] = "be_cloned_index";
|
||||||
constexpr char GROUP_RANKS[] = "group_ranks";
|
constexpr char GROUP_RANKS[] = "group_ranks";
|
||||||
constexpr char IS_IN_FORWARD[] = "is_in_forward";
|
constexpr char IS_IN_FORWARD[] = "is_in_forward";
|
||||||
constexpr char DTYPE[] = "DType";
|
constexpr char DTYPE[] = "dtype";
|
||||||
constexpr char DEV_NUM[] = "dev_num";
|
constexpr char DEV_NUM[] = "dev_num";
|
||||||
constexpr char MEAN_FLAG[] = "mean_flag";
|
constexpr char MEAN_FLAG[] = "mean_flag";
|
||||||
constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step";
|
constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step";
|
||||||
|
@ -358,6 +358,23 @@ constexpr char UNIQUE[] = "Unique";
|
||||||
constexpr char GATHERND[] = "GatherNd";
|
constexpr char GATHERND[] = "GatherNd";
|
||||||
constexpr char SCATTER_UPDATE[] = "ScatterUpdate";
|
constexpr char SCATTER_UPDATE[] = "ScatterUpdate";
|
||||||
|
|
||||||
|
// pipeline
|
||||||
|
constexpr char MICRO[] = "micro";
|
||||||
|
constexpr char DEST_RANK[] = "dest_rank";
|
||||||
|
constexpr char SRC_RANK[] = "src_rank";
|
||||||
|
constexpr char PIPELINE_PARAM[] = "pipeline_param";
|
||||||
|
constexpr char PIPELINE_END[] = "pipeline_end";
|
||||||
|
constexpr char PIPELINE_BEGIN[] = "pipeline_begin";
|
||||||
|
constexpr char MAIN_GRAPH[] = "main_graph";
|
||||||
|
constexpr char SR_TAG[] = "sr_tag";
|
||||||
|
constexpr char GROUP_BACK[] = "group_back";
|
||||||
|
constexpr char MIRROR_MICRO_STEP_OPERATOR[] = "_MirrorMicroStepOperator";
|
||||||
|
constexpr char PARAMETER_MICRO[] = "parameter_micro";
|
||||||
|
constexpr char VIRTUAL_ASSIGN_ADD[] = "_VirtualAssignAdd";
|
||||||
|
constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad";
|
||||||
|
constexpr char ACCU_GRAD[] = "accu_grad";
|
||||||
|
constexpr char PARAMETER_START[] = "parameter_start";
|
||||||
|
|
||||||
// Parallel don't care
|
// Parallel don't care
|
||||||
constexpr char STRING_EQUAL[] = "string_equal";
|
constexpr char STRING_EQUAL[] = "string_equal";
|
||||||
constexpr char MAKE_TUPLE[] = "MakeTuple";
|
constexpr char MAKE_TUPLE[] = "MakeTuple";
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "frontend/parallel/step_parallel.h"
|
#include "frontend/parallel/step_parallel.h"
|
||||||
#include "frontend/parallel/node_check.h"
|
#include "frontend/parallel/node_check.h"
|
||||||
#include "frontend/parallel/graph_util/node_info.h"
|
#include "frontend/parallel/graph_util/node_info.h"
|
||||||
|
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
#include "utils/comm_manager.h"
|
#include "utils/comm_manager.h"
|
||||||
|
@ -52,30 +53,74 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, size_t curr_depth) {
|
void PipelineTransformer::MainGraph() {
|
||||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
if (!root_->has_flag(TRAINING)) {
|
||||||
MS_LOG(WARNING) << "When setting the tags for Grad nodes, exceeded the maximum recursion depth: "
|
main_graph_ = root_;
|
||||||
<< MAX_RECURSIVE_DEPTH;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const auto &node_users = manager->node_users()[node];
|
for (auto &fg : manager_->func_graphs()) {
|
||||||
for (auto &user_pair : node_users) {
|
for (auto &node : fg->nodes()) {
|
||||||
auto user_node = user_pair.first;
|
if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
|
||||||
if (!user_node->grad()) {
|
main_graph_ = fg;
|
||||||
user_node->set_grad(true);
|
main_graph_->set_flag(MAIN_GRAPH, true);
|
||||||
SetGradTag(user_node, manager, ++curr_depth);
|
virtual_dataset_ = node;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find main graph, possible reason is can't find virtual dataset.";
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micro_size) {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find MicroBatch information.";
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
auto value = GetValueNode(cnode->input(2));
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
auto tuple = GetValue<std::vector<int64_t>>(value);
|
||||||
|
auto input_shape = GetNodeShape(cnode->input(1)).at(0);
|
||||||
|
int64_t micro = tuple.at(0) * micro_size / input_shape.at(0);
|
||||||
|
cnode->AddPrimalAttr(MICRO, MakeValue(micro));
|
||||||
|
cnode->AddPrimalAttr(PIPELINE_BEGIN, MakeValue(micro));
|
||||||
|
return MakeValue(micro);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PipelineTransformer::LabelMicroBatch() {
|
||||||
|
MS_EXCEPTION_IF_NULL(main_graph_);
|
||||||
|
MS_EXCEPTION_IF_NULL(virtual_dataset_);
|
||||||
|
auto node_user_map = manager_->node_users();
|
||||||
|
auto node_users = node_user_map[virtual_dataset_];
|
||||||
|
for (auto &node_user : node_users) {
|
||||||
|
if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
|
||||||
|
auto data_users = manager_->node_users()[node_user.first];
|
||||||
|
auto micro_size = int64_t(data_users.size());
|
||||||
|
micro_size_ = micro_size;
|
||||||
|
MS_LOG(INFO) << "Micro Size is: " << micro_size;
|
||||||
|
for (auto &data_user : data_users) {
|
||||||
|
auto micro = SetMicroBatch(data_user.first, micro_size);
|
||||||
|
SetStridedSliceStrategy(data_user.first);
|
||||||
|
auto cnode = data_user.first->cast<CNodePtr>();
|
||||||
|
BroadCastMicroBatch(cnode, &node_user_map, micro);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PipelineTransformer::LabelRequiredGradCNode() {
|
void PipelineTransformer::CreateForwardGroup() {
|
||||||
auto parameters = root_->parameters();
|
std::vector<int64_t> rank_list;
|
||||||
for (auto parameter : parameters) {
|
auto rank_id = g_device_manager->global_rank();
|
||||||
if (!ParameterRequireGrad(parameter)) {
|
auto stage_id = g_device_manager->stage_id();
|
||||||
continue;
|
auto stage_num = g_device_manager->stage_num();
|
||||||
}
|
for (int64_t i = 0; i < stage_num; ++i) {
|
||||||
SetGradTag(parameter, manager_, 0);
|
rank_list.push_back(rank_id + per_stage_rank_num_ * (i - stage_id));
|
||||||
}
|
}
|
||||||
|
auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list);
|
||||||
|
auto g = g_device_manager->CreateGroup(rank_list);
|
||||||
|
auto g_back_name = g.name() + BACKWARD;
|
||||||
|
auto g_back = g_device_manager->CreateGroup(g_back_name, dev_list);
|
||||||
|
group_.push_back(g.name());
|
||||||
|
group_.push_back(g_back.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PipelineTransformer::Coloring() {
|
void PipelineTransformer::Coloring() {
|
||||||
|
@ -84,7 +129,7 @@ void PipelineTransformer::Coloring() {
|
||||||
while (need_coloring) {
|
while (need_coloring) {
|
||||||
need_coloring = false;
|
need_coloring = false;
|
||||||
for (auto &fg : manager_->func_graphs()) {
|
for (auto &fg : manager_->func_graphs()) {
|
||||||
if (fg == root_) {
|
if (fg == root_ && root_->has_flag(TRAINING)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto value_nodes = fg->value_nodes();
|
auto value_nodes = fg->value_nodes();
|
||||||
|
@ -94,16 +139,15 @@ void PipelineTransformer::Coloring() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto graph = GetValueNode<FuncGraphPtr>(node);
|
auto graph = GetValueNode<FuncGraphPtr>(node);
|
||||||
auto need_grad = graph->get_return()->grad();
|
if (graph->stage() == -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
stage_set.insert(graph->stage());
|
||||||
auto node_users = manager_->node_users()[node];
|
auto node_users = manager_->node_users()[node];
|
||||||
for (auto &user_pair : node_users) {
|
for (auto &user_pair : node_users) {
|
||||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||||
user_node->set_stage(graph->stage());
|
user_node->set_stage(graph->stage());
|
||||||
user_node->set_grad(need_grad);
|
|
||||||
auto user_node_graph = user_node->func_graph();
|
auto user_node_graph = user_node->func_graph();
|
||||||
if (graph->stage() != -1) {
|
|
||||||
stage_set.insert(graph->stage());
|
|
||||||
}
|
|
||||||
if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
|
if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
|
||||||
user_node_graph->set_stage(graph->stage());
|
user_node_graph->set_stage(graph->stage());
|
||||||
need_coloring = true;
|
need_coloring = true;
|
||||||
|
@ -117,22 +161,37 @@ void PipelineTransformer::Coloring() {
|
||||||
if (SizeToLong(stage_set.size()) != stage_num) {
|
if (SizeToLong(stage_set.size()) != stage_num) {
|
||||||
MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
|
MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PipelineTransformer::BroadCastColoring() {
|
void PipelineTransformer::BroadCastColoring() {
|
||||||
for (auto &fg : manager_->func_graphs()) {
|
auto need_coloring = true;
|
||||||
if (fg == root_ || fg->stage() == -1) {
|
while (need_coloring) {
|
||||||
continue;
|
need_coloring = false;
|
||||||
|
auto all_nodes = main_graph_->nodes();
|
||||||
|
auto node_users = manager_->node_users();
|
||||||
|
for (auto &node : all_nodes) {
|
||||||
|
if (!node->isa<CNode>() || node->stage() == -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto stage = node->stage();
|
||||||
|
for (auto &user_pair : node_users[node]) {
|
||||||
|
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||||
|
auto user_node_stage = user_node->stage();
|
||||||
|
if (stage > user_node_stage) {
|
||||||
|
user_node->set_stage(stage);
|
||||||
|
need_coloring = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
DoBroadCast(fg);
|
|
||||||
SetNoStageNode(fg);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
|
bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
if (!prim) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (IsInWhiteList(cnode)) {
|
if (IsInWhiteList(cnode)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -148,6 +207,9 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
|
||||||
if (!IsPipelineCareNode(cnode)) {
|
if (!IsPipelineCareNode(cnode)) {
|
||||||
MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " is not a Pipeline Care Node.";
|
MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " is not a Pipeline Care Node.";
|
||||||
}
|
}
|
||||||
|
if (IsPrimitiveCNode(cnode, prim::kPrimVirtualDataset)) {
|
||||||
|
SetVirtualDatasetStrategy(cnode);
|
||||||
|
}
|
||||||
auto shape_list = ExtractShape(cnode);
|
auto shape_list = ExtractShape(cnode);
|
||||||
if (shape_list.empty()) {
|
if (shape_list.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " failed to extract shape.";
|
MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " failed to extract shape.";
|
||||||
|
@ -155,7 +217,7 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
|
||||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
if (prim->name() == RESHAPE) {
|
if (prim->name() == RESHAPE) {
|
||||||
MS_LOG(EXCEPTION) << "Reshape op can't be a border.";
|
MS_LOG(EXCEPTION) << "Reshape op can't be a border. node:" << cnode->DebugString();
|
||||||
}
|
}
|
||||||
auto attrs = prim->attrs();
|
auto attrs = prim->attrs();
|
||||||
auto op_info = OperatorInstance(prim, attrs, shape_list);
|
auto op_info = OperatorInstance(prim, attrs, shape_list);
|
||||||
|
@ -190,93 +252,87 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const A
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
// Handle Cast and TupleGetitem situation
|
// Handle Cast and TupleGetitem situation
|
||||||
size_t tensor_info_index = 0;
|
size_t tensor_info_index = 0;
|
||||||
if (IsPrimitiveCNode(cnode, prim::kPrimCast)) {
|
OperatorInfoPtr op_info;
|
||||||
cnode = cnode->input(1)->cast<CNodePtr>();
|
if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
|
||||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
op_info = node->user_data<OperatorInfo>();
|
||||||
tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode));
|
} else {
|
||||||
cnode = cnode->input(1)->cast<CNodePtr>();
|
if (IsPrimitiveCNode(node, prim::kPrimCast)) {
|
||||||
|
cnode = cnode->input(1)->cast<CNodePtr>();
|
||||||
|
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||||
|
tensor_info_index = LongToSize(GetTupleGetItemIndex(cnode));
|
||||||
|
cnode = cnode->input(1)->cast<CNodePtr>();
|
||||||
|
}
|
||||||
|
// Create OperatorInfo to get slice_shape for send/recv
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
op_info = CreateOpInfo(cnode);
|
||||||
}
|
}
|
||||||
// Create OperatorInfo to get slice_shape for send/recv
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto op_info = CreateOpInfo(cnode);
|
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index];
|
auto tensor_info = op_info->outputs_tensor_info()[tensor_info_index];
|
||||||
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
|
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr PipelineTransformer::HandleMonadLoad(const AnfNodePtr &node) {
|
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto &node_users = manager_->node_users()[node];
|
auto node_users_map = manager_->node_users();
|
||||||
|
auto node_users = node_users_map[node];
|
||||||
|
for (auto &node_user : node_users) {
|
||||||
|
auto load = node_user.first->cast<CNodePtr>();
|
||||||
|
if (IsPrimitiveCNode(load, prim::kPrimLoad)) {
|
||||||
|
node_users = node_users_map[load];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
for (auto &user_pair : node_users) {
|
for (auto &user_pair : node_users) {
|
||||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(user_node);
|
MS_EXCEPTION_IF_NULL(user_node);
|
||||||
if (IsPipelineCareNode(user_node)) {
|
auto user_node_graph = user_node->func_graph();
|
||||||
return user_node;
|
MS_EXCEPTION_IF_NULL(user_node_graph);
|
||||||
|
if (user_node_graph->stage() == -1) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
}
|
auto care_node = user_node;
|
||||||
return nullptr;
|
auto index = user_pair.second;
|
||||||
}
|
if (IsValueNode<FuncGraph>(user_node->input(0))) {
|
||||||
|
auto graph = GetValueNode<FuncGraphPtr>(user_node->input(0));
|
||||||
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
|
auto temp_params = graph->parameters();
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
if (temp_params.size() < IntToSize(user_pair.second)) {
|
||||||
auto &node_users = manager_->node_users()[node];
|
MS_LOG(EXCEPTION) << "parameter:" << node->DebugString() << " out of graph: " << graph->ToString()
|
||||||
for (auto &user_pair : node_users) {
|
<< "'s range.";
|
||||||
auto care_node = user_pair.first;
|
|
||||||
auto care_cnode = care_node->cast<CNodePtr>();
|
|
||||||
if (IsPrimitiveCNode(care_node, prim::kPrimLoad)) {
|
|
||||||
care_cnode = HandleMonadLoad(care_node);
|
|
||||||
if (!care_cnode) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
} else {
|
auto temp_param = temp_params[user_pair.second - 1];
|
||||||
if (!IsPipelineCareNode(care_cnode)) {
|
auto temp_users = node_users_map[temp_param];
|
||||||
continue;
|
for (auto &temp_user : temp_users) {
|
||||||
|
auto load_temp = temp_user.first->cast<CNodePtr>();
|
||||||
|
if (IsPrimitiveCNode(load_temp, prim::kPrimLoad)) {
|
||||||
|
temp_users = node_users_map[load_temp];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto &temp_pair : temp_users) {
|
||||||
|
auto temp_cnode = temp_pair.first->cast<CNodePtr>();
|
||||||
|
if (!IsPipelineCareNode(temp_cnode)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
care_node = temp_cnode;
|
||||||
|
index = temp_pair.second;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(care_cnode);
|
if (!IsPipelineCareNode(care_node)) {
|
||||||
auto op_info = CreateOpInfo(care_cnode);
|
continue;
|
||||||
|
}
|
||||||
|
auto op_info = CreateOpInfo(care_node);
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
auto tensor_info = op_info->inputs_tensor_info()[IntToSize(user_pair.second) - 1];
|
auto tensor_info = op_info->inputs_tensor_info()[IntToSize(index) - 1];
|
||||||
return std::make_pair(nullptr, std::make_shared<TensorInfo>(tensor_info));
|
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
|
||||||
}
|
}
|
||||||
return std::make_pair(nullptr, nullptr);
|
return std::make_pair(nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
|
std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
|
||||||
auto need_coloring = true;
|
|
||||||
while (need_coloring) {
|
|
||||||
need_coloring = false;
|
|
||||||
auto all_nodes = func->nodes();
|
|
||||||
auto &node_users = manager_->node_users();
|
|
||||||
for (auto &node : all_nodes) {
|
|
||||||
if (node->isa<CNode>() || node->stage() == -1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto stage = node->stage();
|
|
||||||
for (auto &user_pair : node_users[node]) {
|
|
||||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
|
||||||
auto user_node_stage = user_node->stage();
|
|
||||||
if (IsValueNode<FuncGraph>(user_node->input(0)) && stage > user_node_stage) {
|
|
||||||
user_node->set_stage(stage);
|
|
||||||
need_coloring = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PipelineTransformer::SetNoStageNode(const FuncGraphPtr &func) {
|
|
||||||
auto all_nodes = func->nodes();
|
|
||||||
for (auto &node : all_nodes) {
|
|
||||||
if (!node->isa<CNode>() || node->stage() != -1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
node->set_stage(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PipelineTransformer::HandleSharedParameter() {
|
|
||||||
auto parameters = root_->parameters();
|
auto parameters = root_->parameters();
|
||||||
|
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
std::vector<AnfNodePtr> recvs = {};
|
||||||
for (auto ¶meter : parameters) {
|
for (auto ¶meter : parameters) {
|
||||||
auto parameter_stage = parameter_color_map[parameter];
|
auto parameter_stage = parameter_color_map[parameter];
|
||||||
if (parameter_stage.size() <= 1) {
|
if (parameter_stage.size() <= 1) {
|
||||||
|
@ -285,37 +341,41 @@ void PipelineTransformer::HandleSharedParameter() {
|
||||||
auto users = manager_->node_users()[parameter];
|
auto users = manager_->node_users()[parameter];
|
||||||
for (auto &user : users) {
|
for (auto &user : users) {
|
||||||
auto node = user.first;
|
auto node = user.first;
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
auto graph = node->func_graph();
|
auto graph = node->func_graph();
|
||||||
if (graph != root_ && graph->stage() == -1) {
|
if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||||
MS_LOG(EXCEPTION) << "Don't support this situation.";
|
graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||||
}
|
}
|
||||||
if (graph == root_ || graph->stage() != stage_) {
|
if (graph == root_ || graph->stage() == -1 || !parameter_stage.count(stage_)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
auto micro = cnode->GetPrimalAttr(MICRO);
|
||||||
|
if (!micro) {
|
||||||
|
MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch";
|
||||||
|
micro = MakeValue(int64_t(0));
|
||||||
|
}
|
||||||
|
auto user_stage = node->stage();
|
||||||
if (stage_ == *parameter_stage.begin()) {
|
if (stage_ == *parameter_stage.begin()) {
|
||||||
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
|
if (graph->stage() == stage_) {
|
||||||
for (auto &stage : parameter_stage) {
|
continue;
|
||||||
if (stage == stage_) {
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
auto send_out = InsertSend(graph, parameter, stage, stage_);
|
|
||||||
make_tuple_input.push_back(send_out.depend);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
auto make_tuple = graph->NewCNode(make_tuple_input);
|
if (Reuse(parameter, user_stage, make_tuple_input, DEST_RANK)) {
|
||||||
OperatorAttrs depend_attrs;
|
continue;
|
||||||
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "");
|
}
|
||||||
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, make_tuple};
|
auto send_out = InsertSend(main_graph_, parameter, user_stage, stage_, micro);
|
||||||
auto depend = graph->NewCNode(depend_input);
|
make_tuple_input.push_back(send_out.depend);
|
||||||
depend->set_abstract(parameter->abstract());
|
|
||||||
manager_->SetEdge(node, user.second, depend);
|
|
||||||
break;
|
|
||||||
} else {
|
} else {
|
||||||
(void)InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin());
|
auto receive = Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK);
|
||||||
break;
|
if (receive) {
|
||||||
|
manager_->SetEdge(node, user.second, receive);
|
||||||
|
} else {
|
||||||
|
auto recv = InsertReceive(main_graph_, parameter, node, user.second, stage_, *parameter_stage.begin(), micro);
|
||||||
|
recvs.push_back(recv);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return make_tuple_input;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PipelineTransformer::ParameterColoring() {
|
void PipelineTransformer::ParameterColoring() {
|
||||||
|
@ -324,14 +384,24 @@ void PipelineTransformer::ParameterColoring() {
|
||||||
auto users = manager_->node_users()[parameter];
|
auto users = manager_->node_users()[parameter];
|
||||||
std::set<int64_t> parameter_stage;
|
std::set<int64_t> parameter_stage;
|
||||||
for (auto &user : users) {
|
for (auto &user : users) {
|
||||||
auto node = user.first;
|
auto node = user.first->cast<CNodePtr>();
|
||||||
auto graph = node->func_graph();
|
auto graph = node->func_graph();
|
||||||
|
if (IsValueNode<FuncGraph>(node->input(0))) {
|
||||||
|
graph = GetValueNode<FuncGraphPtr>(node->input(0));
|
||||||
|
}
|
||||||
if (graph != root_ && graph->stage() != -1) {
|
if (graph != root_ && graph->stage() != -1) {
|
||||||
parameter_stage.insert(graph->stage());
|
parameter_stage.insert(graph->stage());
|
||||||
parameter->set_stage(graph->stage());
|
parameter->set_stage(graph->stage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (*parameter_stage.begin() == stage_ && !virtual_param_) {
|
auto param_info = parameter->cast<ParameterPtr>()->param_info();
|
||||||
|
if (!param_info) {
|
||||||
|
parameter_color_map[parameter] = parameter_stage;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(param_info);
|
||||||
|
auto requires_grad = param_info->requires_grad();
|
||||||
|
if (*parameter_stage.begin() == stage_ && !virtual_param_ && requires_grad) {
|
||||||
virtual_param_ = parameter;
|
virtual_param_ = parameter;
|
||||||
}
|
}
|
||||||
parameter_color_map[parameter] = parameter_stage;
|
parameter_color_map[parameter] = parameter_stage;
|
||||||
|
@ -343,8 +413,8 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) {
|
if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||||
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||||
auto graph_return = graph->get_return();
|
auto graph_output = graph->output();
|
||||||
type = graph_return->Type();
|
type = graph_output->Type();
|
||||||
} else {
|
} else {
|
||||||
type = node->Type();
|
type = node->Type();
|
||||||
}
|
}
|
||||||
|
@ -359,40 +429,38 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con
|
||||||
return std::make_pair(shape_list, dtype);
|
return std::make_pair(shape_list, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr PipelineTransformer::HandleMonadDepend(const AnfNodePtr &node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
return HandleMonadDepend(cnode->input(1));
|
|
||||||
}
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) {
|
AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||||
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||||
auto output = HandleMonadDepend(graph->output());
|
auto output = graph->output();
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
if (output->isa<Parameter>()) {
|
if (output->isa<Parameter>()) {
|
||||||
return output;
|
auto parameters = graph->parameters();
|
||||||
|
auto pos_iter = std::find(parameters.begin(), parameters.end(), output);
|
||||||
|
auto pos = std::distance(parameters.begin(), pos_iter);
|
||||||
|
return FindPipelineCareNode(cnode->input(pos + 1));
|
||||||
}
|
}
|
||||||
cnode = output->cast<CNodePtr>();
|
cnode = output->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
}
|
}
|
||||||
|
if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
|
||||||
|
return FindPipelineCareNode(cnode->input(1));
|
||||||
|
}
|
||||||
if (IsInWhiteList(cnode)) {
|
if (IsInWhiteList(cnode)) {
|
||||||
return cnode->cast<AnfNodePtr>();
|
return cnode->cast<AnfNodePtr>();
|
||||||
}
|
}
|
||||||
if (!IsPipelineCareNode(cnode)) {
|
if (!IsPipelineCareNode(cnode)) {
|
||||||
MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border.";
|
MS_LOG(EXCEPTION) << "Only PipelineSplit cared node can be a border."
|
||||||
|
<< " border node: " << cnode->DebugString();
|
||||||
}
|
}
|
||||||
return cnode->cast<AnfNodePtr>();
|
return cnode->cast<AnfNodePtr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter,
|
SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter,
|
||||||
int64_t user_node_stage, int64_t node_stage) {
|
int64_t user_node_stage, int64_t node_stage, const ValuePtr &value) {
|
||||||
auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
|
auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
|
||||||
int64_t send_tag;
|
int64_t send_tag;
|
||||||
if (send_tag_map.find(dest_rank) != send_tag_map.end()) {
|
if (send_tag_map.find(dest_rank) != send_tag_map.end()) {
|
||||||
|
@ -402,17 +470,25 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
||||||
send_tag = 0;
|
send_tag = 0;
|
||||||
send_tag_map[dest_rank] = 0;
|
send_tag_map[dest_rank] = 0;
|
||||||
}
|
}
|
||||||
Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag));
|
Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
|
||||||
Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank));
|
Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage));
|
||||||
OperatorAttrs attrs = {attr_tag, attr_rank};
|
Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
|
||||||
auto send_op = CreatOpInstance(attrs, SEND, "send");
|
Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
|
||||||
|
OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
|
||||||
|
auto send_op = CreatOpInstance(attrs, SEND, SEND);
|
||||||
auto send_node = NewValueNode(send_op);
|
auto send_node = NewValueNode(send_op);
|
||||||
auto prim = GetValueNode<PrimitivePtr>(send_node);
|
auto prim = GetValueNode<PrimitivePtr>(send_node);
|
||||||
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
|
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
|
||||||
|
AnfNodePtr care_node;
|
||||||
if (parameter->isa<Parameter>()) {
|
if (parameter->isa<Parameter>()) {
|
||||||
op_info_pair = GetParameterPair(parameter);
|
op_info_pair = GetParameterPair(parameter);
|
||||||
} else {
|
} else {
|
||||||
auto care_node = FindPipelineCareNode(parameter);
|
if (IsPrimitiveCNode(parameter, prim::kPrimCast)) {
|
||||||
|
auto parameter_cnode = parameter->cast<CNodePtr>();
|
||||||
|
care_node = FindPipelineCareNode(parameter_cnode->input(1));
|
||||||
|
} else {
|
||||||
|
care_node = FindPipelineCareNode(parameter);
|
||||||
|
}
|
||||||
if (care_node->isa<Parameter>()) {
|
if (care_node->isa<Parameter>()) {
|
||||||
op_info_pair = GetParameterPair(care_node);
|
op_info_pair = GetParameterPair(care_node);
|
||||||
} else {
|
} else {
|
||||||
|
@ -423,14 +499,20 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
||||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||||
auto slice_shape = tensor_info->slice_shape();
|
auto slice_shape = tensor_info->slice_shape();
|
||||||
auto shape_type_pair = GetShapeType(parameter, slice_shape);
|
auto shape_type_pair = GetShapeType(parameter, slice_shape);
|
||||||
prim->set_attr("shape", shape_type_pair.first);
|
prim->set_attr(SHAPE, shape_type_pair.first);
|
||||||
prim->set_attr("dtype", shape_type_pair.second);
|
prim->set_attr(DTYPE, shape_type_pair.second);
|
||||||
std::vector<AnfNodePtr> send_input = {send_node, parameter};
|
std::vector<AnfNodePtr> send_input = {send_node, parameter};
|
||||||
auto send = graph->NewCNode(send_input);
|
auto send = main_graph_->NewCNode(send_input);
|
||||||
|
if (!parameter->isa<Parameter>() && care_node != nullptr && !care_node->isa<Parameter>()) {
|
||||||
|
send->AddPrimalAttr(PIPELINE_END, value);
|
||||||
|
} else {
|
||||||
|
send->AddPrimalAttr(PIPELINE_PARAM, value);
|
||||||
|
send->AddPrimalAttr(MICRO, value);
|
||||||
|
}
|
||||||
OperatorAttrs depend_attrs;
|
OperatorAttrs depend_attrs;
|
||||||
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend");
|
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||||
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
|
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
|
||||||
auto depend = graph->NewCNode(depend_input);
|
auto depend = main_graph_->NewCNode(depend_input);
|
||||||
auto abstract = parameter->abstract();
|
auto abstract = parameter->abstract();
|
||||||
depend->set_abstract(abstract);
|
depend->set_abstract(abstract);
|
||||||
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
|
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
|
||||||
|
@ -439,7 +521,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
||||||
|
|
||||||
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||||
const AnfNodePtr &use_node, int index, int64_t user_node_stage,
|
const AnfNodePtr &use_node, int index, int64_t user_node_stage,
|
||||||
int64_t node_stage) {
|
int64_t node_stage, const ValuePtr &value) {
|
||||||
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
|
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
|
||||||
int64_t recv_tag;
|
int64_t recv_tag;
|
||||||
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
|
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
|
||||||
|
@ -449,9 +531,10 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
|
||||||
recv_tag = 0;
|
recv_tag = 0;
|
||||||
recv_tag_map[src_rank] = 0;
|
recv_tag_map[src_rank] = 0;
|
||||||
}
|
}
|
||||||
Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag));
|
Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
|
||||||
Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank));
|
Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
|
||||||
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
|
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
|
||||||
|
bool is_param = true;
|
||||||
if (node->isa<Parameter>()) {
|
if (node->isa<Parameter>()) {
|
||||||
op_info_pair = GetParameterPair(node);
|
op_info_pair = GetParameterPair(node);
|
||||||
} else {
|
} else {
|
||||||
|
@ -460,28 +543,34 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
|
||||||
op_info_pair = GetParameterPair(care_node);
|
op_info_pair = GetParameterPair(care_node);
|
||||||
} else {
|
} else {
|
||||||
op_info_pair = GetOpInfo(care_node);
|
op_info_pair = GetOpInfo(care_node);
|
||||||
|
is_param = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto tensor_info = op_info_pair.second;
|
auto tensor_info = op_info_pair.second;
|
||||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||||
auto slice_shape = tensor_info->slice_shape();
|
auto tensor_layout = tensor_info->tensor_layout();
|
||||||
|
Shape slice_shape = tensor_info->slice_shape();
|
||||||
auto shape_type_pair = GetShapeType(node, slice_shape);
|
auto shape_type_pair = GetShapeType(node, slice_shape);
|
||||||
Attr attr_shape = std::make_pair("shape", shape_type_pair.first);
|
Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
|
||||||
Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second);
|
Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
|
||||||
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype};
|
Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
|
||||||
auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv");
|
Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
|
||||||
|
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
|
||||||
|
auto recv_op = CreatOpInstance(attrs, RECEIVE, RECEIVE);
|
||||||
std::vector<AnfNodePtr> recv_input;
|
std::vector<AnfNodePtr> recv_input;
|
||||||
if (node->isa<Parameter>()) {
|
if (node->isa<Parameter>()) {
|
||||||
recv_input = {NewValueNode(recv_op), node};
|
recv_input = {NewValueNode(recv_op), node};
|
||||||
} else {
|
} else {
|
||||||
if (node->grad()) {
|
recv_input = {NewValueNode(recv_op), virtual_param_};
|
||||||
recv_input = {NewValueNode(recv_op), virtual_param_};
|
|
||||||
} else {
|
|
||||||
auto param = root_->parameters()[0];
|
|
||||||
recv_input = {NewValueNode(recv_op), param};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
auto recv = graph->NewCNode(recv_input);
|
auto recv = graph->NewCNode(recv_input);
|
||||||
|
if (is_param) {
|
||||||
|
recv->set_user_data<AnfNode>(PIPELINE_PARAM, node);
|
||||||
|
recv->AddPrimalAttr(PIPELINE_PARAM, value);
|
||||||
|
} else {
|
||||||
|
recv->AddPrimalAttr(PIPELINE_BEGIN, value);
|
||||||
|
}
|
||||||
|
recv->AddPrimalAttr(MICRO, value);
|
||||||
auto node_abstract = node->abstract();
|
auto node_abstract = node->abstract();
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
@ -494,65 +583,53 @@ AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const A
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(node_abstract);
|
MS_EXCEPTION_IF_NULL(node_abstract);
|
||||||
recv->set_abstract(node_abstract);
|
recv->set_abstract(node_abstract);
|
||||||
if (op_info_pair.first != nullptr) {
|
if (node->isa<Parameter>()) {
|
||||||
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout()));
|
BaseShapePtr parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||||
recv->set_user_data<OperatorInfo>(op_info_pair.first);
|
auto abstract_clone = node->abstract()->Clone();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract_clone);
|
||||||
|
abstract_clone->set_shape(parallel_shape);
|
||||||
|
node->set_abstract(abstract_clone);
|
||||||
|
node->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
|
||||||
}
|
}
|
||||||
|
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
|
||||||
|
recv->set_user_data<OperatorInfo>(op_info_pair.first);
|
||||||
|
|
||||||
manager_->SetEdge(use_node, index, recv);
|
manager_->SetEdge(use_node, index, recv);
|
||||||
return recv;
|
return recv;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage,
|
AnfNodePtr PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
|
||||||
const std::vector<AnfNodePtr> &out_input) {
|
const std::string &tag) {
|
||||||
auto node_users = manager_->node_users()[node];
|
for (auto &input : out_input) {
|
||||||
auto dest_rank = global_rank_ + (next_node_stage - node_stage) * per_stage_rank_num_;
|
auto cnode = input->cast<CNodePtr>();
|
||||||
for (auto &depend : out_input) {
|
if (!cnode) {
|
||||||
if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto cnode = depend->cast<CNodePtr>();
|
if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
|
||||||
|
cnode = cnode->input(2)->cast<CNodePtr>();
|
||||||
|
}
|
||||||
if (cnode->input(1) == node) {
|
if (cnode->input(1) == node) {
|
||||||
auto send_cnode = cnode->input(2)->cast<CNodePtr>();
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
auto prim = GetValueNode<PrimitivePtr>(send_cnode->input(0));
|
auto dest_rank_send = GetValue<int64_t>(prim->GetAttr(tag));
|
||||||
auto dest_rank_send = GetValue<int64_t>(prim->GetAttr("dest_rank"));
|
if (dest_rank_send == stage) {
|
||||||
if (dest_rank_send == dest_rank) {
|
return input;
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<bool, int64_t> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
|
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
||||||
std::set<int64_t> tag_set;
|
|
||||||
auto node_stage = node->stage();
|
|
||||||
int64_t min_tag = node_stage;
|
|
||||||
for (auto &user_pair : node_users) {
|
|
||||||
auto user_node = user_pair.first;
|
|
||||||
auto user_node_stage = user_node->stage();
|
|
||||||
tag_set.insert(user_node_stage);
|
|
||||||
if (user_node_stage == -1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
min_tag = min_tag > user_node_stage ? user_node_stage : min_tag;
|
|
||||||
}
|
|
||||||
bool is_shared = tag_set.size() > 1;
|
|
||||||
return std::make_pair(is_shared, min_tag);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
|
||||||
OperatorAttrs depend_attrs;
|
OperatorAttrs depend_attrs;
|
||||||
auto depend_op = CreatOpInstance(depend_attrs, "Depend", "");
|
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||||
std::vector<AnfNodePtr> out_input = {NewValueNode(depend_op)};
|
std::vector<AnfNodePtr> receive_ops;
|
||||||
|
std::vector<AnfNodePtr> send_ops;
|
||||||
auto all_nodes = graph->nodes();
|
auto all_nodes = graph->nodes();
|
||||||
for (auto &node : all_nodes) {
|
for (auto &node : all_nodes) {
|
||||||
if (!node->isa<CNode>() || node->stage() == -1) {
|
if (!node->isa<CNode>() || node->stage() == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto node_users = manager_->node_users()[node];
|
auto node_users = manager_->node_users()[node];
|
||||||
auto shared_min_tag_pair = IsSharedNode(node, node_users);
|
|
||||||
auto is_shared = shared_min_tag_pair.first;
|
|
||||||
auto min_tag = shared_min_tag_pair.second;
|
|
||||||
AnfNodePtr receive = nullptr;
|
AnfNodePtr receive = nullptr;
|
||||||
for (auto &user_pair : node_users) {
|
for (auto &user_pair : node_users) {
|
||||||
auto user_node = user_pair.first;
|
auto user_node = user_pair.first;
|
||||||
|
@ -561,21 +638,25 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
||||||
if (node_stage != stage_ && user_node_stage != stage_) {
|
if (node_stage != stage_ && user_node_stage != stage_) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
|
||||||
|
if (!micro) {
|
||||||
|
MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
|
||||||
|
micro = MakeValue(int64_t(0));
|
||||||
|
}
|
||||||
if (node_stage < user_node_stage) {
|
if (node_stage < user_node_stage) {
|
||||||
if (is_shared && (min_tag != node_stage)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (node_stage == stage_) {
|
if (node_stage == stage_) {
|
||||||
if (Reuse(node, user_node_stage, node_stage, out_input)) {
|
if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto send_out = InsertSend(graph, node, user_node_stage, node_stage);
|
auto send_out = InsertSend(graph, node, user_node_stage, node_stage, micro);
|
||||||
out_input.insert(out_input.begin() + 1, send_out.depend);
|
MS_EXCEPTION_IF_NULL(send_out.depend);
|
||||||
type_ptr_ = send_out.type;
|
send_ops.push_back(send_out.depend);
|
||||||
shape_ = send_out.shape;
|
send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
|
||||||
|
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
|
||||||
} else {
|
} else {
|
||||||
if (!receive) {
|
if (!receive) {
|
||||||
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage);
|
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro);
|
||||||
|
receive_ops.push_back(receive);
|
||||||
} else {
|
} else {
|
||||||
manager_->SetEdge(user_node, user_pair.second, receive);
|
manager_->SetEdge(user_node, user_pair.second, receive);
|
||||||
}
|
}
|
||||||
|
@ -583,46 +664,40 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (node_stage > user_node_stage) {
|
if (node_stage > user_node_stage) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
MS_LOG(EXCEPTION) << "node_stage: " << node_stage
|
||||||
auto user_cnode = user_node->cast<CNodePtr>();
|
<< " must be smaller than user_node_stage: " << user_node_stage;
|
||||||
if (IsValueNode<FuncGraph>(cnode->input(0)) && IsValueNode<FuncGraph>(user_cnode->input(0))) {
|
|
||||||
MS_LOG(EXCEPTION) << "Don't support this situation";
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (out_input.size() == 2) {
|
return std::make_pair(send_ops, receive_ops);
|
||||||
manager_->Replace(graph->output(), out_input[1]);
|
|
||||||
}
|
|
||||||
if (out_input.size() > 2) {
|
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
||||||
make_tuple_inputs.insert(make_tuple_inputs.begin() + 1, out_input.begin() + 2, out_input.end());
|
|
||||||
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
|
||||||
std::vector<AnfNodePtr> out_depend_inputs = {out_input[0], out_input[1], make_tuple};
|
|
||||||
auto out_node = graph->NewCNode(out_depend_inputs);
|
|
||||||
manager_->Replace(graph->output(), out_node);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PipelineTransformer::CutGraph() {
|
void PipelineTransformer::CutGraph() {
|
||||||
for (auto &fg : manager_->func_graphs()) {
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||||
CutBorder(fg);
|
CreateForwardGroup();
|
||||||
|
MS_EXCEPTION_IF_NULL(main_graph_);
|
||||||
|
if (make_tuple_inputs.empty()) {
|
||||||
|
make_tuple_inputs = HandleSharedParameter();
|
||||||
}
|
}
|
||||||
}
|
auto send_recv_ops = CutBorder(main_graph_);
|
||||||
|
auto send_ops = send_recv_ops.first;
|
||||||
bool PipelineTransformer::IsStageNode(const CNodePtr &node) {
|
if (IsLastStage()) {
|
||||||
for (auto &input : node->inputs()) {
|
return;
|
||||||
if (input->isa<Parameter>()) {
|
|
||||||
return (*parameter_color_map[input].begin() == stage_ || input->stage() == -1);
|
|
||||||
} else if (input->isa<CNode>()) {
|
|
||||||
auto pre_node = input->cast<CNodePtr>();
|
|
||||||
return IsStageNode(pre_node);
|
|
||||||
} else {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return true;
|
if (send_ops.empty() && !root_->has_flag(TRAINING)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
make_tuple_inputs.insert(make_tuple_inputs.end(), send_ops.begin(), send_ops.end());
|
||||||
|
if (!send_ops.empty()) {
|
||||||
|
type_ptr_ = send_ops.back()->user_data<Type>(DTYPE);
|
||||||
|
shape_ = send_ops.back()->user_data<ValueList>(SHAPE);
|
||||||
|
}
|
||||||
|
auto make_tuple = main_graph_->NewCNode(make_tuple_inputs);
|
||||||
|
std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend)};
|
||||||
|
out.push_back(send_ops.back());
|
||||||
|
out.push_back(make_tuple);
|
||||||
|
auto out_node = main_graph_->NewCNode(out);
|
||||||
|
manager_->Replace(main_graph_->output(), out_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PipelineTransformer::ElimGraphStage() {
|
void PipelineTransformer::ElimGraphStage() {
|
||||||
|
@ -694,7 +769,21 @@ void PipelineTransformer::ElimParameter() {
|
||||||
std::vector<AnfNodePtr> parameter_list;
|
std::vector<AnfNodePtr> parameter_list;
|
||||||
for (auto ¶meter : parameters) {
|
for (auto ¶meter : parameters) {
|
||||||
if (!manager_->node_users()[parameter].empty()) {
|
if (!manager_->node_users()[parameter].empty()) {
|
||||||
parameter_list.push_back(parameter);
|
if (!root_->has_flag(TRAINING)) {
|
||||||
|
for (auto &node_pair : manager_->node_users()[parameter]) {
|
||||||
|
auto user_node = node_pair.first;
|
||||||
|
if (!IsPrimitiveCNode(user_node, prim::kPrimReceive)) {
|
||||||
|
parameter_list.push_back(parameter);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// remove_receive_inputs
|
||||||
|
auto cnode = user_node->cast<CNodePtr>();
|
||||||
|
std::vector<AnfNodePtr> new_inputs = {cnode->input(0)};
|
||||||
|
cnode->set_inputs(new_inputs);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
parameter_list.push_back(parameter);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto del_num = parameters.size() - parameter_list.size();
|
auto del_num = parameters.size() - parameter_list.size();
|
||||||
|
|
|
@ -45,13 +45,15 @@ class PipelineTransformer {
|
||||||
: manager_(manager),
|
: manager_(manager),
|
||||||
stage_(stage),
|
stage_(stage),
|
||||||
root_(root),
|
root_(root),
|
||||||
|
main_graph_(nullptr),
|
||||||
|
virtual_dataset_(nullptr),
|
||||||
global_rank_(global_rank),
|
global_rank_(global_rank),
|
||||||
per_stage_rank_num_(per_stage_rank_num) {}
|
per_stage_rank_num_(per_stage_rank_num) {}
|
||||||
virtual ~PipelineTransformer() = default;
|
virtual ~PipelineTransformer() = default;
|
||||||
void LabelRequiredGradCNode();
|
|
||||||
void Coloring();
|
void Coloring();
|
||||||
|
void MainGraph();
|
||||||
|
void LabelMicroBatch();
|
||||||
void BroadCastColoring();
|
void BroadCastColoring();
|
||||||
void HandleSharedParameter();
|
|
||||||
void CutGraph();
|
void CutGraph();
|
||||||
void ParameterColoring();
|
void ParameterColoring();
|
||||||
void CoverSensShape();
|
void CoverSensShape();
|
||||||
|
@ -59,21 +61,18 @@ class PipelineTransformer {
|
||||||
void ElimParameter();
|
void ElimParameter();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::pair<bool, int64_t> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
|
void CreateForwardGroup();
|
||||||
void DoBroadCast(const FuncGraphPtr &func);
|
ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size);
|
||||||
|
std::vector<AnfNodePtr> HandleSharedParameter();
|
||||||
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int64_t user_node_stage,
|
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int64_t user_node_stage,
|
||||||
int64_t node_stage);
|
int64_t node_stage, const ValuePtr &value);
|
||||||
AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
|
AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
|
||||||
int64_t user_node_stage, int64_t node_stage);
|
int64_t user_node_stage, int64_t node_stage, const ValuePtr &value);
|
||||||
void SetNoStageNode(const FuncGraphPtr &func);
|
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph);
|
||||||
void CutBorder(const FuncGraphPtr &graph);
|
AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
|
||||||
bool IsStageNode(const CNodePtr &node);
|
const std::string &tag);
|
||||||
bool Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage,
|
|
||||||
const std::vector<AnfNodePtr> &out_input);
|
|
||||||
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
|
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
|
||||||
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
|
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
|
||||||
AnfNodePtr HandleMonadDepend(const AnfNodePtr &node);
|
|
||||||
CNodePtr HandleMonadLoad(const AnfNodePtr &node);
|
|
||||||
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);
|
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);
|
||||||
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);
|
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);
|
||||||
bool IsPipelineCareNode(const CNodePtr &cnode);
|
bool IsPipelineCareNode(const CNodePtr &cnode);
|
||||||
|
@ -81,11 +80,15 @@ class PipelineTransformer {
|
||||||
FuncGraphManagerPtr manager_;
|
FuncGraphManagerPtr manager_;
|
||||||
int64_t stage_;
|
int64_t stage_;
|
||||||
FuncGraphPtr root_;
|
FuncGraphPtr root_;
|
||||||
|
FuncGraphPtr main_graph_;
|
||||||
|
AnfNodePtr virtual_dataset_;
|
||||||
int64_t global_rank_;
|
int64_t global_rank_;
|
||||||
int64_t per_stage_rank_num_;
|
int64_t per_stage_rank_num_;
|
||||||
TypePtr type_ptr_;
|
TypePtr type_ptr_;
|
||||||
ValueListPtr shape_;
|
ValueListPtr shape_;
|
||||||
AnfNodePtr virtual_param_;
|
AnfNodePtr virtual_param_;
|
||||||
|
int64_t micro_size_ = 0;
|
||||||
|
std::vector<std::string> group_ = {};
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -37,6 +37,7 @@
|
||||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
#include "frontend/parallel/graph_util/graph_info.h"
|
#include "frontend/parallel/graph_util/graph_info.h"
|
||||||
#include "frontend/parallel/graph_util/node_info.h"
|
#include "frontend/parallel/graph_util/node_info.h"
|
||||||
|
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||||
#include "frontend/parallel/node_check.h"
|
#include "frontend/parallel/node_check.h"
|
||||||
#include "frontend/parallel/ops_info/matmul_info.h"
|
#include "frontend/parallel/ops_info/matmul_info.h"
|
||||||
#include "ir/param_info.h"
|
#include "ir/param_info.h"
|
||||||
|
@ -172,8 +173,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
||||||
OperatorArgs arg_forward = op.second;
|
OperatorArgs arg_forward = op.second;
|
||||||
|
|
||||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||||
|
int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||||
|
|
||||||
if (grad_accumulation_step > 1) {
|
if (grad_accumulation_step > 1 || split_stage_num > 1) {
|
||||||
auto parameters = root->parameters();
|
auto parameters = root->parameters();
|
||||||
bool find_grad_accu_node = false;
|
bool find_grad_accu_node = false;
|
||||||
for (auto ¶m : parameters) {
|
for (auto ¶m : parameters) {
|
||||||
|
@ -196,8 +198,8 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
||||||
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
||||||
op_name = MIRROR_OPERATOR;
|
op_name = MIRROR_OPERATOR;
|
||||||
arg_forward.first.pop_back();
|
arg_forward.first.pop_back();
|
||||||
} else if (op_name == MINI_STEP_ALL_GATHER) {
|
} else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR) {
|
||||||
MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation.";
|
MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -207,7 +209,8 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
||||||
OperatorParams params = arg_forward.second;
|
OperatorParams params = arg_forward.second;
|
||||||
|
|
||||||
std::vector<AnfNodePtr> new_node_input;
|
std::vector<AnfNodePtr> new_node_input;
|
||||||
if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) {
|
if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER ||
|
||||||
|
op_name == MIRROR_MICRO_STEP_OPERATOR) {
|
||||||
new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
|
new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
|
||||||
MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
|
MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
|
||||||
} else {
|
} else {
|
||||||
|
@ -496,6 +499,9 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
|
||||||
TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
|
TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
|
||||||
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
|
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
|
||||||
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
|
TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
|
||||||
|
if (IsPrimitiveCNode(middle_node, prim::kPrimReceive)) {
|
||||||
|
tensorlayout_in = *(middle_node->user_data<TensorLayout>());
|
||||||
|
}
|
||||||
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
|
if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
|
||||||
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
|
MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
|
||||||
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
|
MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
|
||||||
|
@ -866,11 +872,13 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
||||||
SetUserAttrs(origin_prim->attrs(), prim);
|
SetUserAttrs(origin_prim->attrs(), prim);
|
||||||
if (index == replace_op.size() - 1) {
|
if (index == replace_op.size() - 1) {
|
||||||
replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
|
replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
|
||||||
|
replace_node->set_primal_attrs(node->primal_attrs());
|
||||||
}
|
}
|
||||||
replace_node->set_in_forward_flag(true);
|
replace_node->set_in_forward_flag(true);
|
||||||
replace_input[0]->set_scope(scope);
|
replace_input[0]->set_scope(scope);
|
||||||
if (replace_op_info_flag && replace_op_info[index].first) {
|
if (replace_op_info_flag && replace_op_info[index].first) {
|
||||||
auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph);
|
auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph);
|
||||||
|
new_cnode->set_primal_attrs(node->primal_attrs());
|
||||||
(void)manager->Replace(node, new_cnode); // using Replace function to insert node
|
(void)manager->Replace(node, new_cnode); // using Replace function to insert node
|
||||||
} else {
|
} else {
|
||||||
(void)manager->Replace(node, replace_node); // using Replace function to insert node
|
(void)manager->Replace(node, replace_node); // using Replace function to insert node
|
||||||
|
@ -920,8 +928,9 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
|
||||||
manager->SetEdge(replace_input.first, appear_count, pre_node);
|
manager->SetEdge(replace_input.first, appear_count, pre_node);
|
||||||
}
|
}
|
||||||
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
|
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
|
||||||
auto replace_output = replace_graph->second;
|
auto replace_output = replace_graph->second->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(replace_output);
|
MS_EXCEPTION_IF_NULL(replace_output);
|
||||||
|
replace_output->set_primal_attrs(node->primal_attrs());
|
||||||
(void)manager->Replace(node, replace_output);
|
(void)manager->Replace(node, replace_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1075,7 +1084,7 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
|
if (IsSomePrimitive(cnode, RECEIVE) && cnode->has_user_data(PIPELINE_PARAM)) {
|
||||||
return std::make_pair(node, false);
|
return std::make_pair(node, false);
|
||||||
}
|
}
|
||||||
// When not fully use opt shard, allgather and mirror would be both inserted.
|
// When not fully use opt shard, allgather and mirror would be both inserted.
|
||||||
|
@ -1193,6 +1202,20 @@ CNodePtr SkipTrivialNodes(CNodePtr node) {
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string MirrorOpName() {
|
||||||
|
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||||
|
int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||||
|
std::string mirror_op_name;
|
||||||
|
if (grad_accumulation_step > 1) {
|
||||||
|
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
|
||||||
|
} else if (split_stage_num > 1) {
|
||||||
|
mirror_op_name = MIRROR_MICRO_STEP_OPERATOR;
|
||||||
|
} else {
|
||||||
|
mirror_op_name = MIRROR_OPERATOR;
|
||||||
|
}
|
||||||
|
return mirror_op_name;
|
||||||
|
}
|
||||||
|
|
||||||
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
|
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
size_t node_size = node->inputs().size();
|
size_t node_size = node->inputs().size();
|
||||||
|
@ -1240,12 +1263,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// not a RefKey
|
// not a RefKey
|
||||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
std::string mirror_op_name = MirrorOpName();
|
||||||
std::string mirror_op_name;
|
if (IsPrimitiveCNode(param_node_pair.first, prim::kPrimReceive)) {
|
||||||
if (grad_accumulation_step > 1) {
|
param_ptr = param_node_pair.first->cast<CNodePtr>()->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
|
||||||
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
|
param_name = param_ptr->name();
|
||||||
} else {
|
|
||||||
mirror_op_name = MIRROR_OPERATOR;
|
|
||||||
}
|
}
|
||||||
AnfNodePtr pre_node = node->input(index);
|
AnfNodePtr pre_node = node->input(index);
|
||||||
if (!param_node_pair.second) {
|
if (!param_node_pair.second) {
|
||||||
|
@ -1282,6 +1303,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
||||||
AddCommOpFusionType(comm_op, param_node_pair.first);
|
AddCommOpFusionType(comm_op, param_node_pair.first);
|
||||||
MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
|
MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
|
||||||
<< " and insert mirror before Load";
|
<< " and insert mirror before Load";
|
||||||
|
AddCommOpParamFlag(comm_op);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root);
|
InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root);
|
||||||
|
@ -1291,6 +1313,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
||||||
// add fusion flag
|
// add fusion flag
|
||||||
// pipeline mirror would not be set, which should be supported later
|
// pipeline mirror would not be set, which should be supported later
|
||||||
AddCommOpFusionType(comm_op, param_node_pair.first);
|
AddCommOpFusionType(comm_op, param_node_pair.first);
|
||||||
|
AddCommOpParamFlag(comm_op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2333,6 +2356,9 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
|
||||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
|
||||||
|
return cnode->user_data<TensorLayout>();
|
||||||
|
}
|
||||||
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() &&
|
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() &&
|
||||||
!IsPrimitiveCNode(node, prim::kPrimReshape)) {
|
!IsPrimitiveCNode(node, prim::kPrimReshape)) {
|
||||||
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
|
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
|
||||||
|
@ -2764,13 +2790,6 @@ std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphP
|
||||||
return sens_loss_pairs;
|
return sens_loss_pairs;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsLastStage() {
|
|
||||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
|
||||||
auto stage_num = g_device_manager->stage_num();
|
|
||||||
auto stage_id = g_device_manager->stage_id();
|
|
||||||
return ((stage_num - 1) == stage_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||||
const FuncGraphManagerPtr &manager) {
|
const FuncGraphManagerPtr &manager) {
|
||||||
MS_EXCEPTION_IF_NULL(root);
|
MS_EXCEPTION_IF_NULL(root);
|
||||||
|
@ -2793,7 +2812,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
// the make_tuple is parallel care node, but it may have not operator info
|
// the make_tuple is parallel care node, but it may have not operator info
|
||||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
|
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || cnode->HasPrimalAttr(PIPELINE_PARAM)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3545,20 +3564,6 @@ static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name) {
|
|
||||||
for (auto ¶meter : parameters) {
|
|
||||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
|
||||||
if (param_ptr->name() == name) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (param_ptr->name().find(name) != std::string::npos && param_ptr->name().find("accu_grad") != std::string::npos) {
|
|
||||||
return parameter;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
|
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
|
||||||
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
|
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
|
||||||
auto cnode = node_user.first->cast<CNodePtr>();
|
auto cnode = node_user.first->cast<CNodePtr>();
|
||||||
|
@ -3612,6 +3617,17 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages) {
|
||||||
|
if (!root->has_flag(BACKWARD) && pipeline_stages > 1) {
|
||||||
|
root->set_flag(BACKWARD, true);
|
||||||
|
if (root->has_flag(TRAINING)) {
|
||||||
|
Reorder(root, manager);
|
||||||
|
} else {
|
||||||
|
ReorderForPredict(root, manager);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
||||||
#if (ENABLE_CPU && !_WIN32)
|
#if (ENABLE_CPU && !_WIN32)
|
||||||
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
||||||
|
@ -3622,6 +3638,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
||||||
MS_EXCEPTION_IF_NULL(optimizer);
|
MS_EXCEPTION_IF_NULL(optimizer);
|
||||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||||
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
|
||||||
|
pipeline::ResourceBasePtr res = optimizer->resource();
|
||||||
|
MS_EXCEPTION_IF_NULL(res);
|
||||||
|
FuncGraphManagerPtr manager = res->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||||
// assume no change to graph
|
// assume no change to graph
|
||||||
bool changes = false;
|
bool changes = false;
|
||||||
// control whether use model_parallel mode
|
// control whether use model_parallel mode
|
||||||
|
@ -3634,6 +3655,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
||||||
}
|
}
|
||||||
root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
|
root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
|
||||||
}
|
}
|
||||||
|
ReorderForPipelineSplit(root, manager, pipeline_stages);
|
||||||
|
|
||||||
return changes;
|
return changes;
|
||||||
}
|
}
|
||||||
|
@ -3643,23 +3665,22 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
||||||
|
|
||||||
MS_LOG(INFO) << "Now entering step parallel";
|
MS_LOG(INFO) << "Now entering step parallel";
|
||||||
DumpGraph(root, std::string(STEP_PARALLEL_BEGIN));
|
DumpGraph(root, std::string(STEP_PARALLEL_BEGIN));
|
||||||
|
|
||||||
pipeline::ResourceBasePtr res = optimizer->resource();
|
|
||||||
MS_EXCEPTION_IF_NULL(res);
|
|
||||||
|
|
||||||
FuncGraphManagerPtr manager = res->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
AnfNodePtr ret = root->get_return();
|
AnfNodePtr ret = root->get_return();
|
||||||
MS_EXCEPTION_IF_NULL(ret);
|
MS_EXCEPTION_IF_NULL(ret);
|
||||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||||
std::reverse(all_nodes.begin(), all_nodes.end());
|
std::reverse(all_nodes.begin(), all_nodes.end());
|
||||||
if (parallel_mode != AUTO_PARALLEL) {
|
if (parallel_mode != AUTO_PARALLEL) {
|
||||||
TOTAL_OPS = 0;
|
TOTAL_OPS = 0;
|
||||||
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
|
||||||
if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
|
if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "Parallel init failed";
|
MS_LOG(EXCEPTION) << "Parallel init failed";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (pipeline_stages > 1) {
|
||||||
|
HandleMicroBatch(all_nodes, manager);
|
||||||
|
ParameterStartNode(all_nodes, manager);
|
||||||
|
LastStageEndNode(all_nodes, manager);
|
||||||
|
}
|
||||||
|
|
||||||
// mark the forward cnodes, parallel only care these nodes
|
// mark the forward cnodes, parallel only care these nodes
|
||||||
MarkForwardCNode(root);
|
MarkForwardCNode(root);
|
||||||
|
|
||||||
|
@ -3705,6 +3726,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
||||||
// ForwardCommunication BackwardCommunication TensorRedistribution
|
// ForwardCommunication BackwardCommunication TensorRedistribution
|
||||||
ParallelCommunication(root, all_nodes, manager);
|
ParallelCommunication(root, all_nodes, manager);
|
||||||
|
|
||||||
|
if (pipeline_stages > 1) {
|
||||||
|
AddVirtualAssignAdd(root);
|
||||||
|
HandleReceiveParam(root, all_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
auto group_info = g_device_manager->group_info();
|
auto group_info = g_device_manager->group_info();
|
||||||
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
|
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
|
||||||
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
|
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
|
||||||
|
|
|
@ -134,8 +134,6 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
|
||||||
|
|
||||||
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
|
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
|
||||||
|
|
||||||
bool IsLastStage();
|
|
||||||
|
|
||||||
// Add node for whole graph
|
// Add node for whole graph
|
||||||
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||||
const FuncGraphManagerPtr &manager);
|
const FuncGraphManagerPtr &manager);
|
||||||
|
@ -177,6 +175,10 @@ void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *u
|
||||||
std::vector<size_t> *indexes);
|
std::vector<size_t> *indexes);
|
||||||
|
|
||||||
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
|
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
|
||||||
|
|
||||||
|
std::string MirrorOpName();
|
||||||
|
|
||||||
|
void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -410,7 +410,9 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
irpass.env_get_item_depend_swap_,
|
irpass.env_get_item_depend_swap_,
|
||||||
irpass.incorporate_env_getitem_switch_layer_,
|
irpass.incorporate_env_getitem_switch_layer_,
|
||||||
irpass.value_based_eliminate_,
|
irpass.value_based_eliminate_,
|
||||||
irpass.receive_eliminate_},
|
irpass.virtual_accu_grad_,
|
||||||
|
irpass.virtual_assign_add_,
|
||||||
|
irpass.mirror_micro_step_},
|
||||||
false, true);
|
false, true);
|
||||||
opt::OptPassConfig b_2 = opt::OptPassConfig({
|
opt::OptPassConfig b_2 = opt::OptPassConfig({
|
||||||
irpass.replace_refkey_by_param_,
|
irpass.replace_refkey_by_param_,
|
||||||
|
|
|
@ -90,17 +90,19 @@ bool PipelineSplit(const ResourcePtr &res) {
|
||||||
auto transformer =
|
auto transformer =
|
||||||
std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
|
std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
|
||||||
// step1: Do color graph
|
// step1: Do color graph
|
||||||
transformer->LabelRequiredGradCNode();
|
|
||||||
transformer->Coloring();
|
transformer->Coloring();
|
||||||
|
transformer->MainGraph();
|
||||||
|
transformer->LabelMicroBatch();
|
||||||
// step2: Do color broadcast
|
// step2: Do color broadcast
|
||||||
transformer->BroadCastColoring();
|
transformer->BroadCastColoring();
|
||||||
// step3: Handle shared parameters
|
// step3: Handle shared parameters
|
||||||
transformer->ParameterColoring();
|
transformer->ParameterColoring();
|
||||||
transformer->HandleSharedParameter();
|
|
||||||
// step4: Cut Graph
|
// step4: Cut Graph
|
||||||
transformer->CutGraph();
|
transformer->CutGraph();
|
||||||
// step5: Handle Sens
|
// step5: Handle Sens
|
||||||
transformer->CoverSensShape();
|
if (root->has_flag(parallel::TRAINING)) {
|
||||||
|
transformer->CoverSensShape();
|
||||||
|
}
|
||||||
// step6: Elim Graph stages and no used parameter
|
// step6: Elim Graph stages and no used parameter
|
||||||
transformer->ElimGraphStage();
|
transformer->ElimGraphStage();
|
||||||
transformer->ElimParameter();
|
transformer->ElimParameter();
|
||||||
|
|
|
@ -375,6 +375,9 @@ inline const PrimitivePtr kPrimFill = std::make_shared<Primitive>("Fill");
|
||||||
inline const PrimitivePtr kPrimFusedPushWeight = std::make_shared<Primitive>("FusedPushWeight");
|
inline const PrimitivePtr kPrimFusedPushWeight = std::make_shared<Primitive>("FusedPushWeight");
|
||||||
inline const PrimitivePtr kPrimFusedPullWeight = std::make_shared<Primitive>("FusedPullWeight");
|
inline const PrimitivePtr kPrimFusedPullWeight = std::make_shared<Primitive>("FusedPullWeight");
|
||||||
inline const PrimitivePtr kPrimInitDataSetQueue = std::make_shared<Primitive>("InitDataSetQueue");
|
inline const PrimitivePtr kPrimInitDataSetQueue = std::make_shared<Primitive>("InitDataSetQueue");
|
||||||
|
inline const PrimitivePtr kPrimVirtualAssignAdd = std::make_shared<Primitive>("_VirtualAssignAdd");
|
||||||
|
inline const PrimitivePtr kPrimVirtualAccuGrad = std::make_shared<Primitive>("_VirtualAccuGrad");
|
||||||
|
inline const PrimitivePtr kPrimMirrorMicroStep = std::make_shared<Primitive>("_MirrorMicroStepOperator");
|
||||||
|
|
||||||
// Quant ops
|
// Quant ops
|
||||||
inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold");
|
inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold");
|
||||||
|
|
|
@ -19,6 +19,7 @@ from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||||
_get_parallel_mode)
|
_get_parallel_mode)
|
||||||
from mindspore.context import ParallelMode, get_auto_parallel_context
|
from mindspore.context import ParallelMode, get_auto_parallel_context
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from mindspore import ops, nn
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.parameter import Parameter, ParameterTuple
|
from ...common.parameter import Parameter, ParameterTuple
|
||||||
from ...common.tensor import Tensor
|
from ...common.tensor import Tensor
|
||||||
|
@ -503,6 +504,95 @@ class _VirtualDatasetCell(Cell):
|
||||||
return self._backbone(*output)
|
return self._backbone(*output)
|
||||||
|
|
||||||
|
|
||||||
|
class _MicroBatch(Cell):
|
||||||
|
"""
|
||||||
|
transform mini-batch to micro-batch in pipeline parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (micro_size): The number of micro-batch.
|
||||||
|
"""
|
||||||
|
def __init__(self, micro_size):
|
||||||
|
super(_MicroBatch, self).__init__()
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.micro_size = micro_size
|
||||||
|
|
||||||
|
def construct(self, i, *inputs):
|
||||||
|
micro_inputs = ()
|
||||||
|
for each_input in inputs:
|
||||||
|
input_shape = self.shape(each_input)
|
||||||
|
micro_batch_begin = i * input_shape[0] // self.micro_size
|
||||||
|
micro_batch_end = (i + 1) * input_shape[0] // self.micro_size
|
||||||
|
micro_input = each_input[micro_batch_begin:micro_batch_end, :]
|
||||||
|
micro_inputs += (micro_input,)
|
||||||
|
return micro_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineCell(Cell):
|
||||||
|
"""
|
||||||
|
Wrap the network with Micro Batch.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
micro_size must be greater or equal to pipeline stages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The target network to wrap.
|
||||||
|
micro_size (Int): MicroBatch size.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> net = PipelineCell(net, 4)
|
||||||
|
"""
|
||||||
|
def __init__(self, network, micro_size):
|
||||||
|
super(PipelineCell, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.micro_inputs = nn.CellList()
|
||||||
|
self.micro_size = micro_size
|
||||||
|
self.add_list = []
|
||||||
|
for i in range(micro_size):
|
||||||
|
micro_input = _MicroBatch(micro_size)
|
||||||
|
self.micro_inputs.append(micro_input)
|
||||||
|
self.add = P.Add().add_prim_attr("pipeline_end", i)
|
||||||
|
self.add_list.append(self.add)
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
ret = None
|
||||||
|
for i in range(self.micro_size):
|
||||||
|
micro_input = self.micro_inputs[i](i, *inputs)
|
||||||
|
output = self.network(*micro_input)
|
||||||
|
if ret is not None:
|
||||||
|
ret = self.add_list[i](ret, output)
|
||||||
|
else:
|
||||||
|
ret = output
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _pipeline_clear_grad(accu_grad, grad):
|
||||||
|
accu_grad = F.depend(accu_grad, grad)
|
||||||
|
zeros = F.tensor_mul(accu_grad, 0.0)
|
||||||
|
return F.assign(accu_grad, zeros)
|
||||||
|
|
||||||
|
|
||||||
|
class _TrainPipelineAccuStepCell(TrainOneStepCell):
|
||||||
|
"""
|
||||||
|
Wraps the network with an optimizer in pipeline mode.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
|
super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens)
|
||||||
|
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
|
||||||
|
self.hyper_map = ops.HyperMap()
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(*inputs)
|
||||||
|
sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
|
||||||
|
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||||
|
accu_grads = ops.depend(self.accu_grads, grads)
|
||||||
|
succ = self.optimizer(accu_grads)
|
||||||
|
clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads)
|
||||||
|
loss = ops.depend(loss, succ, clear)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class VirtualDatasetCellTriple(Cell):
|
class VirtualDatasetCellTriple(Cell):
|
||||||
"""
|
"""
|
||||||
Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
|
Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Generate bprop for comm ops"""
|
"""Generate bprop for comm ops"""
|
||||||
|
from mindspore import Tensor
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.communication import get_rank, get_group_size
|
from mindspore.communication import get_rank, get_group_size
|
||||||
|
@ -22,7 +23,8 @@ from ...common.tensor import RowTensor
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
||||||
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
|
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
|
||||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap)
|
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
|
||||||
|
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator)
|
||||||
from .grad_base import bprop_getters
|
from .grad_base import bprop_getters
|
||||||
from ..operations._inner_ops import Send, Receive
|
from ..operations._inner_ops import Send, Receive
|
||||||
|
|
||||||
|
@ -84,11 +86,11 @@ def get_bprop_send(self):
|
||||||
"""Generate bprop for Send."""
|
"""Generate bprop for Send."""
|
||||||
shape = self.get_attr_dict()["shape"]
|
shape = self.get_attr_dict()["shape"]
|
||||||
dtype = self.get_attr_dict()["dtype"]
|
dtype = self.get_attr_dict()["dtype"]
|
||||||
send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group)
|
send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group_back)
|
||||||
send_grad.add_prim_attr("backward", True)
|
virtual_input = Tensor(0.0, dtype)
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = send_grad()
|
dx = send_grad(virtual_input)
|
||||||
return (dx,)
|
return (dx,)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
@ -96,14 +98,14 @@ def get_bprop_send(self):
|
||||||
@bprop_getters.register(Receive)
|
@bprop_getters.register(Receive)
|
||||||
def get_bprop_receive(self):
|
def get_bprop_receive(self):
|
||||||
"""Generate bprop for Receive."""
|
"""Generate bprop for Receive."""
|
||||||
receive_grad = Send(self.tag, self.rank, self.group)
|
receive_grad = Send(self.tag, self.rank, self.group_back)
|
||||||
receive_grad.add_prim_attr("backward", True)
|
|
||||||
depend = P.Depend()
|
depend = P.Depend()
|
||||||
cast = P.Cast()
|
cast = P.Cast()
|
||||||
|
out_tensor = Tensor(0.0, mstype.float16)
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
send_out = receive_grad(dout)
|
send_out = receive_grad(dout)
|
||||||
dx = depend(cast(zeros_like(x), F.dtype(x)), send_out)
|
dx = depend(cast(out_tensor, F.dtype(x)), send_out)
|
||||||
return (dx,)
|
return (dx,)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
@ -116,6 +118,80 @@ def get_bprop_virtual_add(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(_VirtualAssignAdd)
|
||||||
|
def get_bprop_virtual_assign_add(self):
|
||||||
|
"""Generate bprop for VirtualAssignAdd."""
|
||||||
|
assign_add = P.AssignAdd()
|
||||||
|
cast = P.Cast()
|
||||||
|
dtype = P.DType()
|
||||||
|
out_tensor = Tensor(0.0, mstype.float16)
|
||||||
|
|
||||||
|
def bprop(x, y, out, dout):
|
||||||
|
temp = assign_add(y, dout)
|
||||||
|
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(_VirtualAccuGrad)
|
||||||
|
def get_bprop_virtual_accu_grad(self):
|
||||||
|
"""Generate bprop for VirtualAccuGrad."""
|
||||||
|
cast = P.Cast()
|
||||||
|
dtype = P.DType()
|
||||||
|
out_tensor = Tensor(0.0, mstype.float16)
|
||||||
|
|
||||||
|
def bprop(x, y, out, dout):
|
||||||
|
return (F.depend(y, dout), cast(out_tensor, dtype(y)))
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(_MirrorMicroStepOperator)
|
||||||
|
def get_bprop_mirror_micro_step_operator(self):
|
||||||
|
"""
|
||||||
|
Backpropagator for _MirrorMicroStepOperator, do allreduce or allgather for the devices in the group,
|
||||||
|
allgather for sparse feature.
|
||||||
|
"""
|
||||||
|
group = self.group
|
||||||
|
dev_num = self.dev_num
|
||||||
|
mean_flag = self.mean_flag
|
||||||
|
scale = 1 / dev_num
|
||||||
|
|
||||||
|
all_reduce = AllReduce(group=group)
|
||||||
|
|
||||||
|
fusion = self.get_attr_dict()["fusion"]
|
||||||
|
all_reduce.add_prim_attr("fusion", fusion)
|
||||||
|
if hasattr(self, 'parameter'):
|
||||||
|
parameter = self.parameter
|
||||||
|
all_reduce.add_prim_attr("parameter", parameter)
|
||||||
|
|
||||||
|
if self.instance_name:
|
||||||
|
instance_name = "grad_mirror" + self.instance_name
|
||||||
|
all_reduce.set_prim_instance_name(instance_name)
|
||||||
|
cast = P.Cast()
|
||||||
|
dtype = P.DType()
|
||||||
|
assign = P.Assign()
|
||||||
|
if "parameter_micro" in self.get_attr_dict():
|
||||||
|
assign.add_prim_attr("parameter_micro", 0)
|
||||||
|
out_tensor = Tensor(1.0, mstype.float16)
|
||||||
|
|
||||||
|
def bprop(x, z, out, dout):
|
||||||
|
real_grad = z
|
||||||
|
if mean_flag:
|
||||||
|
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||||
|
z = F.depend(z, dout)
|
||||||
|
real_grad = all_reduce(z)
|
||||||
|
real_grad = F.tensor_mul(real_grad, scale)
|
||||||
|
assign(z, real_grad)
|
||||||
|
else:
|
||||||
|
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||||
|
z = F.depend(z, dout)
|
||||||
|
real_grad = all_reduce(z)
|
||||||
|
assign(z, real_grad)
|
||||||
|
return (cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z)))
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(Broadcast)
|
@bprop_getters.register(Broadcast)
|
||||||
def get_bprop_broad_cast(self):
|
def get_bprop_broad_cast(self):
|
||||||
"""Generate bprop for Broadcast."""
|
"""Generate bprop for Broadcast."""
|
||||||
|
|
|
@ -36,8 +36,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
||||||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect, SearchSorted)
|
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect, SearchSorted)
|
||||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd,
|
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
||||||
_HostAllGather, _HostReduceScatter)
|
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator)
|
||||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||||
TensorSummary, HistogramSummary, Print, Assert)
|
TensorSummary, HistogramSummary, Print, Assert)
|
||||||
from .control_ops import GeSwitch, Merge
|
from .control_ops import GeSwitch, Merge
|
||||||
|
|
|
@ -417,7 +417,7 @@ class Send(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
|
||||||
self.rank = dest_rank
|
self.rank = dest_rank
|
||||||
self.sr_tag = sr_tag
|
self.sr_tag = sr_tag
|
||||||
self.group = group
|
self.group = group
|
||||||
|
@ -427,7 +427,6 @@ class Send(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
self.add_prim_attr("dtype", x_dtype)
|
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
@ -474,7 +473,8 @@ class Receive(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
|
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP,
|
||||||
|
group_back=GlobalComm.WORLD_COMM_GROUP):
|
||||||
self.rank = src_rank
|
self.rank = src_rank
|
||||||
self.tag = sr_tag
|
self.tag = sr_tag
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
|
@ -690,6 +690,72 @@ class _VirtualDataset(PrimitiveWithInfer):
|
||||||
|
|
||||||
virtual_dataset = _VirtualDataset()
|
virtual_dataset = _VirtualDataset()
|
||||||
|
|
||||||
|
|
||||||
|
class _VirtualAssignAdd(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for
|
||||||
|
internal use of parallel modules and cannot be called by users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
micro (int): MicroBatch. Default: 0.
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init"""
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, y_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, y_dtype):
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
virtual_assign_add = _VirtualAssignAdd()
|
||||||
|
|
||||||
|
|
||||||
|
class _VirtualAccuGrad(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for
|
||||||
|
internal use of parallel modules and cannot be called by users.
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init"""
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, y_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, y_dtype):
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
virtual_accu_grad = _VirtualAccuGrad()
|
||||||
|
|
||||||
|
|
||||||
|
class _MirrorMicroStepOperator(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
|
||||||
|
internal use of parallel modules and cannot be called by users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group (str): The communication group to work on. Default: None.
|
||||||
|
dev_num (int): The device number of the group. Default: None.
|
||||||
|
mean_flag (bool): Whether use mean in backward. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, group=None, dev_num=None, mean_flag=None):
|
||||||
|
self.group = group
|
||||||
|
self.dev_num = dev_num
|
||||||
|
self.mean_flag = mean_flag
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, z_shape):
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, z_shape):
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
class _VirtualOutput(PrimitiveWithInfer):
|
class _VirtualOutput(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Auto parallel virtual out operator.
|
Auto parallel virtual out operator.
|
||||||
|
|
|
@ -18,9 +18,9 @@ from .._checkparam import Validator as validator
|
||||||
from .._checkparam import Rel
|
from .._checkparam import Rel
|
||||||
from ..common import dtype as mstype
|
from ..common import dtype as mstype
|
||||||
from ..nn import acc
|
from ..nn import acc
|
||||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
|
||||||
from ..ops import functional as F
|
from ..ops import functional as F
|
||||||
from ..parallel._utils import _get_parallel_mode
|
from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
|
||||||
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
||||||
from ..context import ParallelMode
|
from ..context import ParallelMode
|
||||||
from .. import context
|
from .. import context
|
||||||
|
@ -187,5 +187,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
||||||
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
|
network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
|
||||||
scale_sense=update_cell).set_train()
|
scale_sense=update_cell).set_train()
|
||||||
return network
|
return network
|
||||||
network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
|
if _get_pipeline_stages() > 1:
|
||||||
|
network = _TrainPipelineAccuStepCell(network, optimizer).set_train()
|
||||||
|
else:
|
||||||
|
network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
|
||||||
return network
|
return network
|
||||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.nn.wrap.cell_wrapper import PipelineCell
|
||||||
|
|
||||||
|
|
||||||
class DatasetLenet():
|
class DatasetLenet():
|
||||||
|
@ -90,6 +91,7 @@ class PipelineSplit(nn.Cell):
|
||||||
def __init__(self, strategy1, strategy2):
|
def __init__(self, strategy1, strategy2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cell = Net(strategy1, strategy2)
|
self.cell = Net(strategy1, strategy2)
|
||||||
|
self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
|
||||||
|
|
||||||
def construct(self, x, label):
|
def construct(self, x, label):
|
||||||
x = self.cell(x)
|
x = self.cell(x)
|
||||||
|
@ -101,6 +103,7 @@ class PipelineSplit2(nn.Cell):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
|
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
|
||||||
self.cell = Net(strategy1, strategy2, self.param)
|
self.cell = Net(strategy1, strategy2, self.param)
|
||||||
|
self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
|
||||||
|
|
||||||
def construct(self, x, label):
|
def construct(self, x, label):
|
||||||
x = self.cell(x)
|
x = self.cell(x)
|
||||||
|
@ -114,8 +117,8 @@ def test_pipeline_split_stage0():
|
||||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
strategy1 = ((4, 1), (1, 1))
|
strategy1 = ((4, 1), (1, 1))
|
||||||
strategy2 = ((2, 1), (1, 1))
|
strategy2 = ((2, 1), (1, 1))
|
||||||
net = PipelineSplit(strategy1, strategy2)
|
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||||
params = net.cell.block[0].trainable_params()
|
params = net.network.cell.block[0].trainable_params()
|
||||||
dataset = DatasetLenet(data, label, 3)
|
dataset = DatasetLenet(data, label, 3)
|
||||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||||
model = Model(net, optimizer=optimizer)
|
model = Model(net, optimizer=optimizer)
|
||||||
|
@ -131,8 +134,8 @@ def test_pipeline_split_stage1():
|
||||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
strategy1 = ((4, 1), (1, 1))
|
strategy1 = ((4, 1), (1, 1))
|
||||||
strategy2 = ((2, 1), (1, 1))
|
strategy2 = ((2, 1), (1, 1))
|
||||||
net = PipelineSplit(strategy1, strategy2)
|
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||||
params = net.cell.block[1].trainable_params()
|
params = net.network.cell.block[1].trainable_params()
|
||||||
dataset = DatasetLenet(data, label, 3)
|
dataset = DatasetLenet(data, label, 3)
|
||||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||||
model = Model(net, optimizer=optimizer)
|
model = Model(net, optimizer=optimizer)
|
||||||
|
@ -149,8 +152,8 @@ def test_pipeline_split_shared_parameter_stage0():
|
||||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
strategy1 = ((4, 1), (1, 1))
|
strategy1 = ((4, 1), (1, 1))
|
||||||
strategy2 = ((2, 1), (1, 1))
|
strategy2 = ((2, 1), (1, 1))
|
||||||
net = PipelineSplit2(strategy1, strategy2)
|
net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
|
||||||
params = net.cell.block[0].trainable_params()
|
params = net.network.cell.block[0].trainable_params()
|
||||||
dataset = DatasetLenet(data, label, 3)
|
dataset = DatasetLenet(data, label, 3)
|
||||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||||
model = Model(net, optimizer=optimizer)
|
model = Model(net, optimizer=optimizer)
|
||||||
|
@ -164,8 +167,8 @@ def test_pipeline_split_shared_parameter_stage1():
|
||||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
strategy1 = ((4, 1), (1, 1))
|
strategy1 = ((4, 1), (1, 1))
|
||||||
strategy2 = ((2, 1), (1, 1))
|
strategy2 = ((2, 1), (1, 1))
|
||||||
net = PipelineSplit2(strategy1, strategy2)
|
net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
|
||||||
params = net.cell.block[1].trainable_params()
|
params = net.network.cell.block[1].trainable_params()
|
||||||
dataset = DatasetLenet(data, label, 3)
|
dataset = DatasetLenet(data, label, 3)
|
||||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||||
model = Model(net, optimizer=optimizer)
|
model = Model(net, optimizer=optimizer)
|
||||||
|
|
Loading…
Reference in New Issue