forked from mindspore-Ecosystem/mindspore
!528 optimize execute order for memreuse
Merge pull request !528 from kisnwang/optimize-execute-order-for-memreuse
This commit is contained in:
commit
c6d21ccd12
|
@ -251,9 +251,10 @@ void BestFitMemReuse::ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BestFitMemReuse::FindIndx(const std::vector<MembufPtr> &membuf_ptr_list, int fac_idx) const {
|
size_t BestFitMemReuse::FindIndx(const std::vector<MembufPtr> &membuf_ptr_list, int fac_idx) const {
|
||||||
size_t membuf_index = 0;
|
size_t membuf_index = membuf_ptr_list.size();
|
||||||
for (size_t n = 0; n < membuf_ptr_list.size(); ++n) {
|
for (size_t n = 0; n < membuf_ptr_list.size(); ++n) {
|
||||||
auto membuf = membuf_ptr_list[n];
|
auto membuf = membuf_ptr_list[n];
|
||||||
|
MS_EXCEPTION_IF_NULL(membuf);
|
||||||
if (membuf->index_ == fac_idx) {
|
if (membuf->index_ == fac_idx) {
|
||||||
membuf_index = n;
|
membuf_index = n;
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -851,17 +851,12 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i
|
||||||
|
|
||||||
bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
|
bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto kernel_name = AnfAlgo::GetCNodeName(node);
|
if (!node->isa<CNode>()) {
|
||||||
auto kernel_type = AnfAlgo::GetKernelType(node);
|
return false;
|
||||||
if (kernel_name == kAllReduceOpName || kernel_type == HCCL_KERNEL) {
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
return false;
|
auto kernel_name = AnfAlgo::GetCNodeName(node);
|
||||||
}
|
if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
|
||||||
|
kernel_name == kReduceScatterOpName) {
|
||||||
bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
if (node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -176,7 +176,6 @@ class AnfRuntimeAlgorithm {
|
||||||
// get real input index for some tbe ops which input order is different between me and tbe impl
|
// get real input index for some tbe ops which input order is different between me and tbe impl
|
||||||
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
||||||
static bool IsCommunicationOp(const AnfNodePtr &node);
|
static bool IsCommunicationOp(const AnfNodePtr &node);
|
||||||
static bool IsAllReduceOp(const AnfNodePtr &node);
|
|
||||||
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
|
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
|
|
|
@ -49,80 +49,81 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
|
||||||
return std::vector<AnfNodePtr>();
|
return std::vector<AnfNodePtr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||||
|
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(visit_queue);
|
||||||
|
MS_EXCEPTION_IF_NULL(visited_nodes);
|
||||||
|
auto it = node_output_edges_.find(node);
|
||||||
|
if (it == node_output_edges_.end()) {
|
||||||
|
// value node and parameter has no input,no need to print log
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// visit all reduce node first, then other nodes
|
||||||
|
std::vector<AnfNodePtr> active_nodes;
|
||||||
|
for (const auto &output_edge : it->second) {
|
||||||
|
auto next_node = output_edge.first;
|
||||||
|
if (node_input_num_.find(next_node) == node_input_num_.end()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(next_node);
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(next_node);
|
||||||
|
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
|
||||||
|
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
|
||||||
|
if (node_input_num_[next_node] < output_edge.second) {
|
||||||
|
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
|
||||||
|
<< ",depend edge:" << output_edge.second;
|
||||||
|
}
|
||||||
|
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
|
||||||
|
// allreduce first
|
||||||
|
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
|
||||||
|
(void)visited_nodes->insert(next_node);
|
||||||
|
if (AnfAlgo::IsCommunicationOp(next_node)) {
|
||||||
|
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
|
||||||
|
visit_queue->push(next_node);
|
||||||
|
} else {
|
||||||
|
active_nodes.emplace_back(next_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &node : active_nodes) {
|
||||||
|
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
|
||||||
|
visit_queue->push(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void KernelGraph::SetExecOrderByDefault() {
|
void KernelGraph::SetExecOrderByDefault() {
|
||||||
std::stack<AnfNodePtr> seed_nodes;
|
std::queue<AnfNodePtr> seed_nodes;
|
||||||
UpdateNodeEdgeList(&seed_nodes);
|
UpdateNodeEdgeList(&seed_nodes);
|
||||||
execution_order_.clear();
|
execution_order_.clear();
|
||||||
std::unordered_set<AnfNodePtr> visited_nodes;
|
std::unordered_set<AnfNodePtr> visited_nodes;
|
||||||
std::queue<AnfNodePtr> zero_input_nodes;
|
std::queue<AnfNodePtr> zero_input_nodes;
|
||||||
|
AnfNodePtr last_communication_node = nullptr;
|
||||||
auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) {
|
std::queue<AnfNodePtr> communication_descendants;
|
||||||
auto it = node_output_edges_.find(node);
|
while (!seed_nodes.empty() || last_communication_node != nullptr) {
|
||||||
if (it == node_output_edges_.end()) {
|
|
||||||
// value node and parameter has no input,no need to print log
|
|
||||||
if (node->isa<CNode>()) {
|
|
||||||
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// visit all reduce node first, then other nodes
|
|
||||||
std::vector<AnfNodePtr> active_nodes;
|
|
||||||
for (const auto &output_edge : it->second) {
|
|
||||||
auto next_node = output_edge.first;
|
|
||||||
if (node_input_num_.find(next_node) == node_input_num_.end()) {
|
|
||||||
MS_EXCEPTION_IF_NULL(next_node);
|
|
||||||
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(next_node);
|
|
||||||
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
|
|
||||||
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
|
|
||||||
if (node_input_num_[next_node] < output_edge.second) {
|
|
||||||
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num"
|
|
||||||
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second;
|
|
||||||
}
|
|
||||||
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
|
|
||||||
// allreduce first
|
|
||||||
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) {
|
|
||||||
(void)visited_nodes.insert(next_node);
|
|
||||||
if (AnfAlgo::IsAllReduceOp(next_node)) {
|
|
||||||
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
|
|
||||||
visit_queue->push(next_node);
|
|
||||||
} else {
|
|
||||||
active_nodes.emplace_back(next_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto &node : active_nodes) {
|
|
||||||
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
|
|
||||||
visit_queue->push(node);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
AnfNodePtr last_allreduce_node = nullptr;
|
|
||||||
std::queue<AnfNodePtr> allreduce_descendants;
|
|
||||||
while (!seed_nodes.empty() || last_allreduce_node != nullptr) {
|
|
||||||
// seed nodes first, then visit last all reduce node descendant
|
// seed nodes first, then visit last all reduce node descendant
|
||||||
if (seed_nodes.empty()) {
|
if (seed_nodes.empty()) {
|
||||||
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
|
||||||
last_allreduce_node = nullptr;
|
last_communication_node = nullptr;
|
||||||
} else {
|
} else {
|
||||||
zero_input_nodes.push(seed_nodes.top());
|
zero_input_nodes.push(seed_nodes.front());
|
||||||
seed_nodes.pop();
|
seed_nodes.pop();
|
||||||
}
|
}
|
||||||
|
|
||||||
// all reduce node descendant first, then common queue
|
// all reduce node descendant first, then common queue
|
||||||
while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) {
|
while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
|
||||||
AnfNodePtr node = nullptr;
|
AnfNodePtr node = nullptr;
|
||||||
bool is_allreduce_descendant = false;
|
bool is_communication_descendant = false;
|
||||||
if (allreduce_descendants.empty()) {
|
if (communication_descendants.empty()) {
|
||||||
node = zero_input_nodes.front();
|
node = zero_input_nodes.front();
|
||||||
zero_input_nodes.pop();
|
zero_input_nodes.pop();
|
||||||
} else {
|
} else {
|
||||||
node = allreduce_descendants.front();
|
node = communication_descendants.front();
|
||||||
allreduce_descendants.pop();
|
communication_descendants.pop();
|
||||||
is_allreduce_descendant = true;
|
is_communication_descendant = true;
|
||||||
}
|
}
|
||||||
// add execute node
|
// add execute node
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -130,19 +131,18 @@ void KernelGraph::SetExecOrderByDefault() {
|
||||||
execution_order_.push_back(node->cast<CNodePtr>());
|
execution_order_.push_back(node->cast<CNodePtr>());
|
||||||
}
|
}
|
||||||
// for all reduce node, visit last all reduce node descendant
|
// for all reduce node, visit last all reduce node descendant
|
||||||
if (AnfAlgo::IsAllReduceOp(node)) {
|
if (AnfAlgo::IsCommunicationOp(node)) {
|
||||||
if (last_allreduce_node != nullptr) {
|
if (last_communication_node != nullptr) {
|
||||||
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
|
||||||
}
|
}
|
||||||
last_allreduce_node = node;
|
last_communication_node = node;
|
||||||
} else if (is_allreduce_descendant) {
|
} else if (is_communication_descendant) {
|
||||||
visit_node_descendant(node, &allreduce_descendants);
|
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
|
||||||
} else {
|
} else {
|
||||||
visit_node_descendant(node, &zero_input_nodes);
|
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CheckLoop();
|
CheckLoop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -467,7 +467,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
|
void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
|
||||||
node_output_edges_.clear();
|
node_output_edges_.clear();
|
||||||
node_input_num_.clear();
|
node_input_num_.clear();
|
||||||
node_input_edges_.clear();
|
node_input_edges_.clear();
|
||||||
|
@ -483,7 +483,6 @@ void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
|
||||||
seed_nodes->push(node);
|
seed_nodes->push(node);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <stack>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
|
@ -94,8 +93,10 @@ class KernelGraph : public FuncGraph {
|
||||||
private:
|
private:
|
||||||
// remove value node form graph
|
// remove value node form graph
|
||||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||||
|
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||||
|
std::unordered_set<AnfNodePtr> *visited_nodes);
|
||||||
// update node edge list
|
// update node edge list
|
||||||
void UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes);
|
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
|
||||||
// add node depend edge by data edge or control depend
|
// add node depend edge by data edge or control depend
|
||||||
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
|
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
|
||||||
// handle control depend
|
// handle control depend
|
||||||
|
|
Loading…
Reference in New Issue