!48575 kernel object types of unreal nodes should be obtained from input nodes

Merge pull request !48575 from wYann/fix_kobj
This commit is contained in:
i-robot 2023-02-13 03:52:02 +00:00 committed by Gitee
commit 99122f3304
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 19 additions and 2 deletions

View File

@ -1086,6 +1086,24 @@ void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNode
}
}
kernel::KernelObjectType GetTupleGetItemOutputKernelObjectType(const AnfNodePtr &node) {
auto tuple_get_item = node->cast<CNodePtr>();
auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(tuple_get_item, 0);
auto input_node = kernel_with_index.first;
auto input_idx = kernel_with_index.second;
auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
if (kernel_info != nullptr && kernel_info->has_build_info()) {
auto build_info = kernel_info->select_kernel_build_info();
const auto &input_kernel_obj_types = build_info->GetAllInputKernelObjectTypes();
const auto &output_kernel_obj_types = build_info->GetAllOutputKernelObjectTypes();
if (input_idx < input_kernel_obj_types.size() && output_kernel_obj_types.size() > 0 &&
output_kernel_obj_types[0] == kernel::KernelObjectType::TUPLE_UNFOLD) {
return input_kernel_obj_types[input_idx];
}
}
return kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAbstractObjectType(node->abstract()));
}
void KernelGraph::SetKernelObjectTypesForUnrealNodes() {
auto SetKernelObjectTypesForUnrealNode = [](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
@ -1102,8 +1120,7 @@ void KernelGraph::SetKernelObjectTypesForUnrealNodes() {
}
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) &&
(!kernel_info->has_build_info() || AnfAlgo::GetOutputKernelObjectTypes(node).empty())) {
const auto &output_object_types = AnfAlgo::GetAllOutputObjectType(node);
output_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(output_object_types);
output_kernel_object_types = {GetTupleGetItemOutputKernelObjectType(node)};
const auto &input_object_types = AnfAlgo::GetAllInputObjectType(node);
input_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(input_object_types);
}