forked from mindspore-Ecosystem/mindspore
!9378 infer shape for noop
From: @wilfchen Reviewed-by: @cristoval,@limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
a76668ce84
|
@ -22,8 +22,6 @@
|
|||
#include <exception>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <stack>
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "debug/data_dump/e2e_dump_util.h"
|
||||
#include "runtime/device/ascend/ascend_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
|
@ -41,7 +39,6 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/ascend/profiling/profiling_utils.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/ascend/ascend_memory_manager.h"
|
||||
#include "debug/tensor_load.h"
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
|
@ -114,34 +111,6 @@ std::string GetRankId() {
|
|||
}
|
||||
return rank_id_str;
|
||||
}
|
||||
|
||||
void InferShapeForNopNode(AnfNodePtr *input_node) {
|
||||
MS_EXCEPTION_IF_NULL(*input_node);
|
||||
if (!opt::IsNopNode(*input_node)) {
|
||||
MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Infer shape for nop node.";
|
||||
std::stack<AnfNodePtr> nop_road;
|
||||
nop_road.push(*input_node);
|
||||
|
||||
while (true) {
|
||||
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0);
|
||||
auto in_node = input_node_with_idx.first;
|
||||
MS_EXCEPTION_IF_NULL(in_node);
|
||||
if (opt::IsNopNode(in_node)) {
|
||||
nop_road.push(in_node);
|
||||
*input_node = in_node;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!nop_road.empty()) {
|
||||
auto nop_node = nop_road.top();
|
||||
AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
|
||||
nop_road.pop();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_;
|
||||
|
@ -665,15 +634,6 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap
|
|||
}
|
||||
|
||||
if (dynamic_kernel->is_dynamic_shape()) {
|
||||
auto kernel_node = dynamic_kernel->kernel_node();
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto input_size = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i);
|
||||
auto input_node = input_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
InferShapeForNopNode(&input_node);
|
||||
}
|
||||
dynamic_kernel->InferShape();
|
||||
dynamic_kernel->UpdateArgs();
|
||||
}
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
|
||||
#include "runtime/device/executor/dynamic_kernel.h"
|
||||
#include <vector>
|
||||
#include <stack>
|
||||
#include <algorithm>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "common/trans.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "abstract/dshape.h"
|
||||
|
@ -73,6 +75,7 @@ void DynamicKernel::InferShape() {
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||
MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope();
|
||||
InferShapeRecursive();
|
||||
|
||||
auto inputs = cnode_ptr_->inputs();
|
||||
if (inputs.empty()) {
|
||||
|
@ -124,5 +127,43 @@ void DynamicKernel::InferShape() {
|
|||
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
|
||||
cnode_ptr_->set_abstract(eval_result);
|
||||
}
|
||||
|
||||
void DynamicKernel::InferShapeRecursive() {
|
||||
auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i);
|
||||
auto input_node = input_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
InferShapeForNopNode(&input_node);
|
||||
}
|
||||
}
|
||||
|
||||
void DynamicKernel::InferShapeForNopNode(AnfNodePtr *input_node) {
|
||||
MS_EXCEPTION_IF_NULL(*input_node);
|
||||
if (!opt::IsNopNode(*input_node) || !AnfAlgo::IsDynamicShape(*input_node)) {
|
||||
MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Infer shape for nop node.";
|
||||
std::stack<AnfNodePtr> nop_road;
|
||||
nop_road.push(*input_node);
|
||||
|
||||
while (true) {
|
||||
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0);
|
||||
auto in_node = input_node_with_idx.first;
|
||||
MS_EXCEPTION_IF_NULL(in_node);
|
||||
if (opt::IsNopNode(in_node)) {
|
||||
nop_road.push(in_node);
|
||||
*input_node = in_node;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!nop_road.empty()) {
|
||||
auto nop_node = nop_road.top();
|
||||
AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
|
||||
nop_road.pop();
|
||||
}
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,10 +48,11 @@ class DynamicKernel {
|
|||
virtual void Initialize();
|
||||
std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); }
|
||||
int GetKernelType();
|
||||
CNodePtr kernel_node() const { return cnode_ptr_; }
|
||||
|
||||
protected:
|
||||
void RebuildDependTensor();
|
||||
void InferShapeRecursive();
|
||||
void InferShapeForNopNode(AnfNodePtr *input_node);
|
||||
|
||||
void *stream_;
|
||||
const CNodePtr cnode_ptr_;
|
||||
|
|
Loading…
Reference in New Issue