!48846 Avoid promblem of get func from VmapAbstractClosure

Merge pull request !48846 from chenfei_mindspore/vmap-stuck-fix
This commit is contained in:
i-robot 2023-03-06 13:06:14 +00:00 committed by Gitee
commit c27f43ba0c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 26 additions and 29 deletions

View File

@ -1041,7 +1041,7 @@ bool HasIncorporateCall(const std::vector<AnfNodePtr> &all_nodes) {
auto input0 = cnode->input(0);
if (IsPrimitiveCNode(input0, prim::kPrimSwitch) || IsPrimitiveCNode(input0, prim::kPrimSwitchLayer) ||
IsValueNode<FuncGraph>(input0)) {
auto func_graphs = abstract::GetFuncGraphsFromAbs(input0);
auto func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
auto graph_has_function_output = [](const FuncGraphPtr &fg) {
return HasAbstractFunction(fg->output()->abstract());
};

View File

@ -1954,7 +1954,7 @@ void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &
const auto &cnode = control_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &func_graphs = abstract::GetFuncGraphsFromAbs(cnode->input(0));
const auto &func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
if (func_graphs.empty()) {
MS_LOG(EXCEPTION) << "Get func graphs from abstract failed.";
}

View File

@ -184,7 +184,7 @@ bool CSE::IsHiddenSideEffectCall(const AnfNodePtr &node) {
return false;
}
// If it is a func graph call node, get all graphs from abstract.
auto func_graphs = abstract::GetFuncGraphsFromAbs(cnode->input(0));
auto func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
auto is_hidden_side_effect_graph = [this](const FuncGraphPtr &fg) -> bool {
return hidden_side_effect_func_graphs_.find(fg) != hidden_side_effect_func_graphs_.end();
};

View File

@ -335,64 +335,61 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
}
namespace {
FuncGraphPtr GetFuncGraphFromAbs(const abstract::AbstractBasePtr &abs, const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
FuncGraphPtr GetFuncGraphFromAbs(const abstract::AbstractBasePtr &abs, const AnfNodePtr &call_node) {
MS_EXCEPTION_IF_NULL(call_node);
if (abs == nullptr) {
MS_LOG(ERROR) << "Null abstract, current node: " << anf_node->DebugString();
MS_LOG(ERROR) << "Null abstract, current node: " << call_node->DebugString();
return nullptr;
}
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
auto abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
MS_EXCEPTION_IF_NULL(abs_func_graph);
if (!abs_func_graph->specialized()) {
MS_LOG(INFO) << "Unspecilized func graph abstract: " << abs_func_graph->ToString()
<< ", node: " << anf_node->DebugString();
MS_LOG(INFO) << "Unspecialized func graph abstract: " << abs_func_graph->ToString()
<< ", node: " << call_node->DebugString();
}
return abs_func_graph->func_graph();
}
if (abs->isa<abstract::PartialAbstractClosure>()) {
auto abs_partial_closure = abs->cast<abstract::PartialAbstractClosurePtr>();
MS_EXCEPTION_IF_NULL(abs_partial_closure);
auto abs_func = abs_partial_closure->fn();
return GetFuncGraphFromAbs(abs_func, anf_node);
return GetFuncGraphFromAbs(abs_func, call_node);
}
if (abs->isa<abstract::MetaFuncGraphAbstractClosure>()) {
if (!IsValueNode<FuncGraph>(anf_node)) {
MS_LOG(EXCEPTION) << "Got unexpected MetaFuncGraphAbstractClosure: " << abs->ToString()
<< ", anf node: " << anf_node->DebugString();
}
return GetValueNode<FuncGraphPtr>(anf_node);
}
MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", node: " << anf_node->DebugString();
MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", call node: " << call_node->DebugString();
return nullptr;
}
} // namespace
std::vector<FuncGraphPtr> GetFuncGraphsFromAbs(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (IsValueNode<FuncGraph>(anf_node)) {
return {GetValueNode<FuncGraphPtr>(anf_node)};
std::vector<FuncGraphPtr> GetFuncGraphsFromCallNode(const CNodePtr &call_node) {
MS_EXCEPTION_IF_NULL(call_node);
auto func_node = call_node->input(0);
if (IsPrimitiveCNode(func_node, prim::kPrimPartial)) {
func_node = func_node->cast<CNodePtr>()->input(1);
}
auto abs = anf_node->abstract();
if (IsValueNode<FuncGraph>(func_node)) {
return {GetValueNode<FuncGraphPtr>(func_node)};
}
auto abs = func_node->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs == nullptr) {
MS_LOG(ERROR) << "Null abstract, current node: " << anf_node->DebugString();
MS_LOG(ERROR) << "Null abstract, current call node: " << call_node->DebugString();
return {};
}
if (!abs->isa<abstract::AbstractFunction>()) {
MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", anf_node: " << anf_node->DebugString();
MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", call_node: " << call_node->DebugString();
return {};
}
auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
MS_EXCEPTION_IF_NULL(abs_func);
std::vector<FuncGraphPtr> func_graphs;
if (abs->isa<abstract::AbstractFuncUnion>()) {
auto visit_func = [&func_graphs, &anf_node](const abstract::AbstractFuncAtomPtr &poss) {
(void)func_graphs.emplace_back(GetFuncGraphFromAbs(poss, anf_node));
auto visit_func = [&func_graphs, &call_node](const abstract::AbstractFuncAtomPtr &poss) {
(void)func_graphs.emplace_back(GetFuncGraphFromAbs(poss, call_node));
};
abs_func->Visit(visit_func);
} else {
(void)func_graphs.emplace_back(GetFuncGraphFromAbs(abs_func, anf_node));
(void)func_graphs.emplace_back(GetFuncGraphFromAbs(abs_func, call_node));
}
bool exist_null_fg =
std::any_of(func_graphs.cbegin(), func_graphs.cend(), [](const FuncGraphPtr &fg) { return fg == nullptr; });

View File

@ -56,7 +56,7 @@ T ShapeSize(const std::vector<T> &shape) {
MS_CORE_API AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
MS_CORE_API AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
MS_CORE_API AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
MS_CORE_API std::vector<FuncGraphPtr> GetFuncGraphsFromAbs(const AnfNodePtr &anf_node);
MS_CORE_API std::vector<FuncGraphPtr> GetFuncGraphsFromCallNode(const CNodePtr &call_node);
class MS_CORE_API EnvSetSparseResultMgr {
public: