!528 optimize execute order for memreuse

Merge pull request !528 from kisnwang/optimize-execute-order-for-memreuse
This commit is contained in:
mindspore-ci-bot 2020-04-26 09:38:20 +08:00 committed by Gitee
commit c6d21ccd12
5 changed files with 78 additions and 83 deletions

View File

@ -251,9 +251,10 @@ void BestFitMemReuse::ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr) {
} }
size_t BestFitMemReuse::FindIndx(const std::vector<MembufPtr> &membuf_ptr_list, int fac_idx) const { size_t BestFitMemReuse::FindIndx(const std::vector<MembufPtr> &membuf_ptr_list, int fac_idx) const {
size_t membuf_index = 0; size_t membuf_index = membuf_ptr_list.size();
for (size_t n = 0; n < membuf_ptr_list.size(); ++n) { for (size_t n = 0; n < membuf_ptr_list.size(); ++n) {
auto membuf = membuf_ptr_list[n]; auto membuf = membuf_ptr_list[n];
MS_EXCEPTION_IF_NULL(membuf);
if (membuf->index_ == fac_idx) { if (membuf->index_ == fac_idx) {
membuf_index = n; membuf_index = n;
break; break;

View File

@ -851,17 +851,12 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i
bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto kernel_name = AnfAlgo::GetCNodeName(node); if (!node->isa<CNode>()) {
auto kernel_type = AnfAlgo::GetKernelType(node); return false;
if (kernel_name == kAllReduceOpName || kernel_type == HCCL_KERNEL) {
return true;
} }
return false; auto kernel_name = AnfAlgo::GetCNodeName(node);
} if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
kernel_name == kReduceScatterOpName) {
bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
return true; return true;
} }
return false; return false;

View File

@ -176,7 +176,6 @@ class AnfRuntimeAlgorithm {
// get real input index for some tbe ops which input order is different between me and tbe impl // get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsAllReduceOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node); static bool IsGetNext(const NotNull<AnfNodePtr> &node);
}; };
} // namespace session } // namespace session

View File

@ -49,80 +49,81 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
return std::vector<AnfNodePtr>(); return std::vector<AnfNodePtr>();
} }
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(visit_queue);
MS_EXCEPTION_IF_NULL(visited_nodes);
auto it = node_output_edges_.find(node);
if (it == node_output_edges_.end()) {
// value node and parameter has no input,no need to print log
if (node->isa<CNode>()) {
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
}
return;
}
// visit all reduce node first, then other nodes
std::vector<AnfNodePtr> active_nodes;
for (const auto &output_edge : it->second) {
auto next_node = output_edge.first;
if (node_input_num_.find(next_node) == node_input_num_.end()) {
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
}
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
if (node_input_num_[next_node] < output_edge.second) {
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
<< ",depend edge:" << output_edge.second;
}
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
// allreduce first
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
(void)visited_nodes->insert(next_node);
if (AnfAlgo::IsCommunicationOp(next_node)) {
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
active_nodes.emplace_back(next_node);
}
}
}
for (auto &node : active_nodes) {
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
visit_queue->push(node);
}
}
void KernelGraph::SetExecOrderByDefault() { void KernelGraph::SetExecOrderByDefault() {
std::stack<AnfNodePtr> seed_nodes; std::queue<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes); UpdateNodeEdgeList(&seed_nodes);
execution_order_.clear(); execution_order_.clear();
std::unordered_set<AnfNodePtr> visited_nodes; std::unordered_set<AnfNodePtr> visited_nodes;
std::queue<AnfNodePtr> zero_input_nodes; std::queue<AnfNodePtr> zero_input_nodes;
AnfNodePtr last_communication_node = nullptr;
auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) { std::queue<AnfNodePtr> communication_descendants;
auto it = node_output_edges_.find(node); while (!seed_nodes.empty() || last_communication_node != nullptr) {
if (it == node_output_edges_.end()) {
// value node and parameter has no input,no need to print log
if (node->isa<CNode>()) {
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
}
return;
}
// visit all reduce node first, then other nodes
std::vector<AnfNodePtr> active_nodes;
for (const auto &output_edge : it->second) {
auto next_node = output_edge.first;
if (node_input_num_.find(next_node) == node_input_num_.end()) {
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
}
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
if (node_input_num_[next_node] < output_edge.second) {
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num"
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second;
}
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
// allreduce first
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) {
(void)visited_nodes.insert(next_node);
if (AnfAlgo::IsAllReduceOp(next_node)) {
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
active_nodes.emplace_back(next_node);
}
}
}
for (auto &node : active_nodes) {
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
visit_queue->push(node);
}
};
AnfNodePtr last_allreduce_node = nullptr;
std::queue<AnfNodePtr> allreduce_descendants;
while (!seed_nodes.empty() || last_allreduce_node != nullptr) {
// seed nodes first, then visit last all reduce node descendant // seed nodes first, then visit last all reduce node descendant
if (seed_nodes.empty()) { if (seed_nodes.empty()) {
visit_node_descendant(last_allreduce_node, &allreduce_descendants); VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
last_allreduce_node = nullptr; last_communication_node = nullptr;
} else { } else {
zero_input_nodes.push(seed_nodes.top()); zero_input_nodes.push(seed_nodes.front());
seed_nodes.pop(); seed_nodes.pop();
} }
// all reduce node descendant first, then common queue // all reduce node descendant first, then common queue
while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) { while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
AnfNodePtr node = nullptr; AnfNodePtr node = nullptr;
bool is_allreduce_descendant = false; bool is_communication_descendant = false;
if (allreduce_descendants.empty()) { if (communication_descendants.empty()) {
node = zero_input_nodes.front(); node = zero_input_nodes.front();
zero_input_nodes.pop(); zero_input_nodes.pop();
} else { } else {
node = allreduce_descendants.front(); node = communication_descendants.front();
allreduce_descendants.pop(); communication_descendants.pop();
is_allreduce_descendant = true; is_communication_descendant = true;
} }
// add execute node // add execute node
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -130,19 +131,18 @@ void KernelGraph::SetExecOrderByDefault() {
execution_order_.push_back(node->cast<CNodePtr>()); execution_order_.push_back(node->cast<CNodePtr>());
} }
// for all reduce node, visit last all reduce node descendant // for all reduce node, visit last all reduce node descendant
if (AnfAlgo::IsAllReduceOp(node)) { if (AnfAlgo::IsCommunicationOp(node)) {
if (last_allreduce_node != nullptr) { if (last_communication_node != nullptr) {
visit_node_descendant(last_allreduce_node, &allreduce_descendants); VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
} }
last_allreduce_node = node; last_communication_node = node;
} else if (is_allreduce_descendant) { } else if (is_communication_descendant) {
visit_node_descendant(node, &allreduce_descendants); VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
} else { } else {
visit_node_descendant(node, &zero_input_nodes); VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
} }
} }
} }
CheckLoop(); CheckLoop();
} }
@ -467,7 +467,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
return true; return true;
} }
void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) { void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
node_output_edges_.clear(); node_output_edges_.clear();
node_input_num_.clear(); node_input_num_.clear();
node_input_edges_.clear(); node_input_edges_.clear();
@ -483,7 +483,6 @@ void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
seed_nodes->push(node); seed_nodes->push(node);
continue; continue;
} }
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
} }

View File

@ -22,7 +22,6 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <queue> #include <queue>
#include <stack>
#include <map> #include <map>
#include <unordered_set> #include <unordered_set>
#include "ir/func_graph.h" #include "ir/func_graph.h"
@ -94,8 +93,10 @@ class KernelGraph : public FuncGraph {
private: private:
// remove value node form graph // remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes);
// update node edge list // update node edge list
void UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes); void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
// add node depend edge by data edge or control depend // add node depend edge by data edge or control depend
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
// handle control depend // handle control depend