!18979 visit nop node for output addr exist check

Merge pull request !18979 from kisnwang/fix-output-nop-error
This commit is contained in:
i-robot 2021-06-28 07:34:39 +00:00 committed by Gitee
commit 0d15aeb524
3 changed files with 12 additions and 3 deletions

View File

@ -1006,8 +1006,17 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
}
// get output device addr of anf_node
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node) {
MS_EXCEPTION_IF_NULL(node);
if (opt::IsNopNode(node) && visit_nop_node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() > 1) {
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, 0);
return OutputAddrExist(kernel_with_index.first, kernel_with_index.second);
}
return false;
}
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);

View File

@ -174,7 +174,7 @@ class AnfRuntimeAlgorithm {
// get mutable output device addr of anf_node
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
// check whether output addr is exist or not
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = false);
// check whether workspace addr is exist or not
static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx);
// get address from prev node,input_index is the input index of current node related to prev node

View File

@ -1612,7 +1612,7 @@ void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
if (iter != tensor_to_node.end()) {
const auto &node = iter->second.first;
const auto &output_index = iter->second.second;
if (!AnfAlgo::OutputAddrExist(node, output_index)) {
if (!AnfAlgo::OutputAddrExist(node, output_index, true)) {
continue;
}
const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);