forked from mindspore-Ecosystem/mindspore
!1929 use VisitKernelWithReturnType instead of VisitKernel to get node's input in mem_reuse
Merge pull request !1929 from laiyongqiang/mem
This commit is contained in:
commit
5770df0951
|
@ -226,7 +226,10 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t
|
|||
<< AnfAlgo::GetInputTensorNum(kernel);
|
||||
}
|
||||
auto input_node = kernel->input(input_idx + 1);
|
||||
auto kernel_input = AnfAlgo::VisitKernel(input_node, 0);
|
||||
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple";
|
||||
}
|
||||
auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second));
|
||||
return result;
|
||||
}
|
||||
|
@ -264,7 +267,10 @@ void MemReuseUtil::SetKernelDefInputs() {
|
|||
if (ref_ptr != nullptr) {
|
||||
// set the inputs of this kernel_def
|
||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||
auto input = AnfAlgo::VisitKernel(input_node, 0);
|
||||
auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
if (IsPrimitive(input.first, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
|
||||
}
|
||||
auto input_key = (input.first).get();
|
||||
auto input_iter = kernel_map_.find(input_key);
|
||||
if (input_iter == kernel_map_.end()) {
|
||||
|
|
|
@ -48,7 +48,8 @@ void MemReuseChecker::CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr
|
|||
auto iter = kernel_refs.find(key);
|
||||
auto node_name = AnfAlgo::GetCNodeName(c_node);
|
||||
if (iter == kernel_refs.end()) {
|
||||
MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor";
|
||||
MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor, node: " << c_node->DebugString()
|
||||
<< " output index: " << output_idx;
|
||||
}
|
||||
if (output_idx >= iter->second.size()) {
|
||||
MS_LOG(INFO) << "invalid cnode: " << c_node->fullname_with_scope().c_str();
|
||||
|
|
Loading…
Reference in New Issue