!26092 Get output node with call node.

Merge pull request !26092 from gaoyong10/runtime_second8
This commit is contained in:
i-robot 2021-11-11 13:23:34 +00:00 committed by Gitee
commit aaa6212c58
2 changed files with 58 additions and 58 deletions

View File

@ -122,54 +122,6 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
return inputs->push_back(std::make_pair(node, output_index));
}
// Fetch all outputs of control nodes, visited nodes indicates the call node that has been processed. In control flow,
// there are recursive calls between funcgraphs, so the processed call nodes are recorded to prevent infinite loops.
std::vector<KernelWithIndex> GetAllOutputByControlFlowNode(const KernelWithIndex &output_with_index,
std::set<AnfNodePtr> *visited_call_nodes) {
std::vector<KernelWithIndex> ret;
const auto &node = output_with_index.first;
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
const auto &switch_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
const auto &switch_inputs = switch_cnode->inputs();
auto output_vector = AnfAlgo::GetAllOutputWithIndex(switch_inputs[kSwitchTrueBranchIndex], visited_call_nodes);
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
} else if (AnfAlgo::IsCallNode(node)) {
if (visited_call_nodes != nullptr) {
if (visited_call_nodes->find(node) != visited_call_nodes->end()) {
return ret;
} else {
visited_call_nodes->emplace(node);
}
}
// The output of the call node is the output of the funcgraph actually called.
const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(node);
for (const auto &func_graph : func_graphs) {
MS_EXCEPTION_IF_NULL(func_graph);
// The call in the graph kernel does not need to be parsed, and the node is directly output.
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
ret.emplace_back(output_with_index);
break;
}
MS_EXCEPTION_IF_NULL(func_graph->output());
const auto &func_graph_output =
AnfAlgo::VisitKernelWithReturnType(func_graph->output(), output_with_index.second);
std::set<AnfNodePtr> tmp_visited_nodes = {node};
auto output_vector = AnfAlgo::GetAllOutputWithIndex(
func_graph_output.first, (visited_call_nodes == nullptr ? &tmp_visited_nodes : visited_call_nodes));
if (output_with_index.second < output_vector.size()) {
ret.emplace_back(output_vector[output_with_index.second]);
break;
}
}
}
return ret;
}
// ops pair that dynamic input order is differ from the fixed shape ops
// pair: <real_input->ori_input, ori_input->real_input>
static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
@ -367,8 +319,54 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
return ret;
}
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node,
std::set<AnfNodePtr> *visited_call_nodes) {
size_t AnfRuntimeAlgorithm::GetOutputNumByAbstract(const AbstractBasePtr &node_abstract) {
MS_EXCEPTION_IF_NULL(node_abstract);
if (!node_abstract->isa<abstract::AbstractTuple>()) {
return 1;
}
size_t result = 0;
auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
for (const auto &sub_abstract : sub_abstracts) {
result += GetOutputNumByAbstract(sub_abstract);
}
return result;
}
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputByCallNode(const KernelWithIndex &output_with_index) {
MS_EXCEPTION_IF_NULL(output_with_index.first);
auto node_abstract = output_with_index.first->abstract();
MS_EXCEPTION_IF_NULL(node_abstract);
if (!node_abstract->isa<abstract::AbstractTuple>()) {
return {output_with_index};
}
auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
if (sub_abstracts.size() <= output_with_index.second) {
MS_LOG(EXCEPTION) << "Invalid index:" << output_with_index.second
<< "for node:" << output_with_index.first->DebugString();
}
// There may be tuples in the output of the call node, these outputs will be all numbered, so it is necessary
// to count the number of outputs before the target in order to accurately obtain its number.
size_t pre_output_num = 0;
for (size_t i = 0; i < output_with_index.second; ++i) {
pre_output_num += GetOutputNumByAbstract(sub_abstracts[i]);
}
size_t output_num = GetOutputNumByAbstract(sub_abstracts[output_with_index.second]);
std::vector<KernelWithIndex> results;
for (size_t i = 0; i < output_num; ++i) {
results.emplace_back(output_with_index.first, pre_output_num + i);
}
return results;
}
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
std::vector<KernelWithIndex> ret;
std::vector<KernelWithIndex> ret_empty;
@ -377,7 +375,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
auto make_tuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i), visited_call_nodes);
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i));
(void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
}
return ret;
@ -387,7 +385,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
auto depend_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_node);
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend), visited_call_nodes);
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend));
(void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
return ret;
}
@ -424,14 +422,14 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
// The makeTuple node need recurse.
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
auto output_vector = GetAllOutputWithIndex(output_with_index.first, visited_call_nodes);
auto output_vector = GetAllOutputWithIndex(output_with_index.first);
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
continue;
}
// Fetch outputs by control nodes.
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || AnfAlgo::IsCallNode(node)) {
const auto &control_node_output = GetAllOutputByControlFlowNode(output_with_index, visited_call_nodes);
if (AnfAlgo::IsCallNode(node)) {
const auto &control_node_output = GetAllOutputByCallNode(output_with_index);
(void)std::copy(control_node_output.begin(), control_node_output.end(), std::back_inserter(ret));
continue;
}
@ -445,7 +443,6 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
<< " with output index: " << output_with_index.second;
ret.push_back(output_with_index);
}
return ret;
}

View File

@ -84,8 +84,7 @@ class AnfRuntimeAlgorithm {
prim::kPrimMakeTuple});
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types = {});
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
std::set<AnfNodePtr> *visited_call_nodes = nullptr);
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
// get cnode primitive
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
@ -344,6 +343,10 @@ class AnfRuntimeAlgorithm {
// Depth represents the number of layers of the call. When the first input of the call node is a call node,
// the funcgraph in the return value of the inner call needs to be returned.
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth = 1);
// Get the output number according to abstract, when there is a tuple in abstract, it needs to get recursively.
static size_t GetOutputNumByAbstract(const AbstractBasePtr &node_abstract);
// Fetch all outputs of call node.
static std::vector<KernelWithIndex> GetAllOutputByCallNode(const KernelWithIndex &output_with_index);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;