Get output nodes by control node.

This commit is contained in:
gaoyong10 2021-10-26 20:14:03 +08:00
parent 1e119e7756
commit 83ee99b07d
8 changed files with 213 additions and 54 deletions

View File

@ -49,6 +49,11 @@ namespace {
constexpr size_t kNopNodeInputSize = 2; constexpr size_t kNopNodeInputSize = 2;
constexpr size_t kNopNodeRealInputIndex = 1; constexpr size_t kNopNodeRealInputIndex = 1;
constexpr size_t kReturnDataIndex = 1; constexpr size_t kReturnDataIndex = 1;
constexpr size_t kSwitchTrueBranchIndex = 2;
constexpr size_t kPartialFuncGraphPos = 1;
constexpr size_t kSwitchLayerBranchPos = 2;
constexpr size_t kSwitchTrueBranchPos = 2;
constexpr size_t kMakeTupleInputStartPos = 1;
const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad}; const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
@ -142,6 +147,54 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
return inputs->push_back(std::make_pair(node, output_index)); return inputs->push_back(std::make_pair(node, output_index));
} }
// Fetch all outputs of control nodes, visited nodes indicates the call node that has been processed. In control flow,
// there are recursive calls between funcgraphs, so the processed call nodes are recorded to prevent infinite loops.
std::vector<KernelWithIndex> GetAllOutputByControlFlowNode(const KernelWithIndex &output_with_index,
std::set<AnfNodePtr> *visited_call_nodes) {
std::vector<KernelWithIndex> ret;
const auto &node = output_with_index.first;
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
const auto &switch_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
const auto &switch_inputs = switch_cnode->inputs();
auto output_vector = AnfAlgo::GetAllOutputWithIndex(switch_inputs[kSwitchTrueBranchIndex], visited_call_nodes);
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
} else if (AnfAlgo::IsCallNode(node)) {
if (visited_call_nodes != nullptr) {
if (visited_call_nodes->find(node) != visited_call_nodes->end()) {
return ret;
} else {
visited_call_nodes->emplace(node);
}
}
// The output of the call node is the output of the funcgraph actually called.
const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(node);
for (const auto &func_graph : func_graphs) {
MS_EXCEPTION_IF_NULL(func_graph);
// The call in the graph kernel does not need to be parsed, and the node is directly output.
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
ret.emplace_back(output_with_index);
break;
}
MS_EXCEPTION_IF_NULL(func_graph->output());
const auto &func_graph_output =
AnfAlgo::VisitKernelWithReturnType(func_graph->output(), output_with_index.second);
std::set<AnfNodePtr> tmp_visited_nodes = {node};
auto output_vector = AnfAlgo::GetAllOutputWithIndex(
func_graph_output.first, (visited_call_nodes == nullptr ? &tmp_visited_nodes : visited_call_nodes));
if (output_with_index.second < output_vector.size()) {
ret.emplace_back(output_vector[output_with_index.second]);
break;
}
}
}
return ret;
}
// ops pair that dynamic input order is differ from the fixed shape ops // ops pair that dynamic input order is differ from the fixed shape ops
// pair: <real_input->ori_input, ori_input->real_input> // pair: <real_input->ori_input, ori_input->real_input>
static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = { static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
@ -339,7 +392,8 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
return ret; return ret;
} }
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) { std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node,
std::set<AnfNodePtr> *visited_call_nodes) {
std::vector<KernelWithIndex> ret; std::vector<KernelWithIndex> ret;
std::vector<KernelWithIndex> ret_empty; std::vector<KernelWithIndex> ret_empty;
@ -348,7 +402,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
auto make_tuple = node->cast<CNodePtr>(); auto make_tuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple); MS_EXCEPTION_IF_NULL(make_tuple);
for (size_t i = 1; i < make_tuple->inputs().size(); i++) { for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i)); auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i), visited_call_nodes);
(void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret)); (void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
} }
return ret; return ret;
@ -358,7 +412,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
auto depend_node = node->cast<CNodePtr>(); auto depend_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_node); MS_EXCEPTION_IF_NULL(depend_node);
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend)); auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend), visited_call_nodes);
(void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret)); (void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
return ret; return ret;
} }
@ -393,20 +447,16 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
// The makeTuple node need recurse. // The makeTuple node need recurse.
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) { if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
auto output_vector = GetAllOutputWithIndex(output_with_index.first); auto output_vector = GetAllOutputWithIndex(output_with_index.first, visited_call_nodes);
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret)); (void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
continue; continue;
} }
// Ignore the output of front call node. // Fetch outputs by control nodes.
if (output_with_index.first->isa<CNode>()) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || AnfAlgo::IsCallNode(node)) {
auto cnode = output_with_index.first->cast<CNodePtr>(); const auto &control_node_output = GetAllOutputByControlFlowNode(output_with_index, visited_call_nodes);
MS_EXCEPTION_IF_NULL(cnode); (void)std::copy(control_node_output.begin(), control_node_output.end(), std::back_inserter(ret));
auto inputs = cnode->inputs(); continue;
if (inputs[0]->isa<CNode>()) {
MS_LOG(INFO) << "The output is call node: " << output_with_index.first->DebugString();
return ret_empty;
}
} }
// The InitDataSetQueue node has no output. // The InitDataSetQueue node has no output.
@ -2527,5 +2577,100 @@ size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
} }
return output_tensor_size_[index]; return output_tensor_size_[index];
} }
bool AnfRuntimeAlgorithm::IsCallNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
if (inputs.empty() || inputs[0] == nullptr) {
MS_LOG(EXCEPTION) << "Invalid call node:" << node->DebugString();
}
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
}
std::set<FuncGraphPtr> AnfRuntimeAlgorithm::GetFuncGraphbyCallNode(const AnfNodePtr &node, size_t call_depth) {
MS_EXCEPTION_IF_NULL(node);
std::set<FuncGraphPtr> func_graphs;
if (!node->isa<CNode>()) {
return func_graphs;
}
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &call_input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(call_input0);
if (AnfAlgo::IsCallNode(call_input0)) {
return AnfAlgo::GetFuncGraphbyCallNode(call_input0, ++call_depth);
}
if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) {
// First input node of call is switch node.
const auto &switch_inputs = call_input0->cast<CNodePtr>()->inputs();
for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) {
MS_EXCEPTION_IF_NULL(switch_inputs[i]);
(void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth));
}
} else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) {
// First input node of call is switch layer node.
const auto &tuple_node = cnode->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString()
<< " for switch layer node:" << cnode->DebugString();
}
const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
MS_EXCEPTION_IF_NULL(tuple_inputs[i]);
func_graphs.emplace(GetFuncGraphFromPartial(tuple_inputs[i], call_depth));
}
} else if (IsPartial(call_input0)) {
// First input node of call is partial node or value node of funcgraph.
(void)func_graphs.emplace(GetFuncGraphFromPartial(call_input0, call_depth));
} else {
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
}
return func_graphs;
}
bool AnfRuntimeAlgorithm::IsPartial(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial);
}
FuncGraphPtr AnfRuntimeAlgorithm::GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth) {
MS_EXCEPTION_IF_NULL(node);
if (depth == 1) {
if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) {
// Value node of funcgraph.
return GetValueNode<FuncGraphPtr>(node);
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
// Partial cnode.
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
return GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
} else {
MS_LOG(EXCEPTION) << "Invalid partial construct node:" << node->DebugString();
}
}
// Get funcgraph in the output of inner call.
if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) {
return GetFuncGraphFromPartial(GetValueNode<FuncGraphPtr>(node)->output(), depth - 1);
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
return GetFuncGraphFromPartial(GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos])->output(),
depth - 1);
} else {
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString();
}
return nullptr;
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

View File

@ -83,7 +83,8 @@ class AnfRuntimeAlgorithm {
prim::kPrimMakeTuple}); prim::kPrimMakeTuple});
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node, static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types = {}); const std::vector<PrimitivePtr> &return_types = {});
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node); static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
std::set<AnfNodePtr> *visited_call_nodes = nullptr);
// get cnode primitive // get cnode primitive
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node); static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index); static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
@ -329,6 +330,20 @@ class AnfRuntimeAlgorithm {
static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph); static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph);
static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod); static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod); static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
// Check whether node is a call node, there are two types of call nodes:
// 1. First input of node is a cnode.
// 2. First input of node is a funcgraph value node.
static bool IsCallNode(const AnfNodePtr &node);
// Find all funcgraphs that the call node will call.
static std::set<FuncGraphPtr> GetFuncGraphbyCallNode(const AnfNodePtr &node, size_t call_depth = 1);
// Check whether node has a partial structure, a node is a partial structure whicih:
// 1. a partial cnode.
// 2. a funcgraph value node.
static bool IsPartial(const AnfNodePtr &node);
// Get funcgraph in partial structure.
// Depth represents the number of layers of the call. When the first input of the call node is a call node,
// the funcgraph in the return value of the inner call needs to be returned.
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth = 1);
}; };
} // namespace session } // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm; using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -1373,6 +1373,23 @@ bool KernelGraph::IsDatasetGraph() const {
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
bool KernelGraph::IsChildGraphResult(const AnfNodePtr &node) {
std::vector<AnfNodePtr> child_graph_results;
for (const auto &child_graph_result : child_graph_result_) {
MS_EXCEPTION_IF_NULL(child_graph_result);
if (AnfAlgo::CheckPrimitiveType(child_graph_result, prim::kPrimMakeTuple)) {
const auto cnode = child_graph_result->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
child_graph_results.insert(child_graph_results.end(), inputs.begin(), inputs.end());
} else {
child_graph_results.emplace_back(child_graph_result);
}
}
return find(child_graph_results.begin(), child_graph_results.end(), node) != child_graph_results.end();
}
KernelGraph::~KernelGraph() { KernelGraph::~KernelGraph() {
try { try {
// Release the kernel resource. // Release the kernel resource.

View File

@ -260,6 +260,7 @@ class KernelGraph : public FuncGraph {
void UpdateChildGraphOrder(); void UpdateChildGraphOrder();
const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; } const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); } void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); }
bool IsChildGraphResult(const AnfNodePtr &node);
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) { void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
child_graph_result_ = child_graph_result; child_graph_result_ = child_graph_result;
} }

