diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index cf237794159..808e87edc0b 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -39,8 +39,6 @@ namespace mindspore { namespace kernel { using FNodeAttrHandle = std::function &anf_node, mindspore::NodeDef *proto)>; -const std::vector local_framework_op_vec = {kInitData, kGetNext, kDropoutGenMask, kPrint}; - bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input_num, std::vector *input_size_list) { MS_EXCEPTION_IF_NULL(anf_node); @@ -298,19 +296,12 @@ KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node) { MS_EXCEPTION_IF_NULL(kernel_mod_ptr); kernel_mod_ptr->SetAnfNode(anf_node); kernel_mod_ptr->SetNodeName(op_name); - auto iter = std::find(local_framework_op_vec.begin(), local_framework_op_vec.end(), op_name); - if (iter != local_framework_op_vec.end()) { - if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) { - MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!"; - } - } else { - MS_LOG(EXCEPTION) << "Aicpu don't support node [" << op_name << "]"; + if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) { + MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!"; } - if (!SetIOSize(anf_node, kernel_mod_ptr)) { MS_LOG(EXCEPTION) << "Set input output size list failed."; } - return kernel_mod_ptr; } } // namespace kernel diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index d2464bce47d..c8cc1530ce3 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -94,6 +94,20 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) return ret; } +void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr& op_info) { + op_info->set_async_flag(obj.at(kAsyncFlag)); + op_info->set_binfile_name(obj.at(kBinfileName)); + op_info->set_compute_cost(obj.at(kComputeCost)); + op_info->set_kernel_name(obj.at(kKernelName)); + op_info->set_partial_flag(obj.at(kPartialFlag)); + if (obj.find(kOpPattern) != obj.end()) { + op_info->set_op_pattern(obj.at(kOpPattern)); + } + if (obj.find(kDynamicFormat) != obj.end()) { + op_info->set_dynamic_format(obj.at(kDynamicFormat)); + } +} + bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type, const std::string& impl_path) { std::shared_ptr op_info = std::make_shared(); @@ -103,17 +117,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI op_info->set_imply_type(imply_type); op_info->set_fusion_type(obj.at(kFusionType)); if (imply_type == kTBE) { - op_info->set_async_flag(obj.at(kAsyncFlag)); - op_info->set_binfile_name(obj.at(kBinfileName)); - op_info->set_compute_cost(obj.at(kComputeCost)); - op_info->set_kernel_name(obj.at(kKernelName)); - op_info->set_partial_flag(obj.at(kPartialFlag)); - if (obj.find(kOpPattern) != obj.end()) { - op_info->set_op_pattern(obj.at(kOpPattern)); - } - if (obj.find(kDynamicFormat) != obj.end()) { - op_info->set_dynamic_format(obj.at(kDynamicFormat)); - } + DecodeTBESpecificInfo(obj, op_info); } auto attrs = obj.at(kAttr); for (const auto& attr : attrs) { diff --git a/mindspore/ccsrc/kernel/oplib/oplib.h b/mindspore/ccsrc/kernel/oplib/oplib.h index a4c5e04bb19..0e11e28d580 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.h +++ b/mindspore/ccsrc/kernel/oplib/oplib.h @@ -40,6 +40,7 @@ class OpLib { const std::shared_ptr& op_info); static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, size_t index); + static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr& op_info); static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, const std::shared_ptr& op_info, const nlohmann::json& dtype_format); static bool GetRefInfo(const std::shared_ptr& op_info); diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 7dd7a9f729e..28821b621e8 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -57,7 +57,7 @@ def op_info_register(op_info): return register_decorator -class RegOp(): +class RegOp: """ Base class for op info register. @@ -483,9 +483,9 @@ class TBERegOp(RegOp): return self -class DataType(): +class DataType: """ - Various combinations of dtype and formatself. + Various combinations of dtype and format. The current list below maybe not completed. If necessary, please add it. """