diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index aaa0c155e40..ca4cbf5158d 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -226,7 +226,10 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t << AnfAlgo::GetInputTensorNum(kernel); } auto input_node = kernel->input(input_idx + 1); - auto kernel_input = AnfAlgo::VisitKernel(input_node, 0); + auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; + } auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second)); return result; } @@ -264,7 +267,10 @@ void MemReuseUtil::SetKernelDefInputs() { if (ref_ptr != nullptr) { // set the inputs of this kernel_def auto input_node = AnfAlgo::GetInputNode(kernel, i); - auto input = AnfAlgo::VisitKernel(input_node, 0); + auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; + } auto input_key = (input.first).get(); auto input_iter = kernel_map_.find(input_key); if (input_iter == kernel_map_.end()) { diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc index cf92679187f..5cd6a5f50ec 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc @@ -48,7 +48,8 @@ void MemReuseChecker::CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr auto iter = kernel_refs.find(key); auto node_name = AnfAlgo::GetCNodeName(c_node); if (iter == kernel_refs.end()) { - MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor"; + MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor, node: " << c_node->DebugString() + << " output index: " << output_idx; } if (output_idx >= iter->second.size()) { MS_LOG(INFO) << "invalid cnode: " << c_node->fullname_with_scope().c_str();