View File

@ -54,6 +54,7 @@ bool CheckValidFuncGraphInput(const AnfNodePtr &node) {
// Get the funcgraph in partial node. // Get the funcgraph in partial node.
FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) { FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &partial_inputs = node->cast<CNodePtr>()->inputs(); const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
return GetValueNode<FuncGraphPtr>(partial_inputs[1]); return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
} }
@ -313,7 +314,7 @@ std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std
if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) { if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) {
return outputs; return outputs;
} }
if (!IsCallNode(real_output.first)) { if (!AnfAlgo::IsCallNode(real_output.first)) {
outputs.push_back(real_output.first); outputs.push_back(real_output.first);
return outputs; return outputs;
} }
@ -349,7 +350,7 @@ std::vector<AnfNodePtr> FetchOutputByCallNode(const AnfNodePtr &call_node, std::
} else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) { } else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes); const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
(void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
} else if (IsCallNode(graph_output)) { } else if (AnfAlgo::IsCallNode(graph_output)) {
const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes); const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
(void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
} else if (graph_output->isa<CNode>()) { } else if (graph_output->isa<CNode>()) {
@ -388,7 +389,7 @@ std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, s
} else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { } else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes); const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes);
(void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
} else if (IsCallNode(inputs[i])) { } else if (AnfAlgo::IsCallNode(inputs[i])) {
const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes); const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes);
(void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
} else { } else {
@ -486,7 +487,7 @@ FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) { AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) {
const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first; const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
if (!IsCallNode(real_node)) { if (!AnfAlgo::IsCallNode(real_node)) {
return real_node; return real_node;
} }
if ((*call_nodes).find(real_node) != (*call_nodes).end()) { if ((*call_nodes).find(real_node) != (*call_nodes).end()) {
@ -513,15 +514,6 @@ bool HasAbstractRef(const AnfNodePtr &node) {
return (abs != nullptr) && abs->isa<abstract::AbstractRef>(); return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
} }
bool IsCallNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
const auto &cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
}
bool IsSubCallNode(const AnfNodePtr &node) { bool IsSubCallNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return false; return false;
@ -604,7 +596,7 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
func_graphs.emplace_back(func_graph); func_graphs.emplace_back(func_graph);
} }
} }
} else if (IsCallNode(cnode)) { } else if (AnfAlgo::IsCallNode(cnode)) {
return FetchFuncGraphbyCallNode(cnode); return FetchFuncGraphbyCallNode(cnode);
} else { } else {
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString(); MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
@ -618,7 +610,7 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
} }
size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes) { size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes) {
if (!IsCallNode(node)) { if (!AnfAlgo::IsCallNode(node)) {
MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node); MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node);
} }
if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) { if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) {
@ -631,7 +623,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
const auto &output = func_graph->output(); const auto &output = func_graph->output();
const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0); const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0);
if (IsCallNode(real_output.first)) { if (AnfAlgo::IsCallNode(real_output.first)) {
size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes); size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes);
if (output_num > 0) { if (output_num > 0) {
return output_num; return output_num;
@ -642,7 +634,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
const auto &inputs = tuple_cnode->inputs(); const auto &inputs = tuple_cnode->inputs();
size_t i = 1; size_t i = 1;
for (; i < inputs.size(); ++i) { for (; i < inputs.size(); ++i) {
if (IsCallNode(inputs[i])) { if (AnfAlgo::IsCallNode(inputs[i])) {
size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes); size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes);
if (call_output_num == 0) { if (call_output_num == 0) {
break; break;
@ -872,7 +864,7 @@ void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node
if (node_value->isa<tensor::Tensor>()) { if (node_value->isa<tensor::Tensor>()) {
(void)((*value_nodes).emplace_back(input)); (void)((*value_nodes).emplace_back(input));
} }
} else if (IsCallNode(input)) { } else if (AnfAlgo::IsCallNode(input)) {
// If input is a call not, should check the switch node in its input. // If input is a call not, should check the switch node in its input.
const auto &call_node = input->cast<CNodePtr>(); const auto &call_node = input->cast<CNodePtr>();
const auto &call_inputs = call_node->inputs(); const auto &call_inputs = call_node->inputs();
@ -1050,7 +1042,7 @@ void ControlNodeParser::FetchFrontToFrontParameter(
std::vector<AnfNodePtr> call_inputs; std::vector<AnfNodePtr> call_inputs;
call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end()); call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end());
switch_input_parse(inputs[0], call_inputs); switch_input_parse(inputs[0], call_inputs);
} else if (IsCallNode(inputs[0])) { } else if (AnfAlgo::IsCallNode(inputs[0])) {
continue; continue;
} else { } else {
MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:" MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:"
@ -1098,7 +1090,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::
void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes) { void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &control_node : control_nodes) { for (const auto &control_node : control_nodes) {
if (IsCallNode(control_node)) { if (AnfAlgo::IsCallNode(control_node)) {
const auto &func_graphs = FetchFuncGraphbyCallNode(control_node); const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
for (const auto &func_graph : func_graphs) { for (const auto &func_graph : func_graphs) {
@ -1123,7 +1115,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphP
const auto inputs = graph->input_nodes(); const auto inputs = graph->input_nodes();
for (const auto &input : inputs) { for (const auto &input : inputs) {
const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input); const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) { if (internal_parameter_with_index.first != nullptr && AnfAlgo::IsCallNode(internal_parameter_with_index.first)) {
call_input_kernel_graphs_[graph] = device_context; call_input_kernel_graphs_[graph] = device_context;
call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context}; call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context};
} }
@ -1162,12 +1154,12 @@ std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node,
for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) { for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
if (inputs[i]->isa<Parameter>()) { if (inputs[i]->isa<Parameter>()) {
(void)parameters.emplace_back(inputs[i]); (void)parameters.emplace_back(inputs[i]);
} else if (IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { } else if (AnfAlgo::IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes); const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes);
(void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
} }
} }
} else if (IsCallNode(node)) { } else if (AnfAlgo::IsCallNode(node)) {
if ((*call_nodes).find(node) != (*call_nodes).end()) { if ((*call_nodes).find(node) != (*call_nodes).end()) {
return parameters; return parameters;
} }
@ -1296,7 +1288,7 @@ void ControlNodeParser::FetchFuncGraphToParameter(const std::vector<AnfNodePtr>
} else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { } else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
// Switchlayer node. // Switchlayer node.
FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_); FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
} else if (IsCallNode(inputs[0])) { } else if (AnfAlgo::IsCallNode(inputs[0])) {
continue; continue;
} else { } else {
MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString(); MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
@ -1373,7 +1365,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
for (const auto &switch_output : switch_outputs) { for (const auto &switch_output : switch_outputs) {
FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results); FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results);
} }
} else if (IsCallNode(front_output)) { } else if (AnfAlgo::IsCallNode(front_output)) {
// Output is a call. // Output is a call.
const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes); const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes);
@ -1429,7 +1421,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
} }
} else if (real_parameter->isa<ValueNode>()) { } else if (real_parameter->isa<ValueNode>()) {
(void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0); (void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0);
} else if (IsCallNode(real_parameter)) { } else if (AnfAlgo::IsCallNode(real_parameter)) {
const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter); const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
for (const auto func_graph : func_graphs) { for (const auto func_graph : func_graphs) {
FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters); FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters);

