|
|
|
@ -50,90 +50,127 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::SetExecOrderByDefault() {
|
|
|
|
|
BfsToUpdateNodeOutput();
|
|
|
|
|
std::stack<AnfNodePtr> seed_nodes;
|
|
|
|
|
UpdateNodeEdgeList(&seed_nodes);
|
|
|
|
|
execution_order_.clear();
|
|
|
|
|
std::queue<AnfNodePtr> allreduce_nodes;
|
|
|
|
|
std::queue<AnfNodePtr> zero_output_nodes;
|
|
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes;
|
|
|
|
|
auto clear_output = [&zero_output_nodes, &allreduce_nodes, &visited_nodes, this](const AnfNodePtr &input) -> void {
|
|
|
|
|
if (node_output_num_[input] == 0 && visited_nodes.find(input) == visited_nodes.end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
|
|
MS_LOG(DEBUG) << "Clear output num:" << input->DebugString();
|
|
|
|
|
(void)visited_nodes.insert(input);
|
|
|
|
|
if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == kAllReduceOpName) {
|
|
|
|
|
allreduce_nodes.push(input);
|
|
|
|
|
} else {
|
|
|
|
|
zero_output_nodes.push(input);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
zero_output_nodes.emplace(get_return());
|
|
|
|
|
while (!zero_output_nodes.empty() || !allreduce_nodes.empty()) {
|
|
|
|
|
AnfNodePtr node;
|
|
|
|
|
if (!zero_output_nodes.empty()) {
|
|
|
|
|
node = zero_output_nodes.front();
|
|
|
|
|
zero_output_nodes.pop();
|
|
|
|
|
} else {
|
|
|
|
|
node = allreduce_nodes.front();
|
|
|
|
|
allreduce_nodes.pop();
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
|
|
|
|
execution_order_.push_back(node->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
auto it = node_input_edges_.find(node);
|
|
|
|
|
if (it == node_input_edges_.end()) {
|
|
|
|
|
std::queue<AnfNodePtr> zero_input_nodes;
|
|
|
|
|
|
|
|
|
|
auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue) {
|
|
|
|
|
auto it = node_output_edges_.find(node);
|
|
|
|
|
if (it == node_output_edges_.end()) {
|
|
|
|
|
// value node and parameter has no input,no need to print log
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
for (const auto &input_edge : it->second) {
|
|
|
|
|
if (node_output_num_.find(input_edge.first) == node_output_num_.end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_edge.first);
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can't find node[" << input_edge.first->DebugString() << "]";
|
|
|
|
|
|
|
|
|
|
// visit all reduce node first, then other nodes
|
|
|
|
|
std::vector<AnfNodePtr> active_nodes;
|
|
|
|
|
for (const auto &output_edge : it->second) {
|
|
|
|
|
auto next_node = output_edge.first;
|
|
|
|
|
if (node_input_num_.find(next_node) == node_input_num_.end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(next_node);
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_edge.first);
|
|
|
|
|
MS_LOG(DEBUG) << "Decrease input:" << input_edge.first->DebugString() << ",node:" << node->DebugString()
|
|
|
|
|
<< ",num: " << node_output_num_[input_edge.first] << ",decrease num:" << input_edge.second;
|
|
|
|
|
if (node_output_num_[input_edge.first] < input_edge.second) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input node:" << input_edge.first->DebugString() << ",node_output_num"
|
|
|
|
|
<< node_output_num_[input_edge.first] << "depend edge:" << input_edge.second;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(next_node);
|
|
|
|
|
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
|
|
|
|
|
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
|
|
|
|
|
if (node_input_num_[next_node] < output_edge.second) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num"
|
|
|
|
|
<< node_input_num_[next_node] << ",depend edge:" << output_edge.second;
|
|
|
|
|
}
|
|
|
|
|
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
|
|
|
|
|
// allreduce first
|
|
|
|
|
if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) {
|
|
|
|
|
(void)visited_nodes.insert(next_node);
|
|
|
|
|
if (AnfAlgo::IsAllReduceOp(next_node)) {
|
|
|
|
|
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
|
|
|
|
|
visit_queue->push(next_node);
|
|
|
|
|
} else {
|
|
|
|
|
active_nodes.emplace_back(next_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &node : active_nodes) {
|
|
|
|
|
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
|
|
|
|
|
visit_queue->push(node);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
AnfNodePtr last_allreduce_node = nullptr;
|
|
|
|
|
std::queue<AnfNodePtr> allreduce_descendants;
|
|
|
|
|
while (!seed_nodes.empty() || last_allreduce_node != nullptr) {
|
|
|
|
|
// seed nodes first, then visit last all reduce node descendant
|
|
|
|
|
if (seed_nodes.empty()) {
|
|
|
|
|
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
|
|
|
|
|
last_allreduce_node = nullptr;
|
|
|
|
|
} else {
|
|
|
|
|
zero_input_nodes.push(seed_nodes.top());
|
|
|
|
|
seed_nodes.pop();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// all reduce node descendant first, then common queue
|
|
|
|
|
while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) {
|
|
|
|
|
AnfNodePtr node = nullptr;
|
|
|
|
|
bool is_allreduce_descendant = false;
|
|
|
|
|
if (allreduce_descendants.empty()) {
|
|
|
|
|
node = zero_input_nodes.front();
|
|
|
|
|
zero_input_nodes.pop();
|
|
|
|
|
} else {
|
|
|
|
|
node = allreduce_descendants.front();
|
|
|
|
|
allreduce_descendants.pop();
|
|
|
|
|
is_allreduce_descendant = true;
|
|
|
|
|
}
|
|
|
|
|
// add execute node
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
|
|
|
|
execution_order_.push_back(node->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
// for all reduce node, visit last all reduce node descendant
|
|
|
|
|
if (AnfAlgo::IsAllReduceOp(node)) {
|
|
|
|
|
if (last_allreduce_node != nullptr) {
|
|
|
|
|
visit_node_descendant(last_allreduce_node, &allreduce_descendants);
|
|
|
|
|
}
|
|
|
|
|
last_allreduce_node = node;
|
|
|
|
|
} else if (is_allreduce_descendant) {
|
|
|
|
|
visit_node_descendant(node, &allreduce_descendants);
|
|
|
|
|
} else {
|
|
|
|
|
visit_node_descendant(node, &zero_input_nodes);
|
|
|
|
|
}
|
|
|
|
|
node_output_num_[input_edge.first] = node_output_num_[input_edge.first] - input_edge.second;
|
|
|
|
|
clear_output(input_edge.first);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CheckLoop();
|
|
|
|
|
std::reverse(execution_order_.begin(), execution_order_.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::CheckLoop() {
|
|
|
|
|
std::map<AnfNodePtr, size_t> none_zero_output;
|
|
|
|
|
if (node_output_edges_.size() != node_output_num_.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node_output_edges_ size :" << node_output_edges_.size()
|
|
|
|
|
<< "not equal to node_output_num_ size:" << node_output_num_.size();
|
|
|
|
|
std::map<AnfNodePtr, size_t> none_zero_nodes;
|
|
|
|
|
if (node_input_edges_.size() != node_input_num_.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
|
|
|
|
|
<< "not equal to node_input_num_ size:" << node_input_num_.size();
|
|
|
|
|
}
|
|
|
|
|
for (auto &it : node_output_num_) {
|
|
|
|
|
for (auto &it : node_input_num_) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(it.first);
|
|
|
|
|
string str;
|
|
|
|
|
auto node_output_it = node_output_edges_.find(it.first);
|
|
|
|
|
if (node_output_it == node_output_edges_.end()) {
|
|
|
|
|
auto node_input_it = node_input_edges_.find(it.first);
|
|
|
|
|
if (node_input_it == node_input_edges_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
|
|
|
|
|
}
|
|
|
|
|
for (const auto &output_edge : node_output_edges_[it.first]) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_edge.first);
|
|
|
|
|
str = str.append(output_edge.first->DebugString()).append("|");
|
|
|
|
|
for (const auto &input_edge : node_input_edges_[it.first]) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_edge.first);
|
|
|
|
|
str = str.append(input_edge.first->DebugString()).append("|");
|
|
|
|
|
}
|
|
|
|
|
if (it.second != 0) {
|
|
|
|
|
MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",outputs:" << str << ",output num:" << it.second;
|
|
|
|
|
none_zero_output[it.first] = it.second;
|
|
|
|
|
MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
|
|
|
|
|
none_zero_nodes[it.first] = it.second;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// if don't consider control depend and loop exit,a exception will be throw
|
|
|
|
|
if (!none_zero_output.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_output.size();
|
|
|
|
|
if (!none_zero_nodes.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -346,12 +383,13 @@ void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input,
|
|
|
|
|
} else {
|
|
|
|
|
input_it->second.push_back(input_depend_edge);
|
|
|
|
|
}
|
|
|
|
|
// add the depend sum of node
|
|
|
|
|
auto depend_it = node_output_num_.find(input);
|
|
|
|
|
if (depend_it == node_output_num_.end()) {
|
|
|
|
|
node_output_num_[input] = 0;
|
|
|
|
|
// add node input depend num
|
|
|
|
|
auto depend_it = node_input_num_.find(node);
|
|
|
|
|
if (depend_it == node_input_num_.end()) {
|
|
|
|
|
node_input_num_[node] = depend_edge_num;
|
|
|
|
|
} else {
|
|
|
|
|
depend_it->second += depend_edge_num;
|
|
|
|
|
}
|
|
|
|
|
node_output_num_[input] += depend_edge_num;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
|
|
|
|
@ -429,9 +467,9 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::BfsToUpdateNodeOutput() {
|
|
|
|
|
void KernelGraph::UpdateNodeEdgeList(std::stack<AnfNodePtr> *seed_nodes) {
|
|
|
|
|
node_output_edges_.clear();
|
|
|
|
|
node_output_num_.clear();
|
|
|
|
|
node_input_num_.clear();
|
|
|
|
|
node_input_edges_.clear();
|
|
|
|
|
std::vector<AnfNodePtr> control_depends;
|
|
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes;
|
|
|
|
@ -441,6 +479,11 @@ void KernelGraph::BfsToUpdateNodeOutput() {
|
|
|
|
|
auto node = que.front();
|
|
|
|
|
que.pop();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->isa<Parameter>() || node->isa<ValueNode>()) {
|
|
|
|
|
seed_nodes->push(node);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -454,10 +497,6 @@ void KernelGraph::BfsToUpdateNodeOutput() {
|
|
|
|
|
control_depends.push_back(input);
|
|
|
|
|
depend_edge_num = 0;
|
|
|
|
|
}
|
|
|
|
|
// the 2rd input of depend is no depend edge
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && input == cnode->input(kDependAttachNodeIndex)) {
|
|
|
|
|
depend_edge_num = 0;
|
|
|
|
|
}
|
|
|
|
|
PushNoVisitedNode(input, &que, &visited_nodes);
|
|
|
|
|
AddDependEdge(node, input, depend_edge_num);
|
|
|
|
|
}
|
|
|
|
|