forked from mindspore-Ecosystem/mindspore
fix bug to remove reshape when reshape is depend's input
This commit is contained in:
parent
166d886501
commit
21770e7b6f
|
@ -223,13 +223,21 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) {
|
|||
}
|
||||
|
||||
KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) {
|
||||
auto is_all_nop_node = opt::IsAllNopNode(graph_);
|
||||
if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) {
|
||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number "
|
||||
<< AnfAlgo::GetInputTensorNum(kernel);
|
||||
}
|
||||
auto input_node = kernel->input(input_idx + 1);
|
||||
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
||||
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
session::KernelWithIndex kernel_input;
|
||||
if (is_all_nop_node) {
|
||||
// The graph does not remove the nop node.
|
||||
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
} else {
|
||||
// The graph removes the nop node.
|
||||
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
}
|
||||
if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple";
|
||||
}
|
||||
|
@ -257,6 +265,7 @@ void MemReuseUtil::SetKernelDefMap() {
|
|||
}
|
||||
|
||||
void MemReuseUtil::SetKernelDefInputs() {
|
||||
auto is_all_nop_node = opt::IsAllNopNode(graph_);
|
||||
for (const auto &kernel : graph_->execution_order()) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto key = kernel.get();
|
||||
|
@ -272,7 +281,14 @@ void MemReuseUtil::SetKernelDefInputs() {
|
|||
// set the inputs of this kernel_def
|
||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
|
||||
auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
session::KernelWithIndex input;
|
||||
if (is_all_nop_node) {
|
||||
// The graph does not remove the nop node.
|
||||
input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
} else {
|
||||
// The graph removes the nop node.
|
||||
input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
}
|
||||
if (IsPrimitive(input.first, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue