From 1068ae4b501f309277e72fe8c2bb8b2904466456 Mon Sep 17 00:00:00 2001 From: wYann Date: Fri, 10 Feb 2023 15:42:49 +0800 Subject: [PATCH] get kernel object types from input node of TupleGetItem --- .../backend/common/session/kernel_graph.cc | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/common/session/kernel_graph.cc b/mindspore/ccsrc/backend/common/session/kernel_graph.cc index c0118646438..709bea20146 100644 --- a/mindspore/ccsrc/backend/common/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/common/session/kernel_graph.cc @@ -1086,6 +1086,24 @@ void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vectorcast(); + auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(tuple_get_item, 0); + auto input_node = kernel_with_index.first; + auto input_idx = kernel_with_index.second; + auto kernel_info = dynamic_cast(input_node->kernel_info()); + if (kernel_info != nullptr && kernel_info->has_build_info()) { + auto build_info = kernel_info->select_kernel_build_info(); + const auto &input_kernel_obj_types = build_info->GetAllInputKernelObjectTypes(); + const auto &output_kernel_obj_types = build_info->GetAllOutputKernelObjectTypes(); + if (input_idx < input_kernel_obj_types.size() && output_kernel_obj_types.size() > 0 && + output_kernel_obj_types[0] == kernel::KernelObjectType::TUPLE_UNFOLD) { + return input_kernel_obj_types[input_idx]; + } + } + return kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAbstractObjectType(node->abstract())); +} + void KernelGraph::SetKernelObjectTypesForUnrealNodes() { auto SetKernelObjectTypesForUnrealNode = [](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); @@ -1102,8 +1120,7 @@ void KernelGraph::SetKernelObjectTypesForUnrealNodes() { } if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) && (!kernel_info->has_build_info() || AnfAlgo::GetOutputKernelObjectTypes(node).empty())) { - const auto &output_object_types = AnfAlgo::GetAllOutputObjectType(node); - output_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(output_object_types); + output_kernel_object_types = {GetTupleGetItemOutputKernelObjectType(node)}; const auto &input_object_types = AnfAlgo::GetAllInputObjectType(node); input_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(input_object_types); }