forked from mindspore-Ecosystem/mindspore
!528 optimize execute order for memreuse
Merge pull request !528 from kisnwang/optimize-execute-order-for-memreuse
This commit is contained in:
commit
c6d21ccd12
|
@ -251,9 +251,10 @@ void BestFitMemReuse::ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr) {
|
|||
}
|
||||
|
||||
size_t BestFitMemReuse::FindIndx(const std::vector<MembufPtr> &membuf_ptr_list, int fac_idx) const {
|
||||
size_t membuf_index = 0;
|
||||
size_t membuf_index = membuf_ptr_list.size();
|
||||
for (size_t n = 0; n < membuf_ptr_list.size(); ++n) {
|
||||
auto membuf = membuf_ptr_list[n];
|
||||
MS_EXCEPTION_IF_NULL(membuf);
|
||||
if (membuf->index_ == fac_idx) {
|
||||
membuf_index = n;
|
||||
break;
|
||||
|
|
|
@ -851,17 +851,12 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i
|
|||
|
||||
bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(node);
|
||||
auto kernel_type = AnfAlgo::GetKernelType(node);
|
||||
if (kernel_name == kAllReduceOpName || kernel_type == HCCL_KERNEL) {
|
||||
return true;
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(node);
|
||||
if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
|
||||
kernel_name == kReduceScatterOpName) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -176,7 +176,6 @@ class AnfRuntimeAlgorithm {
|
|||
// get real input index for some tbe ops which input order is different between me and tbe impl
|
||||
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
||||
static bool IsCommunicationOp(const AnfNodePtr &node);
|
||||
static bool IsAllReduceOp(const AnfNodePtr &node);
|
||||
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
|
||||
};
|
||||
} // namespace session
|
||||
|
|
|
@ -49,80 +49,81 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
|
|||
return std::vector<AnfNodePtr>();
|
||||
}
|
||||
|
||||
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(visit_queue);
|
||||
MS_EXCEPTION_IF_NULL(visited_nodes);
|
||||
auto it = node_output_edges_.find(node);
|
||||
if (it == node_output_edges_.end()) {
|
||||
// value node and parameter has no input,no need to print log
|
||||
if (node->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// visit all reduce node first, then other nodes
|
||||
std::vector<AnfNodePtr> active_nodes;
|
||||
for (const auto &output_edge : it->second) {
|
||||
auto next_node = output_edge.first;
|
||||
if (node_input_num_.find(next_node) == node_input_num_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
|
||||
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
|
||||
if (node_input_num_[next_node] < output_edge.second) {
|
||||
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
|
||||
<< ",depend edge:" << output_edge.second;
|
||||
}
|
||||
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
|
||||
// allreduce first
|
||||
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
|
||||
(void)visited_nodes->insert(next_node);
|
||||
if (AnfAlgo::IsCommunicationOp(next_node)) {
|
||||
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
|
||||
visit_queue->push(next_node);
|
||||
} else {
|
||||
active_nodes.emplace_back(next_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &node : active_nodes) {
|
||||
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
|
||||
visit_queue->push(node);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelGraph::SetExecOrderByDefault() {
|
||||
std::stack<AnfNodePtr> seed_nodes;
|
||||
std::queue<AnfNodePtr> seed_nodes;
|
||||
UpdateNodeEdgeList(&seed_nodes);
|
||||
execution_order_.clear();
|
||||
std::unordered_set<AnfNodePtr> visited_nodes;
|
||||
std::queue<AnfNodePtr> zero_input_nodes;
|
||||
|
||||
auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) {
|
||||
auto it = node_output_edges_.find(node);
|
||||
if (it == node_output_edges_.end()) {
|
||||
// value node and parameter has no input,no need to print log
|
||||
if (node->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// visit all reduce node first, then other nodes
|
||||
std::vector<AnfNodePtr> active_nodes;
|
||||
for (const auto &output_edge : it->second) {
|
||||
auto next_node = output_edge.first;
|
||||
if (node_input_num_.find(next_node) == node_input_num_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
|
||||
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
|
||||
if (node_input_num_[next_node] < output_edge.second) {
|
||||
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num"
|
||||
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second;
|
||||
}
|
||||
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
|
||||
// allreduce first
|
||||
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) {
|
||||
(void)visited_nodes.insert(next_node);
|
||||
if (AnfAlgo::IsAllReduceOp(next_node)) {
|
||||
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
|
||||
visit_queue->push(next_node);
|
||||
} else {
|
||||
active_nodes.emplace_back(next_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &node : active_nodes) {
|
||||
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
|
||||
visit_queue->push(node);
|
||||
}
|
||||
};
|
||||
|
||||
AnfNodePtr last_allreduce_node = nullptr;
|
||||
std::queue<AnfNodePtr> allreduce_descendants;
|
||||
while (!seed_nodes.empty() || last_allreduce_node != nullptr) {
|
||||
AnfNodePtr last_communication_node = nullptr;
|
||||
std::queue<AnfNodePtr> communication_descendants;
|
||||
while (!seed_nodes.empty() || last_communication_node != nullptr) {
|
||||
// seed nodes first, then visit last all reduce node descendant
|
||||
if (seed_nodes.empty()) {
|
||||
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
|
||||
last_allreduce_node = nullptr;
|
||||
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
|
||||
last_communication_node = nullptr;
|
||||
} else {
|
||||
zero_input_nodes.push(seed_nodes.top());
|
||||
zero_input_nodes.push(seed_nodes.front());
|
||||
seed_nodes.pop();
|
||||
}
|
||||
|
||||
// all reduce node descendant first, then common queue
|
||||
while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) {
|
||||
while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
|
||||
AnfNodePtr node = nullptr;
|
||||
bool is_allreduce_descendant = false;
|
||||
if (allreduce_descendants.empty()) {
|
||||
bool is_communication_descendant = false;
|
||||
if (communication_descendants.empty()) {
|
||||
node = zero_input_nodes.front();
|
||||
zero_input_nodes.pop();
|
||||
} else {
|
||||
node = allreduce_descendants.front();
|
||||
allreduce_descendants.pop();
|
||||
is_allreduce_descendant = true;
|
||||
node = communication_descendants.front();
|
||||
communication_descendants.pop();
|
||||
is_communication_descendant = true;
|
||||
}
|
||||
// add execute node
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -130,19 +131,18 @@ void KernelGraph::SetExecOrderByDefault() {
|
|||
execution_order_.push_back(node->cast<CNodePtr>());
|
||||
}
|
||||
// for all reduce node, visit last all reduce node descendant
|
||||
if (AnfAlgo::IsAllReduceOp(node)) {
|
||||
if (last_allreduce_node != nullptr) {
|
||||
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
|
||||
if (AnfAlgo::IsCommunicationOp(node)) {
|
||||
if (last_communication_node != nullptr) {
|
||||
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
|
||||
}
|
||||
last_allreduce_node = node;
|
||||
} else if (is_allreduce_descendant) {
|
||||
visit_node_descendant(node, &allreduce_descendants);
|
||||
last_communication_node = node;
|
||||
} else if (is_communication_descendant) {
|
||||
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
|
||||
} else {
|
||||
visit_node_descendant(node, &zero_input_nodes);
|
||||
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CheckLoop();
|
||||
}
|
||||
|
||||
|
@ -467,7 +467,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
|
|||
return true;
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
|
||||
void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
|
||||
node_output_edges_.clear();
|
||||
node_input_num_.clear();
|
||||
node_input_edges_.clear();
|
||||
|
@ -483,7 +483,6 @@ void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
|
|||
seed_nodes->push(node);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include <utility>
|
||||
#include <string>
|
||||
#include <queue>
|
||||
#include <stack>
|
||||
#include <map>
|
||||
#include <unordered_set>
|
||||
#include "ir/func_graph.h"
|
||||
|
@ -94,8 +93,10 @@ class KernelGraph : public FuncGraph {
|
|||
private:
|
||||
// remove value node form graph
|
||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes);
|
||||
// update node edge list
|
||||
void UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes);
|
||||
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
|
||||
// add node depend edge by data edge or control depend
|
||||
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
|
||||
// handle control depend
|
||||
|
|
Loading…
Reference in New Issue