forked from mindspore-Ecosystem/mindspore
!28967 Actor get funcgraphs from abstract
Merge pull request !28967 from chenfei_mindspore/suit-renormalize-elimiate-dead-node
This commit is contained in:
commit
f20b69ffc5
|
@ -20,6 +20,7 @@
|
|||
#include "utils/convert_utils.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -1520,6 +1521,47 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vector<KernelGra
|
|||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr GetFuncGraph(const abstract::AbstractBasePtr &abs, const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(abs != nullptr, "Null abstract, current node: " + anf_node->DebugString());
|
||||
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
auto abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
|
||||
if (!abs_func_graph->specialized()) {
|
||||
MS_LOG(EXCEPTION) << "Unspecilized func graph abstract: " << abs_func_graph->ToString()
|
||||
<< ", node: " << anf_node->DebugString();
|
||||
}
|
||||
return abs_func_graph->func_graph();
|
||||
}
|
||||
|
||||
if (abs->isa<abstract::PartialAbstractClosure>()) {
|
||||
auto abs_partial_closure = abs->cast<abstract::PartialAbstractClosurePtr>();
|
||||
auto abs_func = abs_partial_closure->fn();
|
||||
return GetFuncGraph(abs_func, anf_node);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected abs: " << abs->ToString();
|
||||
}
|
||||
|
||||
std::vector<FuncGraphPtr> GetFuncGraphs(const AnfNodePtr &anf_node) {
|
||||
if (IsValueNode<FuncGraph>(anf_node)) {
|
||||
return {GetValueNode<FuncGraphPtr>(anf_node)};
|
||||
}
|
||||
auto abs = anf_node->abstract();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(abs != nullptr, "Null abstract of node: " + anf_node->DebugString());
|
||||
if (!abs->isa<abstract::AbstractFunction>()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected abs: " << abs->ToString() << ", anf_node: " << anf_node->DebugString();
|
||||
}
|
||||
auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
|
||||
std::vector<FuncGraphPtr> ret;
|
||||
if (abs->isa<abstract::AbstractFuncUnion>()) {
|
||||
auto visit_func = [&ret, &anf_node](const abstract::AbstractFuncAtomPtr &poss) {
|
||||
ret.emplace_back(GetFuncGraph(poss, anf_node));
|
||||
};
|
||||
abs_func->Visit(visit_func);
|
||||
} else {
|
||||
ret.emplace_back(GetFuncGraph(abs_func, anf_node));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
auto func_graph_analyzer = std::make_shared<FuncGraphAnalyzer>(root_func_graph_);
|
||||
func_graph_analyzer->Run();
|
||||
|
@ -1529,8 +1571,13 @@ void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &
|
|||
if (!AnfAlgo::IsCallNode(control_node)) {
|
||||
continue;
|
||||
}
|
||||
std::vector<FuncGraphPtr> func_graphs;
|
||||
if (common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT") == "1") {
|
||||
func_graphs = GetFuncGraphs(control_node->cast<CNodePtr>()->input(0));
|
||||
} else {
|
||||
func_graphs = func_graph_analyzer->GetCallerFuncGraphs(control_node);
|
||||
}
|
||||
|
||||
auto func_graphs = func_graph_analyzer->GetCallerFuncGraphs(control_node);
|
||||
for (auto func_graph : func_graphs) {
|
||||
(void)call_node_to_func_graphs_[control_node].emplace(func_graph);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue