diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 2e1d274179e..6d877145cc9 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1176,9 +1176,12 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - if (node->isa() || IsPrimitiveCNode(node, prim::kPrimLoad)) { + if (node->isa()) { return false; } + if (IsPrimitiveCNode(node, prim::kPrimLoad)) { + return IsFeatureMapOutput(node->cast()->input(1)); + } auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->is_feature_map(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc index cbda67bbadf..bdacaa3165f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc @@ -1220,13 +1220,9 @@ class AutoMonadConverter { } CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) { - static const std::string primitive_target = "primitive_target"; // Create Load cnode. auto load_prim = NewValueNode(prim::kPrimLoad); auto load_cnode = func_graph_->NewCNode({load_prim, ref, u}); - // Set device target for Load CNode. - std::string target = GetCNodeTarget(cnode); - load_cnode->set_user_data(primitive_target, std::make_shared(target)); // Set load_cnode abstract to Tensor according the input Ref[Tensor]. auto ref_abs = dyn_cast(ref->abstract()); MS_EXCEPTION_IF_NULL(ref_abs); diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 379d0919a25..645b1833a14 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -454,6 +454,8 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[1], prim::kPrimMakeTuple)) { return GetCNodeTarget(inputs[1]); } + } else if (IsPrimitiveCNode(node, prim::kPrimLoad)) { + return GetCNodeTarget(cnode->input(1)); } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { return GetMaketupleNodeTarget(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {