forked from mindspore-Ecosystem/mindspore
!48846 Avoid promblem of get func from VmapAbstractClosure
Merge pull request !48846 from chenfei_mindspore/vmap-stuck-fix
This commit is contained in:
commit
c27f43ba0c
|
@ -1041,7 +1041,7 @@ bool HasIncorporateCall(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
auto input0 = cnode->input(0);
|
||||
if (IsPrimitiveCNode(input0, prim::kPrimSwitch) || IsPrimitiveCNode(input0, prim::kPrimSwitchLayer) ||
|
||||
IsValueNode<FuncGraph>(input0)) {
|
||||
auto func_graphs = abstract::GetFuncGraphsFromAbs(input0);
|
||||
auto func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
|
||||
auto graph_has_function_output = [](const FuncGraphPtr &fg) {
|
||||
return HasAbstractFunction(fg->output()->abstract());
|
||||
};
|
||||
|
|
|
@ -1954,7 +1954,7 @@ void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &
|
|||
|
||||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &func_graphs = abstract::GetFuncGraphsFromAbs(cnode->input(0));
|
||||
const auto &func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
|
||||
if (func_graphs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Get func graphs from abstract failed.";
|
||||
}
|
||||
|
|
|
@ -184,7 +184,7 @@ bool CSE::IsHiddenSideEffectCall(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
// If it is a func graph call node, get all graphs from abstract.
|
||||
auto func_graphs = abstract::GetFuncGraphsFromAbs(cnode->input(0));
|
||||
auto func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
|
||||
auto is_hidden_side_effect_graph = [this](const FuncGraphPtr &fg) -> bool {
|
||||
return hidden_side_effect_func_graphs_.find(fg) != hidden_side_effect_func_graphs_.end();
|
||||
};
|
||||
|
|
|
@ -335,64 +335,61 @@ AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type
|
|||
}
|
||||
|
||||
namespace {
|
||||
FuncGraphPtr GetFuncGraphFromAbs(const abstract::AbstractBasePtr &abs, const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
FuncGraphPtr GetFuncGraphFromAbs(const abstract::AbstractBasePtr &abs, const AnfNodePtr &call_node) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(ERROR) << "Null abstract, current node: " << anf_node->DebugString();
|
||||
MS_LOG(ERROR) << "Null abstract, current node: " << call_node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
auto abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_func_graph);
|
||||
if (!abs_func_graph->specialized()) {
|
||||
MS_LOG(INFO) << "Unspecilized func graph abstract: " << abs_func_graph->ToString()
|
||||
<< ", node: " << anf_node->DebugString();
|
||||
MS_LOG(INFO) << "Unspecialized func graph abstract: " << abs_func_graph->ToString()
|
||||
<< ", node: " << call_node->DebugString();
|
||||
}
|
||||
return abs_func_graph->func_graph();
|
||||
}
|
||||
|
||||
if (abs->isa<abstract::PartialAbstractClosure>()) {
|
||||
auto abs_partial_closure = abs->cast<abstract::PartialAbstractClosurePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_partial_closure);
|
||||
auto abs_func = abs_partial_closure->fn();
|
||||
return GetFuncGraphFromAbs(abs_func, anf_node);
|
||||
return GetFuncGraphFromAbs(abs_func, call_node);
|
||||
}
|
||||
if (abs->isa<abstract::MetaFuncGraphAbstractClosure>()) {
|
||||
if (!IsValueNode<FuncGraph>(anf_node)) {
|
||||
MS_LOG(EXCEPTION) << "Got unexpected MetaFuncGraphAbstractClosure: " << abs->ToString()
|
||||
<< ", anf node: " << anf_node->DebugString();
|
||||
}
|
||||
return GetValueNode<FuncGraphPtr>(anf_node);
|
||||
}
|
||||
MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", node: " << anf_node->DebugString();
|
||||
MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", call node: " << call_node->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<FuncGraphPtr> GetFuncGraphsFromAbs(const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (IsValueNode<FuncGraph>(anf_node)) {
|
||||
return {GetValueNode<FuncGraphPtr>(anf_node)};
|
||||
std::vector<FuncGraphPtr> GetFuncGraphsFromCallNode(const CNodePtr &call_node) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
auto func_node = call_node->input(0);
|
||||
if (IsPrimitiveCNode(func_node, prim::kPrimPartial)) {
|
||||
func_node = func_node->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
auto abs = anf_node->abstract();
|
||||
if (IsValueNode<FuncGraph>(func_node)) {
|
||||
return {GetValueNode<FuncGraphPtr>(func_node)};
|
||||
}
|
||||
auto abs = func_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(ERROR) << "Null abstract, current node: " << anf_node->DebugString();
|
||||
MS_LOG(ERROR) << "Null abstract, current call node: " << call_node->DebugString();
|
||||
return {};
|
||||
}
|
||||
if (!abs->isa<abstract::AbstractFunction>()) {
|
||||
MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", anf_node: " << anf_node->DebugString();
|
||||
MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", call_node: " << call_node->DebugString();
|
||||
return {};
|
||||
}
|
||||
auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_func);
|
||||
std::vector<FuncGraphPtr> func_graphs;
|
||||
if (abs->isa<abstract::AbstractFuncUnion>()) {
|
||||
auto visit_func = [&func_graphs, &anf_node](const abstract::AbstractFuncAtomPtr &poss) {
|
||||
(void)func_graphs.emplace_back(GetFuncGraphFromAbs(poss, anf_node));
|
||||
auto visit_func = [&func_graphs, &call_node](const abstract::AbstractFuncAtomPtr &poss) {
|
||||
(void)func_graphs.emplace_back(GetFuncGraphFromAbs(poss, call_node));
|
||||
};
|
||||
abs_func->Visit(visit_func);
|
||||
} else {
|
||||
(void)func_graphs.emplace_back(GetFuncGraphFromAbs(abs_func, anf_node));
|
||||
(void)func_graphs.emplace_back(GetFuncGraphFromAbs(abs_func, call_node));
|
||||
}
|
||||
bool exist_null_fg =
|
||||
std::any_of(func_graphs.cbegin(), func_graphs.cend(), [](const FuncGraphPtr &fg) { return fg == nullptr; });
|
||||
|
|
|
@ -56,7 +56,7 @@ T ShapeSize(const std::vector<T> &shape) {
|
|||
MS_CORE_API AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
|
||||
MS_CORE_API AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
|
||||
MS_CORE_API AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
|
||||
MS_CORE_API std::vector<FuncGraphPtr> GetFuncGraphsFromAbs(const AnfNodePtr &anf_node);
|
||||
MS_CORE_API std::vector<FuncGraphPtr> GetFuncGraphsFromCallNode(const CNodePtr &call_node);
|
||||
|
||||
class MS_CORE_API EnvSetSparseResultMgr {
|
||||
public:
|
||||
|
|
Loading…
Reference in New Issue