forked from mindspore-Ecosystem/mindspore
!26092 Get output node with call node.
Merge pull request !26092 from gaoyong10/runtime_second8
This commit is contained in:
commit
aaa6212c58
|
@ -122,54 +122,6 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
|
|||
return inputs->push_back(std::make_pair(node, output_index));
|
||||
}
|
||||
|
||||
// Fetch all outputs of control nodes, visited nodes indicates the call node that has been processed. In control flow,
|
||||
// there are recursive calls between funcgraphs, so the processed call nodes are recorded to prevent infinite loops.
|
||||
std::vector<KernelWithIndex> GetAllOutputByControlFlowNode(const KernelWithIndex &output_with_index,
|
||||
std::set<AnfNodePtr> *visited_call_nodes) {
|
||||
std::vector<KernelWithIndex> ret;
|
||||
const auto &node = output_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
|
||||
const auto &switch_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||
const auto &switch_inputs = switch_cnode->inputs();
|
||||
auto output_vector = AnfAlgo::GetAllOutputWithIndex(switch_inputs[kSwitchTrueBranchIndex], visited_call_nodes);
|
||||
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
|
||||
} else if (AnfAlgo::IsCallNode(node)) {
|
||||
if (visited_call_nodes != nullptr) {
|
||||
if (visited_call_nodes->find(node) != visited_call_nodes->end()) {
|
||||
return ret;
|
||||
} else {
|
||||
visited_call_nodes->emplace(node);
|
||||
}
|
||||
}
|
||||
|
||||
// The output of the call node is the output of the funcgraph actually called.
|
||||
const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(node);
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// The call in the graph kernel does not need to be parsed, and the node is directly output.
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
ret.emplace_back(output_with_index);
|
||||
break;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(func_graph->output());
|
||||
const auto &func_graph_output =
|
||||
AnfAlgo::VisitKernelWithReturnType(func_graph->output(), output_with_index.second);
|
||||
std::set<AnfNodePtr> tmp_visited_nodes = {node};
|
||||
auto output_vector = AnfAlgo::GetAllOutputWithIndex(
|
||||
func_graph_output.first, (visited_call_nodes == nullptr ? &tmp_visited_nodes : visited_call_nodes));
|
||||
if (output_with_index.second < output_vector.size()) {
|
||||
ret.emplace_back(output_vector[output_with_index.second]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// ops pair that dynamic input order is differ from the fixed shape ops
|
||||
// pair: <real_input->ori_input, ori_input->real_input>
|
||||
static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
|
||||
|
@ -367,8 +319,54 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node,
|
||||
std::set<AnfNodePtr> *visited_call_nodes) {
|
||||
size_t AnfRuntimeAlgorithm::GetOutputNumByAbstract(const AbstractBasePtr &node_abstract) {
|
||||
MS_EXCEPTION_IF_NULL(node_abstract);
|
||||
if (!node_abstract->isa<abstract::AbstractTuple>()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t result = 0;
|
||||
auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
for (const auto &sub_abstract : sub_abstracts) {
|
||||
result += GetOutputNumByAbstract(sub_abstract);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputByCallNode(const KernelWithIndex &output_with_index) {
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
auto node_abstract = output_with_index.first->abstract();
|
||||
MS_EXCEPTION_IF_NULL(node_abstract);
|
||||
if (!node_abstract->isa<abstract::AbstractTuple>()) {
|
||||
return {output_with_index};
|
||||
}
|
||||
|
||||
auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
if (sub_abstracts.size() <= output_with_index.second) {
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << output_with_index.second
|
||||
<< "for node:" << output_with_index.first->DebugString();
|
||||
}
|
||||
|
||||
// There may be tuples in the output of the call node, these outputs will be all numbered, so it is necessary
|
||||
// to count the number of outputs before the target in order to accurately obtain its number.
|
||||
size_t pre_output_num = 0;
|
||||
for (size_t i = 0; i < output_with_index.second; ++i) {
|
||||
pre_output_num += GetOutputNumByAbstract(sub_abstracts[i]);
|
||||
}
|
||||
|
||||
size_t output_num = GetOutputNumByAbstract(sub_abstracts[output_with_index.second]);
|
||||
std::vector<KernelWithIndex> results;
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
results.emplace_back(output_with_index.first, pre_output_num + i);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
|
||||
std::vector<KernelWithIndex> ret;
|
||||
std::vector<KernelWithIndex> ret_empty;
|
||||
|
||||
|
@ -377,7 +375,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
auto make_tuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
|
||||
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i), visited_call_nodes);
|
||||
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i));
|
||||
(void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
|
||||
}
|
||||
return ret;
|
||||
|
@ -387,7 +385,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
||||
auto depend_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend), visited_call_nodes);
|
||||
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend));
|
||||
(void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
|
||||
return ret;
|
||||
}
|
||||
|
@ -424,14 +422,14 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
|
||||
// The makeTuple node need recurse.
|
||||
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
|
||||
auto output_vector = GetAllOutputWithIndex(output_with_index.first, visited_call_nodes);
|
||||
auto output_vector = GetAllOutputWithIndex(output_with_index.first);
|
||||
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fetch outputs by control nodes.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || AnfAlgo::IsCallNode(node)) {
|
||||
const auto &control_node_output = GetAllOutputByControlFlowNode(output_with_index, visited_call_nodes);
|
||||
if (AnfAlgo::IsCallNode(node)) {
|
||||
const auto &control_node_output = GetAllOutputByCallNode(output_with_index);
|
||||
(void)std::copy(control_node_output.begin(), control_node_output.end(), std::back_inserter(ret));
|
||||
continue;
|
||||
}
|
||||
|
@ -445,7 +443,6 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
<< " with output index: " << output_with_index.second;
|
||||
ret.push_back(output_with_index);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -84,8 +84,7 @@ class AnfRuntimeAlgorithm {
|
|||
prim::kPrimMakeTuple});
|
||||
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types = {});
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
|
||||
std::set<AnfNodePtr> *visited_call_nodes = nullptr);
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
|
||||
// get cnode primitive
|
||||
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
|
||||
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
|
||||
|
@ -344,6 +343,10 @@ class AnfRuntimeAlgorithm {
|
|||
// Depth represents the number of layers of the call. When the first input of the call node is a call node,
|
||||
// the funcgraph in the return value of the inner call needs to be returned.
|
||||
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth = 1);
|
||||
// Get the output number according to abstract, when there is a tuple in abstract, it needs to get recursively.
|
||||
static size_t GetOutputNumByAbstract(const AbstractBasePtr &node_abstract);
|
||||
// Fetch all outputs of call node.
|
||||
static std::vector<KernelWithIndex> GetAllOutputByCallNode(const KernelWithIndex &output_with_index);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
Loading…
Reference in New Issue