View File

@ -56,11 +56,6 @@ using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNode
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>; using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>; using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
// Check whether node is a call node, there are two types of call nodes:
// 1. First input of node is a cnode.
// 2. First input of node is a funcgraph value node.
bool IsCallNode(const AnfNodePtr &node);
// Check if the call node is the input of another call node. // Check if the call node is the input of another call node.
bool IsSubCallNode(const AnfNodePtr &node); bool IsSubCallNode(const AnfNodePtr &node);

View File

@ -1689,8 +1689,9 @@ void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, con
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(kernel_type); MS_EXCEPTION_IF_NULL(kernel_type);
MS_EXCEPTION_IF_NULL(kernel_name); MS_EXCEPTION_IF_NULL(kernel_name);
// In sink mode, the data exchange between child graphs is expressed as parameters. These parameters are stored
if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>())) { // in the graph and should be obtained from the super kernel actor.
if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>() || graph->IsChildGraphResult(node))) {
*kernel_type = KernelTransformType::kSuperKernelActor; *kernel_type = KernelTransformType::kSuperKernelActor;
*kernel_name = graph->ToString() + "_SuperKernelActor"; *kernel_name = graph->ToString() + "_SuperKernelActor";
return; return;

View File

@ -980,13 +980,6 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first; AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
size_t position = 0; size_t position = 0;
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output); auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
if (runtime::IsCallNode(root_output)) {
std::vector<AnfNodePtr> call_nodes;
size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes);
for (size_t i = 0; i < call_output_num; ++i) {
(void)outputs.emplace_back(root_output, i);
}
}
outputs_num = outputs.size(); outputs_num = outputs.size();
for (const auto &output : outputs) { for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) { if (outputs_order.count(output) == 0) {