From ccfe2c817513c6fde58fe4e3c1ef41f55e8b8d3f Mon Sep 17 00:00:00 2001 From: ZPaC Date: Fri, 10 Feb 2023 11:09:43 +0800 Subject: [PATCH] Skip node --- .../common/pass/insert_type_transform_op.cc | 8 ++++++++ mindspore/ccsrc/kernel/kernel_build_info.cc | 17 +++++++++++++++++ mindspore/ccsrc/kernel/kernel_build_info.h | 18 ++++++++++++++++++ .../device/cpu/hal/device/kernel_select_cpu.cc | 4 ++++ .../gpu/hal/device/kernel_info_setter.cc | 4 ++++ 5 files changed, 51 insertions(+) diff --git a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc index 3463ac8b5d3..08f48261c19 100644 --- a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc +++ b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc @@ -638,6 +638,14 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfold(const FuncGraphP // This pattern only supports user node is a TupleGetItem node. // If this pattern is matched but the user node is not TupleGetItem, throw exception. if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + // If this node supports any input types, do not process it. + KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(node); + MS_EXCEPTION_IF_NULL(build_info); + if (build_info->op_type() == kernel::OpType::SKIP) { + MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " skip TupleToTupleUnfold type matching."; + *new_prim = false; + return {input}; + } MS_LOG(EXCEPTION) << "Tuple to TupleUnfold pattern should have TupleGetItem as user node, but got " << node->fullname_with_scope() << ", " << node->DebugString(); } diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index d161933f97c..a30fb6e0cb5 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -53,6 +53,15 @@ std::string KernelTypeLabel(const KernelType &kernel_type) { return trans_map[kernel_type]; } +std::string OpTypeLabel(const OpType &op_type) { + std::unordered_map trans_map{ + {OpType::UNKNOWN_OP_TYPE, "UNKNOWN_OP_TYPE"}, {OpType::DYNAMIC, "DYNAMIC"}, {OpType::SKIP, "SKIP"}}; + if (trans_map.find(op_type) == trans_map.end()) { + return "UNKNOWN_OP_TYPE"; + } + return trans_map[op_type]; +} + std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { if (input_index >= inputs_format_.size()) { MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node"; @@ -133,6 +142,8 @@ const std::vector &KernelBuildInfo::GetAllInputKernelObjectTyp return inputs_kernel_object_type_; } +void KernelBuildInfo::SetOpType(const OpType &op_type) { op_type_ = op_type; } + void KernelBuildInfo::SetOutputsKernelObjectType(const std::vector &outputs_kernel_object_type) { outputs_kernel_object_type_ = outputs_kernel_object_type; } @@ -235,6 +246,7 @@ std::string KernelBuildInfo::ToString() const { output_buffer << KernelObjectTypeLabel(output_object_types[index]); } output_buffer << "], kernel_type: " << KernelTypeLabel(kernel_type()); + output_buffer << ", op_type: " << OpTypeLabel(op_type()); output_buffer << ")"; return output_buffer.str(); } @@ -269,6 +281,11 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &ke kernel_build_info_->kernel_type_ = kernel_type; } +void KernelBuildInfo::KernelBuildInfoBuilder::SetOpType(const OpType &op_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->op_type_ = op_type; +} + void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->origin_data_format_ = origin_data_format; diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h index 03c663b4931..bab804d1315 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ b/mindspore/ccsrc/kernel/kernel_build_info.h @@ -40,8 +40,15 @@ enum KernelObjectType : int { TUPLE_UNFOLD, }; +enum OpType : int { + UNKNOWN_OP_TYPE = 0, + DYNAMIC, + SKIP, +}; + std::string KernelObjectTypeLabel(const KernelObjectType &obj_type); std::string KernelTypeLabel(const KernelType &kernel_type); +std::string OpTypeLabel(const OpType &op_type); class BACKEND_EXPORT KernelBuildInfo { public: @@ -52,8 +59,13 @@ class BACKEND_EXPORT KernelBuildInfo { ~KernelBuildInfo() = default; KernelType kernel_type() const { return kernel_type_; } + + OpType op_type() const { return op_type_; } + void set_kernel_type(KernelType kernel_type) { kernel_type_ = kernel_type; } + void set_op_type(OpType op_type) { op_type_ = op_type; } + std::string GetInputFormat(size_t input_index) const; std::string GetOutputFormat(size_t output_index) const; @@ -88,6 +100,8 @@ class BACKEND_EXPORT KernelBuildInfo { const std::vector &GetAllOutputKernelObjectTypes() const; + void SetOpType(const OpType &op_type); + void SetOutputsKernelObjectType(const std::vector &outputs_kernel_object_type); void SetInputsKernelObjectType(const std::vector &inputs_kernel_object_type); @@ -140,6 +154,7 @@ class BACKEND_EXPORT KernelBuildInfo { private: KernelType kernel_type_{UNKNOWN_KERNEL_TYPE}; + OpType op_type_{UNKNOWN_OP_TYPE}; std::string origin_data_format_{kOpFormat_DEFAULT}; std::string core_type_; std::vector inputs_format_; @@ -166,6 +181,7 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder { explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info) : kernel_build_info_(std::make_shared()) { SetKernelType(kernel_build_info->kernel_type()); + SetOpType(kernel_build_info->op_type()); SetFusionType(kernel_build_info->fusion_type()); SetProcessor(kernel_build_info->processor()); SetOpPattern(kernel_build_info->op_pattern()); @@ -196,6 +212,8 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder { void SetKernelType(const KernelType &kernel_type); + void SetOpType(const OpType &op_type); + void SetOriginDataFormat(const std::string &origin_data_format); void SetInputsFormat(const std::vector &inputs_format); diff --git a/mindspore/ccsrc/plugin/device/cpu/hal/device/kernel_select_cpu.cc b/mindspore/ccsrc/plugin/device/cpu/hal/device/kernel_select_cpu.cc index 389d5639149..872fec593e9 100644 --- a/mindspore/ccsrc/plugin/device/cpu/hal/device/kernel_select_cpu.cc +++ b/mindspore/ccsrc/plugin/device/cpu/hal/device/kernel_select_cpu.cc @@ -283,6 +283,10 @@ void SetKernelBuildInfoWithSelectedAttr(const CNodePtr &kernel_node, const kerne input_formats.emplace_back(selected_kernel_attr.GetInputAttr(index).format); } SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get()); + if (selected_kernel_attr.GetSkipCheck()) { + auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node); + kernel_build_info->SetOpType(kernel::OpType::SKIP); + } kernel::SetKernelObjectTypeWithSelectedAttr(kernel_node, selected_kernel_attr); kernel::UnfoldKernelBuildInfo(kernel_node); if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) { diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/kernel_info_setter.cc b/mindspore/ccsrc/plugin/device/gpu/hal/device/kernel_info_setter.cc index ef8609cdef2..3fd552e4240 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/kernel_info_setter.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/kernel_info_setter.cc @@ -610,6 +610,10 @@ bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType ker auto output_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllOutputObjectType(kernel_node)); kernel::SetKernelObjectTypeBuildInfo(kernel_node, input_object_types, output_object_types); + if (!kernel_attrs.empty()) { + auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node); + kernel_build_info->SetOpType(kernel::OpType::SKIP); + } return true; }