forked from mindspore-Ecosystem/mindspore
!48575 kernel object types of unreal nodes should be obtained from input nodes
Merge pull request !48575 from wYann/fix_kobj
This commit is contained in:
commit
99122f3304
|
@ -1086,6 +1086,24 @@ void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNode
|
|||
}
|
||||
}
|
||||
|
||||
kernel::KernelObjectType GetTupleGetItemOutputKernelObjectType(const AnfNodePtr &node) {
|
||||
auto tuple_get_item = node->cast<CNodePtr>();
|
||||
auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(tuple_get_item, 0);
|
||||
auto input_node = kernel_with_index.first;
|
||||
auto input_idx = kernel_with_index.second;
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
|
||||
if (kernel_info != nullptr && kernel_info->has_build_info()) {
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
const auto &input_kernel_obj_types = build_info->GetAllInputKernelObjectTypes();
|
||||
const auto &output_kernel_obj_types = build_info->GetAllOutputKernelObjectTypes();
|
||||
if (input_idx < input_kernel_obj_types.size() && output_kernel_obj_types.size() > 0 &&
|
||||
output_kernel_obj_types[0] == kernel::KernelObjectType::TUPLE_UNFOLD) {
|
||||
return input_kernel_obj_types[input_idx];
|
||||
}
|
||||
}
|
||||
return kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAbstractObjectType(node->abstract()));
|
||||
}
|
||||
|
||||
void KernelGraph::SetKernelObjectTypesForUnrealNodes() {
|
||||
auto SetKernelObjectTypesForUnrealNode = [](const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -1102,8 +1120,7 @@ void KernelGraph::SetKernelObjectTypesForUnrealNodes() {
|
|||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) &&
|
||||
(!kernel_info->has_build_info() || AnfAlgo::GetOutputKernelObjectTypes(node).empty())) {
|
||||
const auto &output_object_types = AnfAlgo::GetAllOutputObjectType(node);
|
||||
output_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(output_object_types);
|
||||
output_kernel_object_types = {GetTupleGetItemOutputKernelObjectType(node)};
|
||||
const auto &input_object_types = AnfAlgo::GetAllInputObjectType(node);
|
||||
input_kernel_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(input_object_types);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue