forked from mindspore-Ecosystem/mindspore
Skip node
This commit is contained in:
parent
7957890bb6
commit
ccfe2c8175
|
@ -638,6 +638,14 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTupleUnfold(const FuncGraphP
|
|||
// This pattern only supports user node is a TupleGetItem node.
|
||||
// If this pattern is matched but the user node is not TupleGetItem, throw exception.
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
// If this node supports any input types, do not process it.
|
||||
KernelBuildInfoPtr build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
if (build_info->op_type() == kernel::OpType::SKIP) {
|
||||
MS_LOG(INFO) << "Node " << node->fullname_with_scope() << " skip TupleToTupleUnfold type matching.";
|
||||
*new_prim = false;
|
||||
return {input};
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Tuple to TupleUnfold pattern should have TupleGetItem as user node, but got "
|
||||
<< node->fullname_with_scope() << ", " << node->DebugString();
|
||||
}
|
||||
|
|
|
@ -53,6 +53,15 @@ std::string KernelTypeLabel(const KernelType &kernel_type) {
|
|||
return trans_map[kernel_type];
|
||||
}
|
||||
|
||||
std::string OpTypeLabel(const OpType &op_type) {
|
||||
std::unordered_map<OpType, std::string> trans_map{
|
||||
{OpType::UNKNOWN_OP_TYPE, "UNKNOWN_OP_TYPE"}, {OpType::DYNAMIC, "DYNAMIC"}, {OpType::SKIP, "SKIP"}};
|
||||
if (trans_map.find(op_type) == trans_map.end()) {
|
||||
return "UNKNOWN_OP_TYPE";
|
||||
}
|
||||
return trans_map[op_type];
|
||||
}
|
||||
|
||||
std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
|
||||
if (input_index >= inputs_format_.size()) {
|
||||
MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
|
||||
|
@ -133,6 +142,8 @@ const std::vector<KernelObjectType> &KernelBuildInfo::GetAllInputKernelObjectTyp
|
|||
return inputs_kernel_object_type_;
|
||||
}
|
||||
|
||||
void KernelBuildInfo::SetOpType(const OpType &op_type) { op_type_ = op_type; }
|
||||
|
||||
void KernelBuildInfo::SetOutputsKernelObjectType(const std::vector<KernelObjectType> &outputs_kernel_object_type) {
|
||||
outputs_kernel_object_type_ = outputs_kernel_object_type;
|
||||
}
|
||||
|
@ -235,6 +246,7 @@ std::string KernelBuildInfo::ToString() const {
|
|||
output_buffer << KernelObjectTypeLabel(output_object_types[index]);
|
||||
}
|
||||
output_buffer << "], kernel_type: " << KernelTypeLabel(kernel_type());
|
||||
output_buffer << ", op_type: " << OpTypeLabel(op_type());
|
||||
output_buffer << ")";
|
||||
return output_buffer.str();
|
||||
}
|
||||
|
@ -269,6 +281,11 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &ke
|
|||
kernel_build_info_->kernel_type_ = kernel_type;
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOpType(const OpType &op_type) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||
kernel_build_info_->op_type_ = op_type;
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||
kernel_build_info_->origin_data_format_ = origin_data_format;
|
||||
|
|
|
@ -40,8 +40,15 @@ enum KernelObjectType : int {
|
|||
TUPLE_UNFOLD,
|
||||
};
|
||||
|
||||
enum OpType : int {
|
||||
UNKNOWN_OP_TYPE = 0,
|
||||
DYNAMIC,
|
||||
SKIP,
|
||||
};
|
||||
|
||||
std::string KernelObjectTypeLabel(const KernelObjectType &obj_type);
|
||||
std::string KernelTypeLabel(const KernelType &kernel_type);
|
||||
std::string OpTypeLabel(const OpType &op_type);
|
||||
|
||||
class BACKEND_EXPORT KernelBuildInfo {
|
||||
public:
|
||||
|
@ -52,8 +59,13 @@ class BACKEND_EXPORT KernelBuildInfo {
|
|||
~KernelBuildInfo() = default;
|
||||
|
||||
KernelType kernel_type() const { return kernel_type_; }
|
||||
|
||||
OpType op_type() const { return op_type_; }
|
||||
|
||||
void set_kernel_type(KernelType kernel_type) { kernel_type_ = kernel_type; }
|
||||
|
||||
void set_op_type(OpType op_type) { op_type_ = op_type; }
|
||||
|
||||
std::string GetInputFormat(size_t input_index) const;
|
||||
|
||||
std::string GetOutputFormat(size_t output_index) const;
|
||||
|
@ -88,6 +100,8 @@ class BACKEND_EXPORT KernelBuildInfo {
|
|||
|
||||
const std::vector<KernelObjectType> &GetAllOutputKernelObjectTypes() const;
|
||||
|
||||
void SetOpType(const OpType &op_type);
|
||||
|
||||
void SetOutputsKernelObjectType(const std::vector<KernelObjectType> &outputs_kernel_object_type);
|
||||
|
||||
void SetInputsKernelObjectType(const std::vector<KernelObjectType> &inputs_kernel_object_type);
|
||||
|
@ -140,6 +154,7 @@ class BACKEND_EXPORT KernelBuildInfo {
|
|||
|
||||
private:
|
||||
KernelType kernel_type_{UNKNOWN_KERNEL_TYPE};
|
||||
OpType op_type_{UNKNOWN_OP_TYPE};
|
||||
std::string origin_data_format_{kOpFormat_DEFAULT};
|
||||
std::string core_type_;
|
||||
std::vector<std::string> inputs_format_;
|
||||
|
@ -166,6 +181,7 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
|
|||
explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info)
|
||||
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
|
||||
SetKernelType(kernel_build_info->kernel_type());
|
||||
SetOpType(kernel_build_info->op_type());
|
||||
SetFusionType(kernel_build_info->fusion_type());
|
||||
SetProcessor(kernel_build_info->processor());
|
||||
SetOpPattern(kernel_build_info->op_pattern());
|
||||
|
@ -196,6 +212,8 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
|
|||
|
||||
void SetKernelType(const KernelType &kernel_type);
|
||||
|
||||
void SetOpType(const OpType &op_type);
|
||||
|
||||
void SetOriginDataFormat(const std::string &origin_data_format);
|
||||
|
||||
void SetInputsFormat(const std::vector<std::string> &inputs_format);
|
||||
|
|
|
@ -283,6 +283,10 @@ void SetKernelBuildInfoWithSelectedAttr(const CNodePtr &kernel_node, const kerne
|
|||
input_formats.emplace_back(selected_kernel_attr.GetInputAttr(index).format);
|
||||
}
|
||||
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
|
||||
if (selected_kernel_attr.GetSkipCheck()) {
|
||||
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
kernel_build_info->SetOpType(kernel::OpType::SKIP);
|
||||
}
|
||||
kernel::SetKernelObjectTypeWithSelectedAttr(kernel_node, selected_kernel_attr);
|
||||
kernel::UnfoldKernelBuildInfo(kernel_node);
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
|
||||
|
|
|
@ -610,6 +610,10 @@ bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType ker
|
|||
auto output_object_types =
|
||||
kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllOutputObjectType(kernel_node));
|
||||
kernel::SetKernelObjectTypeBuildInfo(kernel_node, input_object_types, output_object_types);
|
||||
if (!kernel_attrs.empty()) {
|
||||
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
kernel_build_info->SetOpType(kernel::OpType::SKIP);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue