diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 7d39239d3fe..df0b4e7ecf7 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -54,14 +54,16 @@ void ClearPythonParasMap() { python_paras = nullptr; } namespace { const int kSummaryGetItem = 2; const size_t max_depth = 128; -bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx, bool *check_dynamic) { +bool IsShapeDynamic(const abstract::ShapePtr &shape) { + if (shape == nullptr) { + return false; + } + return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; }); +} +bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx) { MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(node); - if (*check_dynamic) { - if (node->isa() && AnfAlgo::IsNodeDynamicShape(node->cast())) { - return true; - } - } else if (AnfAlgo::IsRealKernel(node)) { + if (AnfAlgo::IsRealKernel(node)) { return true; } (*idx) += 1; @@ -69,7 +71,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, if (*idx <= max_depth) { auto users = manager->node_users()[node]; if (std::any_of(users.begin(), users.end(), [&](const std::pair &kernel) { - return RecursiveCheck(manager, kernel.first, idx, check_dynamic); + return RecursiveCheck(manager, kernel.first, idx); })) { return true; } @@ -82,24 +84,8 @@ bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &no MS_EXCEPTION_IF_NULL(node); auto node_users = manager->node_users()[node]; size_t idx = 0; - bool check_dynamic = false; if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair &kernel) { - return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); - })) { - return true; - } - - return false; -} - -bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(node); - auto node_users = manager->node_users()[node]; - size_t idx = 0; - bool check_dynamic = true; - if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair &kernel) { - return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); + return RecursiveCheck(manager, kernel.first, &idx); })) { return true; } @@ -481,7 +467,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const builder.SetOutputsFormat({format}); d_kernel_info->set_select_kernel_build_info(builder.Build()); AnfAlgo::SetOutputAddr(address, 0, parameter.get()); - AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get()); + auto abstract = std::make_shared(TypeIdToType(type), + parameter->Shape()->cast()); + parameter->set_abstract(abstract); } } @@ -954,7 +942,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con if (!IsUsedByRealKernel(manager, input_node)) { node_ptr->set_used_by_real_kernel(); } - if (IsUsedByDynamicKernel(manager, input_node)) { + auto shape = node_ptr->Shape(); + if (IsShapeDynamic(shape->cast())) { node_ptr->set_used_by_dynamic_kernel(); } } @@ -1043,7 +1032,8 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP if (!IsUsedByRealKernel(manager, input_node)) { node_ptr->set_used_by_real_kernel(); } - if (IsUsedByDynamicKernel(manager, input_node)) { + auto shape = node_ptr->Shape(); + if (IsShapeDynamic(shape->cast())) { node_ptr->set_used_by_dynamic_kernel(); } }