[auto-monad] Fix primitive_target and IsFeatureMapOutput check for Load
This commit is contained in:
parent
ddffb61c62
commit
ce690a5489
|
@ -1176,9 +1176,12 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index)
|
|||
|
||||
bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>() || IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
return false;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
return IsFeatureMapOutput(node->cast<CNodePtr>()->input(1));
|
||||
}
|
||||
auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
return kernel_info->is_feature_map();
|
||||
|
|
|
@ -1220,13 +1220,9 @@ class AutoMonadConverter {
|
|||
}
|
||||
|
||||
CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) {
|
||||
static const std::string primitive_target = "primitive_target";
|
||||
// Create Load cnode.
|
||||
auto load_prim = NewValueNode(prim::kPrimLoad);
|
||||
auto load_cnode = func_graph_->NewCNode({load_prim, ref, u});
|
||||
// Set device target for Load CNode.
|
||||
std::string target = GetCNodeTarget(cnode);
|
||||
load_cnode->set_user_data(primitive_target, std::make_shared<std::string>(target));
|
||||
// Set load_cnode abstract to Tensor according the input Ref[Tensor].
|
||||
auto ref_abs = dyn_cast<abstract::AbstractRef>(ref->abstract());
|
||||
MS_EXCEPTION_IF_NULL(ref_abs);
|
||||
|
|
|
@ -454,6 +454,8 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
|
|||
if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[1], prim::kPrimMakeTuple)) {
|
||||
return GetCNodeTarget(inputs[1]);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
return GetCNodeTarget(cnode->input(1));
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
return GetMaketupleNodeTarget(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
|
|
Loading…
Reference in New Issue