forked from mindspore-Ecosystem/mindspore
!10131 dynamic shape optimize
From: @liubuyu Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
984f0fe124
|
@ -54,14 +54,16 @@ void ClearPythonParasMap() { python_paras = nullptr; }
|
|||
namespace {
|
||||
const int kSummaryGetItem = 2;
|
||||
const size_t max_depth = 128;
|
||||
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx, bool *check_dynamic) {
|
||||
bool IsShapeDynamic(const abstract::ShapePtr &shape) {
|
||||
if (shape == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
|
||||
}
|
||||
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (*check_dynamic) {
|
||||
if (node->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(node->cast<CNodePtr>())) {
|
||||
return true;
|
||||
}
|
||||
} else if (AnfAlgo::IsRealKernel(node)) {
|
||||
if (AnfAlgo::IsRealKernel(node)) {
|
||||
return true;
|
||||
}
|
||||
(*idx) += 1;
|
||||
|
@ -69,7 +71,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node,
|
|||
if (*idx <= max_depth) {
|
||||
auto users = manager->node_users()[node];
|
||||
if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
||||
return RecursiveCheck(manager, kernel.first, idx, check_dynamic);
|
||||
return RecursiveCheck(manager, kernel.first, idx);
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
|
@ -82,24 +84,8 @@ bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &no
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_users = manager->node_users()[node];
|
||||
size_t idx = 0;
|
||||
bool check_dynamic = false;
|
||||
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
||||
return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic);
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_users = manager->node_users()[node];
|
||||
size_t idx = 0;
|
||||
bool check_dynamic = true;
|
||||
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
||||
return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic);
|
||||
return RecursiveCheck(manager, kernel.first, &idx);
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
|
@ -481,7 +467,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
|||
builder.SetOutputsFormat({format});
|
||||
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get());
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
|
||||
parameter->Shape()->cast<abstract::BaseShapePtr>());
|
||||
parameter->set_abstract(abstract);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -954,7 +942,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
if (!IsUsedByRealKernel(manager, input_node)) {
|
||||
node_ptr->set_used_by_real_kernel();
|
||||
}
|
||||
if (IsUsedByDynamicKernel(manager, input_node)) {
|
||||
auto shape = node_ptr->Shape();
|
||||
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
|
||||
node_ptr->set_used_by_dynamic_kernel();
|
||||
}
|
||||
}
|
||||
|
@ -1043,7 +1032,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
if (!IsUsedByRealKernel(manager, input_node)) {
|
||||
node_ptr->set_used_by_real_kernel();
|
||||
}
|
||||
if (IsUsedByDynamicKernel(manager, input_node)) {
|
||||
auto shape = node_ptr->Shape();
|
||||
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
|
||||
node_ptr->set_used_by_dynamic_kernel();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue