!10131 dynamic shape optimize

From: @liubuyu
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2020-12-18 11:50:26 +08:00 committed by Gitee
commit 984f0fe124
1 changed files with 17 additions and 27 deletions

View File

@ -54,14 +54,16 @@ void ClearPythonParasMap() { python_paras = nullptr; }
namespace {
const int kSummaryGetItem = 2;
const size_t max_depth = 128;
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx, bool *check_dynamic) {
bool IsShapeDynamic(const abstract::ShapePtr &shape) {
if (shape == nullptr) {
return false;
}
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
}
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (*check_dynamic) {
if (node->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(node->cast<CNodePtr>())) {
return true;
}
} else if (AnfAlgo::IsRealKernel(node)) {
if (AnfAlgo::IsRealKernel(node)) {
return true;
}
(*idx) += 1;
@ -69,7 +71,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node,
if (*idx <= max_depth) {
auto users = manager->node_users()[node];
if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
return RecursiveCheck(manager, kernel.first, idx, check_dynamic);
return RecursiveCheck(manager, kernel.first, idx);
})) {
return true;
}
@ -82,24 +84,8 @@ bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &no
MS_EXCEPTION_IF_NULL(node);
auto node_users = manager->node_users()[node];
size_t idx = 0;
bool check_dynamic = false;
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic);
})) {
return true;
}
return false;
}
bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
auto node_users = manager->node_users()[node];
size_t idx = 0;
bool check_dynamic = true;
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic);
return RecursiveCheck(manager, kernel.first, &idx);
})) {
return true;
}
@ -481,7 +467,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
builder.SetOutputsFormat({format});
d_kernel_info->set_select_kernel_build_info(builder.Build());
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get());
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
parameter->Shape()->cast<abstract::BaseShapePtr>());
parameter->set_abstract(abstract);
}
}
@ -954,7 +942,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
if (!IsUsedByRealKernel(manager, input_node)) {
node_ptr->set_used_by_real_kernel();
}
if (IsUsedByDynamicKernel(manager, input_node)) {
auto shape = node_ptr->Shape();
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
node_ptr->set_used_by_dynamic_kernel();
}
}
@ -1043,7 +1032,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
if (!IsUsedByRealKernel(manager, input_node)) {
node_ptr->set_used_by_real_kernel();
}
if (IsUsedByDynamicKernel(manager, input_node)) {
auto shape = node_ptr->Shape();
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
node_ptr->set_used_by_dynamic_kernel();
}
}