Skip node

This commit is contained in:
ZPaC 2023-02-10 11:09:43 +08:00
parent 7957890bb6
commit ccfe2c8175
5 changed files with 51 additions and 0 deletions

View File

@ -638,6 +638,14 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfold(const FuncGraphP
// This pattern only supports user node is a TupleGetItem node.
// If this pattern is matched but the user node is not TupleGetItem, throw exception.
if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
// If this node supports any input types, do not process it.
KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
MS_EXCEPTION_IF_NULL(build_info);
if (build_info->op_type() == kernel::OpType::SKIP) {
MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " skip TupleToTupleUnfold type matching.";
*new_prim = false;
return {input};
}
MS_LOG(EXCEPTION) << "Tuple to TupleUnfold pattern should have TupleGetItem as user node, but got "
<< node->fullname_with_scope() << ", " << node->DebugString();
}

View File

@ -53,6 +53,15 @@ std::string KernelTypeLabel(const KernelType &kernel_type) {
return trans_map[kernel_type];
}
std::string OpTypeLabel(const OpType &op_type) {
std::unordered_map<OpType, std::string> trans_map{
{OpType::UNKNOWN_OP_TYPE, "UNKNOWN_OP_TYPE"}, {OpType::DYNAMIC, "DYNAMIC"}, {OpType::SKIP, "SKIP"}};
if (trans_map.find(op_type) == trans_map.end()) {
return "UNKNOWN_OP_TYPE";
}
return trans_map[op_type];
}
std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
if (input_index >= inputs_format_.size()) {
MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
@ -133,6 +142,8 @@ const std::vector<KernelObjectType> &KernelBuildInfo::GetAllInputKernelObjectTyp
return inputs_kernel_object_type_;
}
void KernelBuildInfo::SetOpType(const OpType &op_type) { op_type_ = op_type; }
void KernelBuildInfo::SetOutputsKernelObjectType(const std::vector<KernelObjectType> &outputs_kernel_object_type) {
outputs_kernel_object_type_ = outputs_kernel_object_type;
}
@ -235,6 +246,7 @@ std::string KernelBuildInfo::ToString() const {
output_buffer << KernelObjectTypeLabel(output_object_types[index]);
}
output_buffer << "], kernel_type: " << KernelTypeLabel(kernel_type());
output_buffer << ", op_type: " << OpTypeLabel(op_type());
output_buffer << ")";
return output_buffer.str();
}
@ -269,6 +281,11 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &ke
kernel_build_info_->kernel_type_ = kernel_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOpType(const OpType &op_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->op_type_ = op_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->origin_data_format_ = origin_data_format;

View File

@ -40,8 +40,15 @@ enum KernelObjectType : int {
TUPLE_UNFOLD,
};
enum OpType : int {
UNKNOWN_OP_TYPE = 0,
DYNAMIC,
SKIP,
};
std::string KernelObjectTypeLabel(const KernelObjectType &obj_type);
std::string KernelTypeLabel(const KernelType &kernel_type);
std::string OpTypeLabel(const OpType &op_type);
class BACKEND_EXPORT KernelBuildInfo {
public:
@ -52,8 +59,13 @@ class BACKEND_EXPORT KernelBuildInfo {
~KernelBuildInfo() = default;
KernelType kernel_type() const { return kernel_type_; }
OpType op_type() const { return op_type_; }
void set_kernel_type(KernelType kernel_type) { kernel_type_ = kernel_type; }
void set_op_type(OpType op_type) { op_type_ = op_type; }
std::string GetInputFormat(size_t input_index) const;
std::string GetOutputFormat(size_t output_index) const;
@ -88,6 +100,8 @@ class BACKEND_EXPORT KernelBuildInfo {
const std::vector<KernelObjectType> &GetAllOutputKernelObjectTypes() const;
void SetOpType(const OpType &op_type);
void SetOutputsKernelObjectType(const std::vector<KernelObjectType> &outputs_kernel_object_type);
void SetInputsKernelObjectType(const std::vector<KernelObjectType> &inputs_kernel_object_type);
@ -140,6 +154,7 @@ class BACKEND_EXPORT KernelBuildInfo {
private:
KernelType kernel_type_{UNKNOWN_KERNEL_TYPE};
OpType op_type_{UNKNOWN_OP_TYPE};
std::string origin_data_format_{kOpFormat_DEFAULT};
std::string core_type_;
std::vector<std::string> inputs_format_;
@ -166,6 +181,7 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info)
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
SetKernelType(kernel_build_info->kernel_type());
SetOpType(kernel_build_info->op_type());
SetFusionType(kernel_build_info->fusion_type());
SetProcessor(kernel_build_info->processor());
SetOpPattern(kernel_build_info->op_pattern());
@ -196,6 +212,8 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
void SetKernelType(const KernelType &kernel_type);
void SetOpType(const OpType &op_type);
void SetOriginDataFormat(const std::string &origin_data_format);
void SetInputsFormat(const std::vector<std::string> &inputs_format);

View File

@ -283,6 +283,10 @@ void SetKernelBuildInfoWithSelectedAttr(const CNodePtr &kernel_node, const kerne
input_formats.emplace_back(selected_kernel_attr.GetInputAttr(index).format);
}
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
if (selected_kernel_attr.GetSkipCheck()) {
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
kernel_build_info->SetOpType(kernel::OpType::SKIP);
}
kernel::SetKernelObjectTypeWithSelectedAttr(kernel_node, selected_kernel_attr);
kernel::UnfoldKernelBuildInfo(kernel_node);
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {

View File

@ -610,6 +610,10 @@ bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType ker
auto output_object_types =
kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllOutputObjectType(kernel_node));
kernel::SetKernelObjectTypeBuildInfo(kernel_node, input_object_types, output_object_types);
if (!kernel_attrs.empty()) {
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
kernel_build_info->SetOpType(kernel::OpType::SKIP);
}
return true;
}