!28967 Actor get funcgraphs from abstract

Merge pull request !28967 from chenfei_mindspore/suit-renormalize-elimiate-dead-node
This commit is contained in:
i-robot 2022-01-16 04:08:43 +00:00 committed by Gitee
commit f20b69ffc5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 48 additions and 1 deletions

View File

@ -20,6 +20,7 @@
#include "utils/convert_utils.h"
#include "abstract/utils.h"
#include "ir/tensor.h"
#include "abstract/abstract_function.h"
namespace mindspore {
namespace runtime {
@ -1520,6 +1521,47 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGra
}
}
FuncGraphPtr GetFuncGraph(const abstract::AbstractBasePtr &abs, const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_CHECK_FAIL(abs != nullptr, "Null abstract, current node: " + anf_node->DebugString());
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
auto abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
if (!abs_func_graph->specialized()) {
MS_LOG(EXCEPTION) << "Unspecilized func graph abstract: " << abs_func_graph->ToString()
<< ", node: " << anf_node->DebugString();
}
return abs_func_graph->func_graph();
}
if (abs->isa<abstract::PartialAbstractClosure>()) {
auto abs_partial_closure = abs->cast<abstract::PartialAbstractClosurePtr>();
auto abs_func = abs_partial_closure->fn();
return GetFuncGraph(abs_func, anf_node);
}
MS_LOG(EXCEPTION) << "Unexpected abs: " << abs->ToString();
}
std::vector<FuncGraphPtr> GetFuncGraphs(const AnfNodePtr &anf_node) {
if (IsValueNode<FuncGraph>(anf_node)) {
return {GetValueNode<FuncGraphPtr>(anf_node)};
}
auto abs = anf_node->abstract();
MS_EXCEPTION_IF_CHECK_FAIL(abs != nullptr, "Null abstract of node: " + anf_node->DebugString());
if (!abs->isa<abstract::AbstractFunction>()) {
MS_LOG(EXCEPTION) << "Unexpected abs: " << abs->ToString() << ", anf_node: " << anf_node->DebugString();
}
auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
std::vector<FuncGraphPtr> ret;
if (abs->isa<abstract::AbstractFuncUnion>()) {
auto visit_func = [&ret, &anf_node](const abstract::AbstractFuncAtomPtr &poss) {
ret.emplace_back(GetFuncGraph(poss, anf_node));
};
abs_func->Visit(visit_func);
} else {
ret.emplace_back(GetFuncGraph(abs_func, anf_node));
}
return ret;
}
void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
auto func_graph_analyzer = std::make_shared<FuncGraphAnalyzer>(root_func_graph_);
func_graph_analyzer->Run();
@ -1529,8 +1571,13 @@ void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &
if (!AnfAlgo::IsCallNode(control_node)) {
continue;
}
std::vector<FuncGraphPtr> func_graphs;
if (common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT") == "1") {
func_graphs = GetFuncGraphs(control_node->cast<CNodePtr>()->input(0));
} else {
func_graphs = func_graph_analyzer->GetCallerFuncGraphs(control_node);
}
auto func_graphs = func_graph_analyzer->GetCallerFuncGraphs(control_node);
for (auto func_graph : func_graphs) {
(void)call_node_to_func_graphs_[control_node].emplace(func_graph);
}