forked from mindspore-Ecosystem/mindspore
!10131 dynamic shape optimize
From: @liubuyu Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
984f0fe124
|
@ -54,14 +54,16 @@ void ClearPythonParasMap() { python_paras = nullptr; }
|
||||||
namespace {
|
namespace {
|
||||||
const int kSummaryGetItem = 2;
|
const int kSummaryGetItem = 2;
|
||||||
const size_t max_depth = 128;
|
const size_t max_depth = 128;
|
||||||
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx, bool *check_dynamic) {
|
bool IsShapeDynamic(const abstract::ShapePtr &shape) {
|
||||||
|
if (shape == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
|
||||||
|
}
|
||||||
|
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx) {
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (*check_dynamic) {
|
if (AnfAlgo::IsRealKernel(node)) {
|
||||||
if (node->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(node->cast<CNodePtr>())) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} else if (AnfAlgo::IsRealKernel(node)) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
(*idx) += 1;
|
(*idx) += 1;
|
||||||
|
@ -69,7 +71,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node,
|
||||||
if (*idx <= max_depth) {
|
if (*idx <= max_depth) {
|
||||||
auto users = manager->node_users()[node];
|
auto users = manager->node_users()[node];
|
||||||
if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
||||||
return RecursiveCheck(manager, kernel.first, idx, check_dynamic);
|
return RecursiveCheck(manager, kernel.first, idx);
|
||||||
})) {
|
})) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -82,24 +84,8 @@ bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &no
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto node_users = manager->node_users()[node];
|
auto node_users = manager->node_users()[node];
|
||||||
size_t idx = 0;
|
size_t idx = 0;
|
||||||
bool check_dynamic = false;
|
|
||||||
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
||||||
return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic);
|
return RecursiveCheck(manager, kernel.first, &idx);
|
||||||
})) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
auto node_users = manager->node_users()[node];
|
|
||||||
size_t idx = 0;
|
|
||||||
bool check_dynamic = true;
|
|
||||||
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
|
||||||
return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic);
|
|
||||||
})) {
|
})) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -481,7 +467,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
||||||
builder.SetOutputsFormat({format});
|
builder.SetOutputsFormat({format});
|
||||||
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
||||||
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
|
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get());
|
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
|
||||||
|
parameter->Shape()->cast<abstract::BaseShapePtr>());
|
||||||
|
parameter->set_abstract(abstract);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -954,7 +942,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
||||||
if (!IsUsedByRealKernel(manager, input_node)) {
|
if (!IsUsedByRealKernel(manager, input_node)) {
|
||||||
node_ptr->set_used_by_real_kernel();
|
node_ptr->set_used_by_real_kernel();
|
||||||
}
|
}
|
||||||
if (IsUsedByDynamicKernel(manager, input_node)) {
|
auto shape = node_ptr->Shape();
|
||||||
|
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
|
||||||
node_ptr->set_used_by_dynamic_kernel();
|
node_ptr->set_used_by_dynamic_kernel();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1043,7 +1032,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
||||||
if (!IsUsedByRealKernel(manager, input_node)) {
|
if (!IsUsedByRealKernel(manager, input_node)) {
|
||||||
node_ptr->set_used_by_real_kernel();
|
node_ptr->set_used_by_real_kernel();
|
||||||
}
|
}
|
||||||
if (IsUsedByDynamicKernel(manager, input_node)) {
|
auto shape = node_ptr->Shape();
|
||||||
|
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
|
||||||
node_ptr->set_used_by_dynamic_kernel();
|
node_ptr->set_used_by_dynamic_kernel();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue