forked from mindspore-Ecosystem/mindspore
!18979 visit nop node for output addr exist check
Merge pull request !18979 from kisnwang/fix-output-nop-error
This commit is contained in:
commit
0d15aeb524
|
@ -1006,8 +1006,17 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
|
|||
}
|
||||
|
||||
// get output device addr of anf_node
|
||||
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
|
||||
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (opt::IsNopNode(node) && visit_nop_node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() > 1) {
|
||||
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, 0);
|
||||
return OutputAddrExist(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
|
|
@ -174,7 +174,7 @@ class AnfRuntimeAlgorithm {
|
|||
// get mutable output device addr of anf_node
|
||||
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
|
||||
// check whether output addr is exist or not
|
||||
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
|
||||
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = false);
|
||||
// check whether workspace addr is exist or not
|
||||
static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx);
|
||||
// get address from prev node,input_index is the input index of current node related to prev node
|
||||
|
|
|
@ -1612,7 +1612,7 @@ void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
|
|||
if (iter != tensor_to_node.end()) {
|
||||
const auto &node = iter->second.first;
|
||||
const auto &output_index = iter->second.second;
|
||||
if (!AnfAlgo::OutputAddrExist(node, output_index)) {
|
||||
if (!AnfAlgo::OutputAddrExist(node, output_index, true)) {
|
||||
continue;
|
||||
}
|
||||
const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);
|
||||
|
|
Loading…
Reference in New Issue