diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 875c29547b9..f96d0cbebf3 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -20,7 +20,7 @@ #include "kernel/aicpu/aicpu_kernel_metadata.h" #include "kernel/rts/rt_kernel_info.h" #include "kernel/hccl/hccl_kernel_metadata.h" -#include "kernel/tbe/tbe_kernel_select.h" +#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "session/anf_runtime_algorithm.h" namespace mindspore { @@ -63,7 +63,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorempty()) { AicpuMetadataInfo(kernel_node, kernel_info_list); if (!kernel_info_list->empty()) { @@ -114,7 +113,6 @@ bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr auto cnode = kernel_node->cast(); MS_EXCEPTION_IF_NULL(cnode); TbeMetadataInfo(cnode, &kernel_info_list); - FilterInvalidKernelInfo(cnode, &kernel_info_list); return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { MS_EXCEPTION_IF_NULL(item); diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h index 8d7b543ea62..4d133085982 100644 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ b/mindspore/ccsrc/kernel/oplib/opinfo.h @@ -126,6 +126,8 @@ class OpInfo { bool is_ref() const { return !ref_infos_.empty(); } bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } + void ClearInputs() { (void)inputs_ptr_.clear(); } + void ClearOutputs() { (void)outputs_ptr_.clear(); } private: std::string op_name_; diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index b1bff365189..1b367120b41 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -35,7 +35,7 @@ constexpr auto kKernelName = "kernel_name"; constexpr auto kPartialFlag = "partial_flag"; constexpr auto kReshapeType = "reshape_type"; constexpr auto kOpPattern = "op_pattern"; -constexpr auto kDynamicFormat = "dynamic_format"; +constexpr auto kDynamicFormat = "dynamicFormat"; constexpr auto kFormatAgnostic = "formatAgnostic"; constexpr auto kBroadcast = "broadcast"; constexpr auto kReduce = "reduce"; @@ -100,7 +100,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { const std::map kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, - {kFormatAgnostic, kBroadcastPattern}, + {kBroadcast, kBroadcastPattern}, {kReduce, kReducePattern}, {kDynamicFormat, kDynamicFormatPattern}}; op_info->set_async_flag(obj.at(kAsyncFlag)); @@ -108,14 +108,19 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p op_info->set_compute_cost(obj.at(kComputeCost)); op_info->set_kernel_name(obj.at(kKernelName)); op_info->set_partial_flag(obj.at(kPartialFlag)); + if (obj.find(kOpPattern) != obj.end()) { - if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) { - op_info->set_op_pattern(obj.at(kOpPattern)); + std::string op_pattern = obj.at(kOpPattern); + auto find_iter = kOpPatternMap.find(op_pattern); + if (find_iter == kOpPatternMap.end()) { + if (!op_pattern.empty()) { + MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern; + } + op_info->set_op_pattern(kCommonPattern); + } else { + op_info->set_op_pattern(find_iter->second); } } - if (obj.find(kDynamicFormat) != obj.end()) { - op_info->set_dynamic_format(obj.at(kDynamicFormat)); - } } bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, diff --git a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc b/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc index 1159bd888d9..e6f5b3985be 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc @@ -45,7 +45,7 @@ const std::map type_id_str_maps = { {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, - {TypeId::kNumberTypeBool, "bool"}, + {TypeId::kNumberTypeBool, "int8"}, }; const std::map type_str_maps = { @@ -85,7 +85,7 @@ std::string DtypeToString(const std::string &dtypes) { std::string TypeIdToString(TypeId type_id) { auto iter = type_id_str_maps.find(type_id); if (iter == type_id_str_maps.end()) { - MS_LOG(EXCEPTION) << "Illegal input dtype." << TypeIdLabel(type_id); + MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id); } return iter->second; } diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc index bd5b0d63231..86d6801a608 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc @@ -111,41 +111,20 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr &anf_node, if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); } else { - // dtype : float16 - auto tensor_dtype = - std::make_shared(TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index))); - MS_EXCEPTION_IF_NULL(tensor_dtype); - std::string dtype = tensor_dtype->element()->ToString(); - dtype = tbe::DtypeToString(dtype); - - // format - std::string format = AnfAlgo::GetInputFormat(anf_node, real_input_index); - if (format == kOpFormat_DEFAULT) { - format = kOpFormat_NCHW; - } else if (format == kOpFormat_FRAC_Z) { - format = kOpFormat_FRACTAL_Z; - } - - nlohmann::json input_desc_json; - input_desc_json["dtype"] = dtype; - input_desc_json["name"] = op_input_name + std::to_string(input_i); + auto dtype = GetDeviceInputType(anf_node, real_input_index); + auto format = GetDeviceInputFormat(anf_node, real_input_index); + auto shape = GetDeviceInputShape(anf_node, real_input_index); auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); if (ori_shape.empty()) { ori_shape.emplace_back(1); } + nlohmann::json input_desc_json; + input_desc_json["dtype"] = dtype; + input_desc_json["name"] = op_input_name + std::to_string(input_i); input_desc_json["ori_shape"] = ori_shape; input_desc_json["ori_format"] = kOpFormat_NCHW; - auto shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); - if (shape.empty()) { - shape.emplace_back(1); - } - if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { - input_desc_json["shape"] = ori_shape; - input_desc_json["format"] = kOpFormat_NCHW; - } else { - input_desc_json["shape"] = shape; - input_desc_json["format"] = format; - } + input_desc_json["shape"] = shape; + input_desc_json["format"] = format; input_desc_json["valid"] = value; input_desc_json["param_type"] = input_ptr->param_type(); input_list->emplace_back(input_desc_json); @@ -325,40 +304,22 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr &anf_node, co MS_EXCEPTION_IF_NULL(output_idx); MS_EXCEPTION_IF_NULL(output_list); for (size_t i = 0; i < output_obj_num; i++) { - nlohmann::json output_obj; - auto type_ptr = std::make_shared(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, *output_idx))); - std::string dtype = type_ptr->element()->ToString(); - dtype = tbe::DtypeToString(dtype); - std::string format = AnfAlgo::GetOutputFormat(anf_node, *output_idx); - if (format == kOpFormat_DEFAULT) { - format = kOpFormat_NCHW; - } else if (format == kOpFormat_FRAC_Z) { - format = kOpFormat_FRACTAL_Z; - } - std::vector ori_shape; - if (AnfAlgo::GetOutputInferShape(anf_node, *output_idx).empty()) { + auto dtype = GetDeviceOutputType(anf_node, *output_idx); + auto format = GetDeviceOutputFormat(anf_node, *output_idx); + auto shape = GetDeviceOutputShape(anf_node, *output_idx); + std::vector ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); + if (ori_shape.empty()) { ori_shape.emplace_back(1); - } else { - ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); } + nlohmann::json output_obj; output_obj["dtype"] = dtype; - auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, *output_idx); - if (shape.empty()) { - shape.emplace_back(1); - } - if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { - output_obj["shape"] = ori_shape; - output_obj["format"] = kOpFormat_NCHW; - } else { - output_obj["shape"] = shape; - output_obj["format"] = format; - } + output_obj["shape"] = shape; + output_obj["format"] = format; output_obj["ori_shape"] = ori_shape; output_obj["ori_format"] = kOpFormat_NCHW; output_obj["name"] = output_ptr->name(); output_obj["valid"] = true; output_obj["param_type"] = output_ptr->param_type(); - output_list->emplace_back(output_obj); (*output_idx)++; } @@ -456,6 +417,84 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo } } +std::vector TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector shape; + if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { + shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index); + } else { + shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index); + } + if (shape.empty()) { + shape.emplace_back(1); + } + return shape; +} + +std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + TypeId type_id; + if (creater_type_ == OP_SELECT_FORMAT) { + type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index); + } else { + type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index); + } + return tbe::TypeIdToString(type_id); +} + +std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::string format = kOpFormat_NCHW; + if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { + format = AnfAlgo::GetInputFormat(anf_node, real_index); + if (format == kOpFormat_FRAC_Z) { + format = kOpFormat_FRACTAL_Z; + } else if (format == kOpFormat_DEFAULT) { + format = kOpFormat_NCHW; + } + } + return format; +} + +std::vector TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector shape; + if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { + shape = AnfAlgo::GetOutputInferShape(anf_node, real_index); + } else { + shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index); + } + if (shape.empty()) { + shape.emplace_back(1); + } + return shape; +} + +std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + TypeId type_id; + if (creater_type_ == OP_SELECT_FORMAT) { + type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index); + } else { + type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index); + } + return tbe::TypeIdToString(type_id); +} + +std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::string format = kOpFormat_NCHW; + if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { + format = AnfAlgo::GetOutputFormat(anf_node, real_index); + if (format == kOpFormat_FRAC_Z) { + format = kOpFormat_FRACTAL_Z; + } else if (format == kOpFormat_DEFAULT) { + format = kOpFormat_NCHW; + } + } + return format; +} + bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, std::vector *output_size_list) { if (input_size_list == nullptr || output_size_list == nullptr) { diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h index 2ddab34d495..eef02efa87e 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h @@ -93,7 +93,7 @@ class TbeKernelJsonCreator { nlohmann::json *outputs_json); bool GenTbeAttrJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, nlohmann::json *attrs_json); - void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); + static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); bool GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, bool value, const std::shared_ptr &input_ptr, const string &op_input_name, size_t input_i, std::vector *input_list); @@ -105,6 +105,13 @@ class TbeKernelJsonCreator { void GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, const std::shared_ptr &output_ptr, size_t *output_idx, std::vector *output_list); + std::vector GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const; + std::vector GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const; + kCreaterType creater_type_; std::string json_name_; std::string json_info_; diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc index e3973d2403a..7f6fc2dc25b 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc @@ -230,7 +230,7 @@ std::pair ParallelBuildManager::TaskFinishProcess(int32_t task_iter->second.output_size_list, kernel_pack); MS_EXCEPTION_IF_NULL(kernel_mod); if (set_kernel_mod) { - AnfAlgo ::SetKernelMod(kernel_mod, task_iter->second.node); + AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); } auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); (void)task_map_.erase(task_iter); diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc deleted file mode 100644 index aedb0b3eafe..00000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ /dev/null @@ -1,664 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_select.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "nlohmann/json.hpp" -#include "common/utils.h" -#include "utils/context/ms_context.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "pre_activate/common/helper.h" -#include "kernel/tbe/tbe_convert_utils.h" - -namespace mindspore { -namespace kernel { -constexpr auto kName = "name"; -constexpr auto kDtype = "dtype"; -constexpr auto kFormat = "format"; -constexpr auto kPrefixInput = "input"; -constexpr auto kPrefixOutput = "output"; -const std::map DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"}, - {"NHWC", "DefaultFormat"}, - {"ND", "DefaultFormat"}, - {"FRACTAL_Z", "FracZ"}, - {"NDHWC", "DefaultFormat"}}; -static const std::vector CHECK_SUPPORTED_OPTYPE{ - "MatMul", "BatchMatMul", "TopK", "InTopK", "Pack", "GatherNd", "UnsortedSegmentMinD", "UnsortedSegmentProdD", "Cast"}; - -bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(select_kernel_build_info); - - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto iter = std::find(CHECK_SUPPORTED_OPTYPE.begin(), CHECK_SUPPORTED_OPTYPE.end(), op_name); - if (iter == CHECK_SUPPORTED_OPTYPE.end()) { - MS_LOG(DEBUG) << "Op " << op_name << "this op does not need to check op supported."; - return true; - } - - // replace kernel_info with current kernel info - auto ori_select_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(anf_node); - AnfAlgo::SetSelectKernelBuildInfo(select_kernel_build_info, anf_node.get()); - - nlohmann::json kernel_json; - TbeKernelJsonCreator creator(CHECK_SUPPORTED); - bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json); - if (!ret) { - MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed"; - AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get()); - return false; - } - - ret = TbePythonFuncs::CheckSupported(kernel_json); - AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get()); - return ret; -} - -bool CheckJsonItemValidity(const nlohmann::json &json_obj, const std::string &key_name, - const std::vector &keys) { - if (!json_obj[key_name].is_object()) { - MS_LOG(DEBUG) << key_name << "is not an object!"; - return false; - } - for (auto key : keys) { - if (json_obj[key_name].find(key) == json_obj[key_name].end()) { - MS_LOG(DEBUG) << "Key" << key << "of " << key_name << " is not found!"; - return false; - } - } - return true; -} - -std::vector SplitStr(const std::string &string, const std::string &sep) { - std::vector result; - size_t start = 0; - size_t index = string.find(sep, start); - std::string substr; - while (index != std::string::npos) { - if (string.size() > start) { - substr = string.substr(start, index - start); - } - (void)substr.erase(0, substr.find_first_not_of(' ')); - (void)substr.erase(substr.find_last_not_of(' ') + 1); - auto iter = DYNAMIC_FORMAT_MAP.find(substr); - if (iter != DYNAMIC_FORMAT_MAP.end()) { - substr = iter->second; - } - result.push_back(substr); - start = index + sep.size(); - index = string.find(sep, start); - } - - if (string.size() > start) { - substr = string.substr(start); - } - (void)substr.erase(0, substr.find_first_not_of(' ')); - (void)substr.erase(substr.find_last_not_of(' ') + 1); - auto iter = DYNAMIC_FORMAT_MAP.find(substr); - if (iter != DYNAMIC_FORMAT_MAP.end()) { - substr = iter->second; - } - result.push_back(substr); - return result; -} - -void ConvertFormatDtype(const std::string &format, const std::string &dtype, const std::shared_ptr &io_info) { - MS_EXCEPTION_IF_NULL(io_info); - std::vector format_vec = SplitStr(format, ","); - std::vector dtype_vec = SplitStr(dtype, ","); - io_info->set_formats(format_vec); - io_info->set_dtypes(dtype_vec); -} - -bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector> *const inputs, - std::vector> *const outputs) { - nlohmann::json json_obj = nlohmann::json::parse(jsonStr); - if (!json_obj.is_object()) { - MS_LOG(DEBUG) << "JsonStr is not an object, the jsonStr is:" << jsonStr; - return false; - } - std::vector keys = {kName, kDtype, kFormat}; - for (const auto &item : json_obj.items()) { - std::string key_name; - key_name = item.key(); - if (key_name.empty()) { - MS_LOG(DEBUG) << "Key name is empty!"; - return false; - } - if (!CheckJsonItemValidity(json_obj, key_name, keys)) { - return false; - } - if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) { - std::shared_ptr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->set_name(json_obj[key_name].at(kName)); - ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input); - inputs->emplace_back(input); - } else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) { - std::shared_ptr output = std::make_shared(); - MS_EXCEPTION_IF_NULL(output); - output->set_name(json_obj[key_name].at(kName)); - ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), output); - outputs->emplace_back(output); - } else { - MS_LOG(DEBUG) << "Key name:" << key_name << " is undefined!"; - return false; - } - } - return true; -} - -std::string OpSelectFormat(const std::shared_ptr &anf_node) { - nlohmann::json kernel_json; - std::string res_json_str; - TbeKernelJsonCreator creator(OP_SELECT_FORMAT); - bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json); - if (!ret) { - MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed"; - return res_json_str; - } - res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); - MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; - return res_json_str; -} - -void SetTidyInputsInfo(const std::shared_ptr &anf_node, - const std::shared_ptr &builder, - const std::vector> &inputs) { - std::vector inputs_type; - std::vector inputs_format; - std::vector dyn_input_sizes; - size_t dyn_input_idx = 0; - size_t kernel_info_index = 0; - size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("dyn_input_sizes") != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr("dyn_input_sizes")); - } - for (size_t i = 0; i < inputs.size(); i++) { - MS_EXCEPTION_IF_NULL(inputs[i]); - std::string param_type = inputs[i]->param_type(); - if (i >= real_input_num) { - MS_LOG(INFO) << "Input index: " << i << " is out of real_input_num:" << real_input_num; - continue; - } - auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, i); - auto format = kOpFormat_DEFAULT; - if (param_type == "dynamic") { - if (!dyn_input_sizes.empty()) { - for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { - kernel_info_index++; - inputs_type.emplace_back(type_id); - inputs_format.emplace_back(format); - } - dyn_input_idx++; - } - } else if (param_type == "required") { - kernel_info_index++; - inputs_type.emplace_back(type_id); - inputs_format.emplace_back(format); - } else { - if (kernel_info_index < real_input_num) { - MS_LOG(INFO) << "Input type is optional, input index is :" << kernel_info_index; - kernel_info_index++; - inputs_type.emplace_back(type_id); - inputs_format.emplace_back(format); - } - } - } - builder->SetInputsDeviceType(inputs_type); - builder->SetInputsFormat(inputs_format); -} - -void SetTidyOutputsInfo(const std::shared_ptr &anf_node, - const std::shared_ptr &builder, - const std::vector> &outputs) { - std::vector outputs_type; - std::vector outputs_format; - auto real_output_num = AnfAlgo::GetOutputTensorNum(anf_node); - size_t output_idx = 0; - for (const auto &output : outputs) { - MS_EXCEPTION_IF_NULL(output); - if (output_idx >= real_output_num) { - continue; - } - size_t output_num = 0; - if (output->param_type() == "dynamic") { - if (outputs.size() > 1) { - MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!"; - } - output_num = real_output_num; - } else if (output->param_type() == "required") { - output_num = 1; - } else { - if (output_idx < real_output_num) { - MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; - output_num = 1; - } - } - for (size_t i = 0; i < output_num; i++) { - auto type_id = AnfAlgo::GetOutputInferDataType(anf_node, output_idx); - outputs_type.emplace_back(type_id); - outputs_format.emplace_back(kOpFormat_DEFAULT); - output_idx++; - } - } - builder->SetOutputsDeviceType(outputs_type); - builder->SetOutputsFormat(outputs_format); -} - -void GenTidyKernelBuildInfo(const std::shared_ptr &anf_node, - const std::vector> &inputs, - const std::vector> &outputs) { - auto builder_tmp = std::make_shared(); - builder_tmp->SetKernelType(TBE_KERNEL); - SetTidyInputsInfo(anf_node, builder_tmp, inputs); - SetTidyOutputsInfo(anf_node, builder_tmp, outputs); - AnfAlgo::SetSelectKernelBuildInfo(builder_tmp->Build(), anf_node.get()); -} - -void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, - const std::shared_ptr &op_info_new_ptr) { - std::vector> inputs_static = op_info_ptr->inputs_ptr(); - std::vector> outputs_static = op_info_ptr->outputs_ptr(); - std::vector> inputs_dyn; - std::vector> outputs_dyn; - if ((op_info_ptr->imply_type() == kTBE) && (!mindspore::opt::IsNopNode(kernel_node->cast()))) { - // 1. create tidy kernelBuildInfo in order to generate json for calling op_select_format - auto anf_node = kernel_node->cast>(); - auto kernel_build_info_ptr = AnfAlgo::GetSelectKernelBuildInfo(anf_node); - GenTidyKernelBuildInfo(kernel_node, inputs_static, outputs_static); - - // 2.get dynamic format from op_impl - std::string res_json_str; - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->execution_mode() != kPynativeMode) { - res_json_str = OpSelectFormat(kernel_node); - } - if (!res_json_str.empty()) { - (void)ParseDynamicFormatJson(res_json_str, &inputs_dyn, &outputs_dyn); - } - if (inputs_static.size() != inputs_dyn.size()) { - inputs_dyn.clear(); - } - if (outputs_static.size() != outputs_dyn.size()) { - outputs_dyn.clear(); - } - - // 3. resume kernel node's SelectKernelBuildInfo - // As it has been replaced by GenTidyKernelBuildInfo in order to call python func - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_ptr, anf_node.get()); - } - // 4.replace by dynamic format and dtype - if (inputs_dyn.empty() && outputs_dyn.empty()) { - MS_LOG(INFO) << "Dynamic select format response is empty, use static register info."; - op_info_new_ptr->set_inputs_ptr(inputs_static); - op_info_new_ptr->set_outputs_ptr(outputs_static); - } else { - MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format."; - for (size_t i = 0; i < inputs_static.size(); i++) { - inputs_dyn[i]->set_param_type(inputs_static[i]->param_type()); - inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type()); - } - for (size_t j = 0; j < outputs_static.size(); j++) { - outputs_dyn[j]->set_param_type(outputs_static[j]->param_type()); - outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type()); - } - op_info_new_ptr->set_inputs_ptr(inputs_dyn); - op_info_new_ptr->set_outputs_ptr(outputs_dyn); - } - - // 5.copy other opinfo to new op_info_new - op_info_new_ptr->set_op_name(op_info_ptr->op_name()); - op_info_new_ptr->set_imply_type(op_info_ptr->imply_type()); - op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type()); -} - -bool StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec) { - for (const auto &c : reshape_type_str) { - switch (c) { - case 'N': - reshape_type_vec->push_back(kernel::N); - break; - case 'C': - reshape_type_vec->push_back(kernel::C); - break; - case 'H': - reshape_type_vec->push_back(kernel::H); - break; - case 'W': - reshape_type_vec->push_back(kernel::W); - break; - default: - MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type."; - return false; - } - } - return true; -} - -bool SetKernelBuilderInputInfo(const std::vector> &inputs, size_t real_input_num, - size_t builder_idex, const std::vector &dyn_input_sizes, - const std::shared_ptr &builder) { - MS_EXCEPTION_IF_NULL(builder); - - std::vector inputs_device_type; - std::vector inputs_format; - size_t dyn_input_idx = 0; - size_t kernel_info_index = 0; - MS_EXCEPTION_IF_NULL(inputs[0]); - size_t kernel_info_cnt = inputs[0]->dtypes().size(); - - std::vector> reshape_types; - for (const auto &input : inputs) { - MS_EXCEPTION_IF_NULL(input); - std::string param_type = input->param_type(); - std::vector dtypes = input->dtypes(); - std::vector formats = input->formats(); - if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { - MS_LOG(ERROR) << "Set input kernel builder info, dtyps size != formats size."; - return false; - } - - std::vector reshape_type; - if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { - return false; - } - - if (param_type == "dynamic") { - if (dyn_input_sizes.empty()) { - MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; - return false; - } - - for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { - kernel_info_index++; - auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - reshape_types.push_back(reshape_type); - } - dyn_input_idx++; - } else if (param_type == "required") { - kernel_info_index++; - auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - reshape_types.push_back(reshape_type); - } else { - if (kernel_info_index < real_input_num) { - MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index; - kernel_info_index++; - auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - reshape_types.push_back(reshape_type); - } - } - } - - builder->SetInputReshapeType(reshape_types); - builder->SetInputsDeviceType(inputs_device_type); - builder->SetInputsFormat(inputs_format); - return true; -} - -bool SetKernelBuilderOutputInfo(const std::vector> &outputs, size_t builder_idex, - const size_t &real_output_num, - const std::shared_ptr &builder) { - // not now but in the next we need to support dynamic output case - MS_EXCEPTION_IF_NULL(builder); - - size_t output_idx = 0; - std::vector outputs_device_type; - std::vector outputs_format; - MS_EXCEPTION_IF_NULL(outputs[0]); - size_t kernel_info_cnt = outputs[0]->dtypes().size(); - - std::vector> reshape_types; - for (const auto &output : outputs) { - MS_EXCEPTION_IF_NULL(output); - if (output_idx >= real_output_num) { - MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!"; - continue; - } - std::vector reshape_type; - if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { - return false; - } - - size_t output_num = 0; - if (output->param_type() == "dynamic") { - if (outputs.size() > 1) { - MS_LOG(EXCEPTION) << "Dynamic output is unsupported multi output!"; - } - output_num = real_output_num; - } else if (output->param_type() == "required") { - output_num = 1; - } else { - if (output_idx < real_output_num) { - MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is " << output_idx; - output_num = 1; - } - } - - for (size_t i = 0; i < output_num; i++) { - std::vector dtypes = output->dtypes(); - std::vector formats = output->formats(); - if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { - MS_LOG(ERROR) << "Set output kernel builder info, dtyps size != formats size."; - return false; - } - auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); - outputs_device_type.push_back(type_id); - outputs_format.push_back(formats[builder_idex]); - reshape_types.push_back(reshape_type); - output_idx++; - } - } - - builder->SetOutputReshapeType(reshape_types); - builder->SetOutputsFormat(outputs_format); - builder->SetOutputsDeviceType(outputs_device_type); - return true; -} - -void SetKernelBuildCommonInfo(const std::shared_ptr &builder, - Processor processor, const std::shared_ptr &op_info_ptr) { - MS_EXCEPTION_IF_NULL(builder); - MS_EXCEPTION_IF_NULL(op_info_ptr); - - builder->SetProcessor(processor); - std::string fusion_type = op_info_ptr->fusion_type(); - if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { - builder->SetFusionType(tbe::GetFusionType(fusion_type)); - } - builder->SetOpPattern(op_info_ptr->op_pattern()); - builder->SetKernelType(TBE_KERNEL); -} - -bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, - std::vector> *const kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - std::vector> inputs = op_info_ptr->inputs_ptr(); - std::vector> outputs = op_info_ptr->outputs_ptr(); - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("dyn_input_sizes") != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr("dyn_input_sizes")); - } - if (!inputs.empty()) { - MS_EXCEPTION_IF_NULL(inputs[0]); - size_t kernel_info_cnt = inputs[0]->dtypes().size(); - for (size_t j = 0; j < kernel_info_cnt; j++) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr); - - if (!SetKernelBuilderInputInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) { - MS_LOG(ERROR) << "Parse kernel metadata, set inputs kernel builder info failed."; - return false; - } - - if (!outputs.empty()) { - if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) { - MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed."; - return false; - } - } - - kernel_info_list->push_back(builder->Build()); - } - } else if (!outputs.empty()) { - MS_EXCEPTION_IF_NULL(outputs[0]); - size_t kernel_info_cnt = outputs[0]->dtypes().size(); - for (size_t j = 0; j < kernel_info_cnt; j++) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr); - - if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) { - MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed."; - return false; - } - - kernel_info_list->push_back(builder->Build()); - } - } - return true; -} - -bool IsShapeMatchFormat(const std::vector &shape, const std::string &format) { - // if format is default, it remarkes support all format - if (kOpFormatList.find(format) == kOpFormatList.end()) { - MS_LOG(EXCEPTION) << "Got the unknown format " << format; - } - if (format == kOpFormat_DEFAULT) { - return true; - } - if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) { - return false; - } - // if shape size is 0, the shape will be a scalar - if (shape.empty()) { - return true; - } - if (shape.size() > kShape4dDims) { - return false; - } - if (format == kOpFormat_FRAC_NZ && shape.size() < 2) { - return false; - } - return true; -} - -bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - const size_t kCAxis = 1; - for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); - if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) { - if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) { - return false; - } - return false; - } - if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { - return false; - } - if (kernel_name == "ReduceMean") { - auto keep_dims = AnfAlgo::GetNodeAttr(kernel_node, kAttrKeepDims); - if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) { - return false; - } - } - } - for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); - if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) { - return false; - } - if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) { - if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) { - return false; - } - return false; - } - if (kernel_name == "ReduceMean") { - auto keep_dims = AnfAlgo::GetNodeAttr(kernel_node, kAttrKeepDims); - if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) { - return false; - } - } - } - if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { - return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && - AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); - } - return true; -} - -void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - std::vector> parse_info_list; - - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); - if (op_info_ptr == nullptr) { - return; - } - // dynamic get op format and dtype and replace opinfo - auto op_info_new_ptr = std::make_shared(); - ReplaceByDynamicFormatDtype(kernel_node, op_info_ptr, op_info_new_ptr); - - if (!ParseMetadata(kernel_node, op_info_new_ptr, &parse_info_list)) { - MS_LOG(INFO) << "Tbe parsed metadata of op[" << op_name << "] failed."; - return; - } - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - for (const auto &parse_info : parse_info_list) { - if (IsValidKernelInfo(kernel_node, *(parse_info))) { - if (CheckSupported(kernel_node, parse_info)) { - kernel_info_list->push_back(parse_info); - } else { - MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; - } - } - if (kernel_info_list->empty()) { - MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "]."; - } - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h similarity index 62% rename from mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h rename to mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h index 3ce66b51484..c07197610e9 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.h +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,20 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef MINDSPORE_TBE_KERNEL_SELECT_H -#define MINDSPORE_TBE_KERNEL_SELECT_H - +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_ #include #include -#include -#include "kernel/oplib/opinfo.h" -#include "kernel/kernel_build_info.h" - namespace mindspore { namespace kernel { -void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +struct SupportFormat { + std::vector> input_format; + std::vector> output_format; +}; +using SupportFormatItem = std::vector; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_TBE_KERNEL_SELECT_H +#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc new file mode 100644 index 00000000000..9d28af3f3fe --- /dev/null +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc @@ -0,0 +1,319 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" +#include "utils/utils.h" +#include "session/anf_runtime_algorithm.h" +#include "kernel/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr char kDynInputKey[] = "dyn_input_sizes"; +constexpr size_t kInputIndex_0 = 0; +constexpr size_t kChannelN = 0; +constexpr size_t kChannelC = 1; +constexpr size_t kAlignmented16 = 16; +// 1. all shape no scalar and same +// 2. part scalar : no_scalar (shape size > xxx && alig xxx) +// 3. all no_scalar and not same (broad cast xxx dim) +bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) { + MS_EXCEPTION_IF_NULL(support_format); + input_num_ = 0; + output_num_ = 0; + input_shapes_.clear(); + output_shapes_.clear(); + if (AnfAlgo::HasNodeAttr(kDynInputKey, cnode_ptr_)) { + MS_LOG(INFO) << "This broadcast node has dynamic input."; + auto dynamic_size_vec = AnfAlgo::GetNodeAttr>(cnode_ptr_, kDynInputKey); + if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) { + MS_LOG(EXCEPTION) << "dynamic attr set error, please check."; + } + auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); + PadScalarShape(&dynamic_input_shape0_); + input_shapes_.emplace_back(dynamic_input_shape0_); + input_num_ = 1; + } else { + input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); + for (size_t i = 0; i < input_num_; ++i) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); + PadScalarShape(&input_shape); + input_shapes_.emplace_back(input_shape); + } + } + + output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + for (size_t i = 0; i < output_num_; ++i) { + auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i); + PadScalarShape(&output); + output_shapes_.emplace_back(output); + } + AssignSupportFormat(kOpFormat_DEFAULT, support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_NC1HWC0, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelC] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_NC1HWC0); + } + } + } else { + for (const auto &shape : input_shapes_) { + if (!Is4DShape(shape)) { + return false; + } + } + auto shape_tmp = input_shapes_[0]; + auto broadcast_c_axis = std::any_of( + input_shapes_.begin(), input_shapes_.end(), + [&shape_tmp](const std::vector &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); + if (broadcast_c_axis) { + MS_LOG(INFO) << "This node broadcast c channel."; + return false; + } + input_support_format.assign(input_num_, kOpFormat_NC1HWC0); + } + GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_FRAC_Z, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_FRAC_Z); + } + } + } else { + return false; + } + GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} +bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelN] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_C1HWNCoC0); + } + } + } else { + for (const auto &shape : input_shapes_) { + if (!Is4DShape(shape)) { + return false; + } + } + auto shape_tmp = input_shapes_[0]; + auto broadcast_nc_axis = + std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { + return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN)); + }); + if (broadcast_nc_axis) { + MS_LOG(INFO) << "This node broadcast n || c channel."; + return false; + } + input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0); + } + GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (shape.size() < kShape2dDims) { + return false; + } + if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_FRAC_NZ); + } + } + } else { + auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(), + [](const std::vector &elem) { return elem.size() < kShape2dDims; }); + if (less_2dims) { + MS_LOG(INFO) << "This node dim less 2."; + return false; + } + + auto shape_tmp = input_shapes_[0]; + auto broadcast_last_dim = + std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { + return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) || + (shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2)); + }); + if (broadcast_last_dim) { + MS_LOG(INFO) << "This node broadcast last channel."; + return false; + } + + input_support_format.assign(input_num_, kOpFormat_FRAC_NZ); + } + GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + return false; +} + +bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector &shape) const { + return shape.size() == kShape4dDims; +} + +bool TbeKernelBroadCastSelecter::IsSameShape() const { + auto shape = input_shapes_.begin(); + for (const auto &item : input_shapes_) { + if (shape->size() != item.size()) { + return false; + } + for (size_t i = 0; i < shape->size(); ++i) { + if (shape->at(i) != item.at(i)) { + return false; + } + } + } + return true; +} + +void TbeKernelBroadCastSelecter::PadScalarShape(std::vector *shape) const { + MS_EXCEPTION_IF_NULL(shape); + if (shape->empty()) { + shape->emplace_back(1); + } +} + +bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector &shape) const { + return (shape.size() == 1 && shape[0] == 1); +} + +bool TbeKernelBroadCastSelecter::HasScalarInput() const { + bool ret = false; + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + ret = true; + break; + } + } + return ret; +} + +void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format, + SupportFormatItem *output_support_item) const { + MS_EXCEPTION_IF_NULL(output_support_item); + for (const auto &shape : output_shapes_) { + if (IsScalarShape(shape)) { + output_support_item->emplace_back(kOpFormat_DEFAULT); + } else { + output_support_item->emplace_back(support_format); + } + } +} + +void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + input_support_format.assign(input_num_, support_format_str); + output_support_format.assign(output_num_, support_format_str); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h new file mode 100644 index 00000000000..2beb81927ff --- /dev/null +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "kernel/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +class TbeKernelBroadCastSelecter { + public: + explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} + ~TbeKernelBroadCastSelecter() = default; + bool GetShapeInfo(SupportFormat *support_format); + bool IsBroadCastSupport5HD(SupportFormat *support_format) const; + bool IsBroadCastSupportFracZ(SupportFormat *support_format) const; + bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const; + bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const; + bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const; + + private: + bool IsSameShape() const; + void PadScalarShape(std::vector *shape) const; + bool Is4DShape(const std::vector &shape) const; + bool IsScalarShape(const std::vector &shape) const; + bool HasScalarInput() const; + void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; + void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; + // broadcast + CNodePtr cnode_ptr_; + size_t input_num_{}; + size_t output_num_{}; + std::vector> input_shapes_; + std::vector> output_shapes_; +}; + +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc new file mode 100644 index 00000000000..d62f3cea2f5 --- /dev/null +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" +#include +#include +#include "utils/utils.h" +#include "session/anf_runtime_algorithm.h" +#include "kernel/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr char kKeepDims[] = "keep_dims"; +constexpr char kAxis[] = "axis"; +constexpr char kTypeInt32[] = "Int32"; +constexpr size_t kInputIndex_0 = 0; +constexpr size_t kOutputIndex_0 = 0; +constexpr size_t kChannelN = 0; +constexpr size_t kChannelC = 1; +constexpr size_t kReduceNZMinDim = 3; + +bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) { + MS_EXCEPTION_IF_NULL(support_format); + input_shape_.clear(); + output_shape_.clear(); + axis_.clear(); + auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); + auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + if (input_num != 1 || output_num != 1) { + MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num + << ", output num: " << output_num; + } + // get input/output shape + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); + PadScalarShape(&input_shape_); + output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0); + PadScalarShape(&output_shape_); + // get keep dim attr + GetReduceAttrKeepDim(); + // get axis attr + GetReduceAttrAxis(); + AssignSupportFormat(kOpFormat_DEFAULT, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (!Is4DShape(input_shape_)) { + return false; + } + if (!keep_dims_ || axis_.empty()) { + return false; + } + auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); + if (reduce_c_axis) { + return false; + } + AssignSupportFormat(kOpFormat_NC1HWC0, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + // like to 5HD + return false; +} + +bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { + return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format); +} + +bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const { + return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format); +} + +bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (input_shape_.size() < kReduceNZMinDim) { + return false; + } + if (axis_.empty()) { + return false; + } + auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) { + return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2)); + }); + if (reduce_last_axis) { + return false; + } + AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (!Is4DShape(input_shape_)) { + return false; + } + if (!keep_dims_ || axis_.empty()) { + return false; + } + auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(), + [](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); }); + if (reduce_n_c_axis) { + return false; + } + AssignSupportFormat(format, support_format); + return true; +} + +void TbeKernelReduceSelecter::GetReduceAttrAxis() { + auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); + MS_EXCEPTION_IF_NULL(primitive); + auto axis = primitive->GetAttr(kAxis); + if (axis == nullptr) { + MS_LOG(INFO) << "This node does't have axie attr."; + return; + } + auto type = axis->type(); + MS_EXCEPTION_IF_NULL(type); + std::vector axis_list; + if (type->ToString() == kTypeInt32) { + axis_list.emplace_back(GetValue(axis)); + } else { + axis_list = GetValue>(axis); + } + for (const auto &elem : axis_list) { + if (elem < 0) { + axis_.emplace_back(input_shape_.size() + elem); + } else { + axis_.emplace_back(IntToSize(elem)); + } + } +} + +void TbeKernelReduceSelecter::GetReduceAttrKeepDim() { + if (!AnfAlgo::HasNodeAttr(kKeepDims, cnode_ptr_)) { + MS_LOG(INFO) << "This node does't have keep_attr."; + keep_dims_ = false; + return; + } + keep_dims_ = AnfAlgo::GetNodeAttr(cnode_ptr_, kKeepDims); +} + +void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + input_support_format.emplace_back(support_format_str); + output_support_format.emplace_back(support_format_str); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); +} + +bool TbeKernelReduceSelecter::Is4DShape(const std::vector &shape) const { return shape.size() == kShape4dDims; } + +void TbeKernelReduceSelecter::PadScalarShape(std::vector *shape) const { + MS_EXCEPTION_IF_NULL(shape); + if (shape->empty()) { + shape->emplace_back(1); + } +} + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h new file mode 100644 index 00000000000..e66525fd646 --- /dev/null +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ +#include +#include +#include +#include "ir/anf.h" +#include "kernel/tbe/tbe_kernel_select/common_utils.h" +namespace mindspore { +namespace kernel { +class TbeKernelReduceSelecter { + public: + explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} + ~TbeKernelReduceSelecter() = default; + bool GetShapeInfo(SupportFormat *support_format); + bool IsReduceSupport5HD(SupportFormat *support_format) const; + bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const; + bool IsReduceSupportFracZ(SupportFormat *support_format) const; + bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const; + bool IsReduceSupportFracNZ(SupportFormat *support_format) const; + + private: + bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const; + void GetReduceAttrAxis(); + void GetReduceAttrKeepDim(); + void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; + bool Is4DShape(const std::vector &shape) const; + void PadScalarShape(std::vector *shape) const; + CNodePtr cnode_ptr_; + std::vector input_shape_{}; + std::vector output_shape_{}; + std::vector axis_{}; + bool keep_dims_ = false; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc new file mode 100644 index 00000000000..e9ad1283358 --- /dev/null +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -0,0 +1,633 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" +#include +#include +#include +#include +#include "session/anf_runtime_algorithm.h" +#include "kernel/oplib/oplib.h" +#include "kernel/tbe/tbe_kernel_build.h" +#include "nlohmann/json.hpp" +#include "utils/context/ms_context.h" +#include "kernel/tbe/tbe_python_funcs.h" +#include "pre_activate/common/helper.h" +#include "kernel/tbe/tbe_convert_utils.h" +#include "parallel/ops_info/ops_utils.h" +#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" +#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" +#include "kernel/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr auto kName = "name"; +constexpr auto kDtype = "dtype"; +constexpr auto kFormat = "format"; +constexpr auto kPrefixInput = "input"; +constexpr auto kPrefixOutput = "output"; +constexpr char kDynInputKey[] = "dyn_input_sizes"; +constexpr char kParamTypeDynamic[] = "dynamic"; +constexpr char kParamTypeRequre[] = "required"; +constexpr char kParamTypeOptional[] = "optional"; +void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list); + tbe_selecter.TbeMetadataInfoEx(); +} + +TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list) + : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {} + +void TbeKernelSelect::TbeMetadataInfoEx() { + MS_EXCEPTION_IF_NULL(cnode_ptr_); + MS_EXCEPTION_IF_NULL(kernel_info_list_); + node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); + auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); + if (!op_info_ptr) { + MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; + return; + } + MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ + << ", node name: " << cnode_ptr_->fullname_with_scope(); + OpPattern pattern = op_info_ptr->op_pattern(); + if (pattern == kCommonPattern) { + GetCommonPatternKernelInfo(*op_info_ptr); + } else if (pattern == kDynamicFormatPattern) { + GetDynamicFormatPatternKernelInfo(*op_info_ptr); + } else if (pattern == kFormatAgnosticPattern) { + GetAgnosticPatternKernelInfo(*op_info_ptr); + } else if (pattern == kBroadcastPattern) { + GetBroadcastPatternKernelInfo(*op_info_ptr); + } else if (pattern == kReducePattern) { + GetReducePatternKernelInfo(*op_info_ptr); + } else { + MS_LOG(INFO) << "Warning: op pattern is invailed."; + } + // check support + FilterInVaildKernelInfo(); + MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; +} + +void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + // get dynamic inputs + auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); + MS_EXCEPTION_IF_NULL(primitive); + std::vector dyn_input_sizes; + if (primitive->HasAttr(kDynInputKey)) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kDynInputKey)); + } + // get real input/output num + size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); + const auto inputs_info = op_info.inputs_ptr(); + size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + const auto outputs_info = op_info.outputs_ptr(); + if (inputs_info.empty() && outputs_info.empty()) { + MS_LOG(EXCEPTION) << "op info input & output is null, please check."; + } + // create kernel build info from opinfo + size_t kernel_build_info_num = + inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size(); + for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + SetTbeBuildCommonInfo(op_info, &builder); + std::vector inputs_format; + std::vector inputs_device_type; + std::vector> inputs_reshape_type; + // input + if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, + &inputs_format, &inputs_device_type, &inputs_reshape_type)) { + break; + } + builder.SetInputsDeviceType(inputs_device_type); + builder.SetInputsFormat(inputs_format); + builder.SetInputReshapeType(inputs_reshape_type); + // output + std::vector outputs_format; + std::vector outputs_device_type; + std::vector> outputs_reshape_type; + if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, + &outputs_format, &outputs_device_type, &outputs_reshape_type)) { + break; + } + builder.SetOutputsDeviceType(outputs_device_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputReshapeType(outputs_reshape_type); + kernel_info_list_->emplace_back(builder.Build()); + } + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + // + OpInfo op_info_new; + CreateNewOpInfo(op_info, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + if (op_info.inputs_ptr().size() != 1) { + MS_LOG(EXCEPTION) << "AgnosticPattern only support one input."; + } + auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0); + if (kOpFormatList.find(format) == kOpFormatList.end()) { + MS_LOG(INFO) << "Got the unknown format " << format; + format = kOpFormat_DEFAULT; + } + SupportFormat support_format; + SupportFormatItem input_item; + SupportFormatItem output_item; + input_item.assign(op_info.inputs_ptr().size(), format); + output_item.assign(op_info.outputs_ptr().size(), format); + support_format.input_format.emplace_back(input_item); + support_format.output_format.emplace_back(output_item); + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_); + SupportFormat support_format; + broadcast_selecter.GetShapeInfo(&support_format); + if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD."; + } + if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ."; + } + if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0."; + } + if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; + } + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_); + SupportFormat support_format; + reduce_selecter.GetShapeInfo(&support_format); + if (!reduce_selecter.IsReduceSupport5HD(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD."; + } + if (reduce_selecter.IsReduceSupportFracZ(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ."; + } + if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0."; + } + if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ."; + } + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::FilterInVaildKernelInfo() { + if (kernel_info_list_->empty()) { + MS_LOG(INFO) << "Warning: get kernel build info failed."; + return; + } + auto kernel_build_info_iter = kernel_info_list_->begin(); + while (kernel_build_info_iter != kernel_info_list_->end()) { + if (!FilterInVaildShape(kernel_build_info_iter)) { + MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); + kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); + continue; + } + if (!TbeCheckSupported(kernel_build_info_iter)) { + MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); + kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); + continue; + } + kernel_build_info_iter++; + } +} + +bool TbeKernelSelect::FilterInVaildShape( + const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { + MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); + auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); + for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); + auto format = kernel_build_info_inputs_format.at(i); + if (!IsShapeMatchFormat(shape, format)) { + MS_LOG(INFO) << "The " << i << "th input check failed."; + return false; + } + } + auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); + for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { + auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); + auto format = kernel_build_info_outputs_format.at(j); + if (!IsShapeMatchFormat(shape, format)) { + MS_LOG(INFO) << "The " << j << "th input check failed."; + return false; + } + } + return true; +} + +bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const std::string &format) { + if (format == kOpFormat_DEFAULT) { + return true; + } + static std::set kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; + // if format is default, it remarkes support all format + if (kOpFormatList.find(format) == kOpFormatList.end()) { + MS_LOG(EXCEPTION) << "Got the unknown format " << format; + } + // server not support format with C04 suffix + if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) != + kServerNotSupportFormat.end()) { + MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; + return false; + } + // not support format: + // 1 NDHWC with shape size != 5 + // 2 FRAC_NZ with shape size < 2 + // 3 !NDHWC with shape size > 4 + if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || + (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || + (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { + MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); + return false; + } + return true; +} + +bool TbeKernelSelect::TbeCheckSupported( + const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { + MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); + static const std::set kCheckSupportedOpType = {parallel::MATMUL, + parallel::BATCHMATMUL, + parallel::TOPK, + parallel::IN_TOPK, + parallel::PACK, + parallel::GATHER_ND, + parallel::UNSORTEF_SEGMENT_MIND, + parallel::UNSORTEF_SEGMENT_PRODD, + parallel::CAST}; + auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); + if (iter == kCheckSupportedOpType.end()) { + return true; + } + MS_LOG(INFO) << "Check support start."; + // replace kernel_info with current kernel info + auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); + AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); + nlohmann::json kernel_json; + TbeKernelJsonCreator creator(CHECK_SUPPORTED); + bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); + if (!ret) { + MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; + } + ret = TbePythonFuncs::CheckSupported(kernel_json); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); + return ret; +} + +void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info, + mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) { + MS_EXCEPTION_IF_NULL(builder); + builder->SetProcessor(AICORE); + std::string fusion_type = op_info.fusion_type(); + if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { + builder->SetFusionType(tbe::GetFusionType(fusion_type)); + } + builder->SetOpPattern(op_info.op_pattern()); + builder->SetKernelType(TBE_KERNEL); +} + +bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, + const std::vector> &ios_info, + const std::vector &dyn_input_sizes, std::vector *formats, + std::vector *device_types, std::vector> *reshape_types) { + MS_EXCEPTION_IF_NULL(formats); + MS_EXCEPTION_IF_NULL(device_types); + MS_EXCEPTION_IF_NULL(reshape_types); + size_t dynamic_input_index = 0; + size_t real_io_tensor_index = 0; + size_t io_info_index = 0; + size_t io_info_num = ios_info.size(); + for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { + std::shared_ptr io_info_item = ios_info[io_info_index]; + auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); + std::string kernel_build_info_format; + if (!io_info_item->formats().empty()) { + kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); + } + std::string io_param_type = io_info_item->param_type(); + std::vector reshape_type; + StringToAxisVector(io_info_item->reshape_type(), &reshape_type); + if (io_param_type == kParamTypeDynamic) { + // dynamic io + if (is_input) { + if (dynamic_input_index >= dyn_input_sizes.size()) { + MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index + << ", dyn_input_sizes size: " << dyn_input_sizes.size(); + } + int dynamic_input_size = dyn_input_sizes[dynamic_input_index]; + for (int i = 0; i < dynamic_input_size; ++i) { + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + } + dynamic_input_index++; + real_io_tensor_index += dynamic_input_size; + } else { + if (ios_info.size() != 1) { + MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; + } + for (size_t i = 0; i < real_io_tensor_num; ++i) { + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + } + real_io_tensor_index += real_io_tensor_num; + } + } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) { + // requre or optional io + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + real_io_tensor_index++; + } else { + MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; + } + } + + if (io_info_index != io_info_num) { + MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num + << "), this node may has optional input/output."; + } + if (real_io_tensor_index != real_io_tensor_num) { + std::string io_type = is_input ? "inputs " : "outputs"; + MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num + << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index + << ") != real_io_tensor_num(" << real_io_tensor_num << ")"; + return false; + } + return true; +} + +void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec) { + MS_EXCEPTION_IF_NULL(reshape_type_vec); + for (const auto &c : reshape_type_str) { + switch (c) { + case 'N': + reshape_type_vec->push_back(kernel::N); + break; + case 'C': + reshape_type_vec->push_back(kernel::C); + break; + case 'H': + reshape_type_vec->push_back(kernel::H); + break; + case 'W': + reshape_type_vec->push_back(kernel::W); + break; + default: + MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; + } + } +} + +void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, + const std::vector> &support_format_item, size_t index, + mindspore::kernel::OpIOInfo *op_io_info_new) { + MS_EXCEPTION_IF_NULL(op_io_info_new); + op_io_info_new->set_index(op_io_info.index()); + op_io_info_new->set_name(op_io_info.name()); + op_io_info_new->set_param_type(op_io_info.param_type()); + op_io_info_new->set_need_compile(op_io_info.need_compile()); + op_io_info_new->set_reshape_type(op_io_info.reshape_type()); + op_io_info_new->set_shape(op_io_info.shape()); + // dtype + std::vector dtype_new; + auto dtype = op_io_info.dtypes(); + for (size_t i = 0; i < support_format_item.size(); ++i) { + dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end()); + } + op_io_info_new->set_dtypes(dtype_new); + // format + std::vector format_new; + for (const auto &formats : support_format_item) { + auto format = formats.at(index); + for (size_t j = 0; j < dtype.size(); ++j) { + format_new.emplace_back(format); + } + } + op_io_info_new->set_formats(format_new); +} + +std::vector TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { + const std::map kDynamicFormatMap = { + {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; + if (op_select_json_item.empty()) { + MS_LOG(EXCEPTION) << "Op select ret item is null."; + } + const char space = ' '; + const char sep = ','; + std::string op_select_tmp = op_select_json_item + ","; + std::vector ret; + auto begin = op_select_tmp.find_first_not_of(space, 0); + auto sep_pos = op_select_tmp.find(sep); + while (sep_pos != std::string::npos) { + auto obj = op_select_tmp.substr(begin, sep_pos - begin); + if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) { + obj = kDynamicFormatMap.at(obj); + } + ret.emplace_back(obj); + begin = op_select_tmp.find_first_not_of(space, sep_pos + 1); + sep_pos = op_select_tmp.find(sep, begin); + } + return ret; +} + +std::string TbeKernelSelect::OpSelectFormat() { + nlohmann::json kernel_json; + std::string res_json_str; + TbeKernelJsonCreator creator(OP_SELECT_FORMAT); + bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); + if (!ret) { + MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; + } + res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); + if (res_json_str.empty()) { + MS_LOG(EXCEPTION) << "op select format error."; + } + MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; + return res_json_str; +} + +void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format, + mindspore::kernel::OpInfo *op_info_new) { + MS_EXCEPTION_IF_NULL(op_info_new); + if (op_info.inputs_ptr().size() != support_format.input_format[0].size() || + op_info.outputs_ptr().size() != support_format.output_format[0].size()) { + MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size() + << ", input support size: " << support_format.input_format[0].size() + << ", op info output size: " << op_info.outputs_ptr().size() + << ", output support size: " << support_format.output_format[0].size(); + } + *op_info_new = op_info; + op_info_new->ClearInputs(); + op_info_new->ClearOutputs(); + for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { + auto input = op_info.inputs_ptr().at(i); + auto input_new = std::make_shared(); + CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get()); + op_info_new->add_inputs_ptr(input_new); + } + for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) { + auto output = op_info.outputs_ptr().at(j); + auto output_new = std::make_shared(); + CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get()); + op_info_new->add_outputs_ptr(output_new); + } +} + +struct SelectOpIOInfo { + std::string name; + std::vector dtypes; + std::vector formats; +}; + +void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, + mindspore::kernel::OpInfo *op_info_new) { + MS_EXCEPTION_IF_NULL(op_info_new); + auto op_seclect_json = OpSelectFormat(); + if (!op_seclect_json.empty()) { + nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json); + if (!json_obj.is_object()) { + MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json; + } + std::vector inputs; + std::vector outputs; + for (const auto &item : json_obj.items()) { + const std::string &item_name = item.key(); + bool is_input = (item_name.find(kPrefixInput) != std::string::npos); + bool is_output = (item_name.find(kPrefixOutput) != std::string::npos); + if (!is_input && !is_output) { + MS_LOG(EXCEPTION) << "op select ret json is error."; + } + if (is_input) { + SelectOpIOInfo select_input; + select_input.name = item.value().at(kName); + std::string input_dtype_item = item.value().at(kDtype); + select_input.dtypes = SplitStrToVec(input_dtype_item); + std::string input_format_item = item.value().at(kFormat); + select_input.formats = SplitStrToVec(input_format_item); + inputs.emplace_back(select_input); + } else if (is_output) { + SelectOpIOInfo select_output; + select_output.name = item.value().at(kName); + std::string input_dtype_item = item.value().at(kDtype); + select_output.dtypes = SplitStrToVec(input_dtype_item); + std::string input_format_item = item.value().at(kFormat); + select_output.formats = SplitStrToVec(input_format_item); + outputs.emplace_back(select_output); + } + } + + if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) { + MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register."; + } + + *op_info_new = op_info; + op_info_new->ClearInputs(); + op_info_new->ClearOutputs(); + for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { + auto input_new = std::make_shared(); + CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get()); + op_info_new->add_inputs_ptr(input_new); + } + for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) { + auto output_new = std::make_shared(); + CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get()); + op_info_new->add_outputs_ptr(output_new); + } + } +} + +void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, + const std::vector &support_dtype, + const std::vector &support_format, + mindspore::kernel::OpIOInfo *op_io_info_new) { + MS_EXCEPTION_IF_NULL(op_io_info_new); + op_io_info_new->set_index(op_io_info.index()); + op_io_info_new->set_name(op_io_info.name()); + op_io_info_new->set_param_type(op_io_info.param_type()); + op_io_info_new->set_need_compile(op_io_info.need_compile()); + op_io_info_new->set_reshape_type(op_io_info.reshape_type()); + op_io_info_new->set_shape(op_io_info.shape()); + // dtype + std::vector dtype_new; + for (size_t i = 0; i < support_format.size(); ++i) { + dtype_new.insert(dtype_new.end(), support_dtype.begin(), support_dtype.end()); + } + op_io_info_new->set_dtypes(dtype_new); + // format + std::vector format_new; + for (const auto &format : support_format) { + for (size_t j = 0; j < support_dtype.size(); ++j) { + format_new.emplace_back(format); + } + } + op_io_info_new->set_formats(format_new); +} + +void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) { + if (support_format.input_format.size() != support_format.output_format.size()) { + MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output(" + << support_format.output_format.size() << ") size not match."; + } + for (size_t i = 0; i < support_format.input_format.size(); ++i) { + auto input_items = support_format.input_format.at(i); + auto output_items = support_format.output_format.at(i); + std::string print_str = "["; + for (const auto &input : input_items) { + print_str.append(input); + print_str.append(", "); + } + print_str.append("] -->"); + for (const auto &output : output_items) { + print_str.append(output); + print_str.append(", "); + } + MS_LOG(INFO) << "Support format: " << print_str; + } +} + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h new file mode 100644 index 00000000000..c400bdbb6f8 --- /dev/null +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_TBE_KERNEL_SELECT_H +#define MINDSPORE_TBE_KERNEL_SELECT_H + +#include +#include +#include +#include "kernel/oplib/opinfo.h" +#include "kernel/kernel_build_info.h" +#include "kernel/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); + +class TbeKernelSelect { + using OpInfoPtr = std::shared_ptr; + using KernelBuildInfoIter = std::vector>::iterator; + + public: + TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list); + ~TbeKernelSelect() = default; + void TbeMetadataInfoEx(); + + private: + void GetCommonPatternKernelInfo(const OpInfo &op_info); + void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info); + void GetAgnosticPatternKernelInfo(const OpInfo &op_info); + void GetBroadcastPatternKernelInfo(const OpInfo &op_info); + void GetReducePatternKernelInfo(const OpInfo &op_info); + void FilterInVaildKernelInfo(); + bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); + static bool IsShapeMatchFormat(const std::vector &shape, const std::string &format); + bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); + static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); + bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, + const std::vector> &ios_info, const std::vector &dyn_input_sizes, + std::vector *formats, std::vector *device_types, + std::vector> *reshape_types); + static void StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec); + static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); + static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, + const std::vector> &support_format_item, size_t index, + OpIOInfo *op_io_info_new); + // op select(dynamic) + void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new); + static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector &support_dtype, + const std::vector &support_format, OpIOInfo *op_io_info_new); + static std::vector SplitStrToVec(const std::string &op_select_json_item); + std::string OpSelectFormat(); + + static void PrintSupportedFormat(const SupportFormat &support_format); + + private: + CNodePtr cnode_ptr_; + std::vector> *kernel_info_list_; + std::string node_name_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_TBE_KERNEL_SELECT_H diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index bc0d669baa4..952b87a14d7 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -216,6 +216,13 @@ constexpr char NEG[] = "Neg"; constexpr char BATCH_MATMUL[] = "BatchMatMul"; constexpr char EXPAND_DIMS[] = "ExpandDims"; constexpr char SQUARE[] = "Square"; +constexpr char BATCHMATMUL[] = "BatchMatMul"; +constexpr char TOPK[] = "TopK"; +constexpr char IN_TOPK[] = "InTopK"; +constexpr char PACK[] = "Pack"; +constexpr char GATHER_ND[] = "GatherNd"; +constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; +constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; // Parallel don't care constexpr char TUPLE_GETITEM[] = "tuple_getitem"; diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index e47270d8478..a385e574a42 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -21,7 +21,6 @@ #include #include "device/ascend/kernel_select_ascend.h" #include "kernel/kernel_query.h" -#include "kernel/tbe/tbe_kernel_select.h" #include "kernel/oplib/oplib.h" #include "session/anf_runtime_algorithm.h" diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc index cfa4e423424..c0f99ed4159 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc @@ -34,7 +34,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph return nullptr; } auto node_name = AnfAlgo::GetCNodeName(node); - if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) { + if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) { return nullptr; } auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); diff --git a/mindspore/ops/_op_impl/tbe/abs.py b/mindspore/ops/_op_impl/tbe/abs.py index 30a75812bde..66c1d409fbf 100644 --- a/mindspore/ops/_op_impl/tbe/abs.py +++ b/mindspore/ops/_op_impl/tbe/abs.py @@ -26,12 +26,9 @@ abs_op_info = TBERegOp("Abs") \ .op_pattern("formatAgnostic") \ .input(0, "x", None, "required", None) \ .output(0, "y", True, "required", "all") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/abs_grad.py b/mindspore/ops/_op_impl/tbe/abs_grad.py index ba630f6570f..3e7ac70d80e 100644 --- a/mindspore/ops/_op_impl/tbe/abs_grad.py +++ b/mindspore/ops/_op_impl/tbe/abs_grad.py @@ -23,7 +23,6 @@ abs_grad_op_info = TBERegOp("AbsGrad") \ .compute_cost(10) \ .kernel_name("abs_grad") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .input(0, "y", None, "required", None) \ .input(1, "dy", None, "required", None) \ .output(0, "z", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/add.py b/mindspore/ops/_op_impl/tbe/add.py index 63e1efb1c65..d3db3de0ad9 100644 --- a/mindspore/ops/_op_impl/tbe/add.py +++ b/mindspore/ops/_op_impl/tbe/add.py @@ -26,6 +26,7 @@ add_op_info = TBERegOp("Add") \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ diff --git a/mindspore/ops/_op_impl/tbe/add_n.py b/mindspore/ops/_op_impl/tbe/add_n.py index 3e8a6c00164..1c42b4bb2d2 100644 --- a/mindspore/ops/_op_impl/tbe/add_n.py +++ b/mindspore/ops/_op_impl/tbe/add_n.py @@ -26,17 +26,10 @@ add_n_op_info = TBERegOp("AddN") \ .attr("n", "required", "int", "all") \ .input(0, "x", False, "dynamic", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ - .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ - .dtype_format(DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ - .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/batch_matmul.py b/mindspore/ops/_op_impl/tbe/batch_matmul.py index 4efcf8031c6..02f2dd58809 100644 --- a/mindspore/ops/_op_impl/tbe/batch_matmul.py +++ b/mindspore/ops/_op_impl/tbe/batch_matmul.py @@ -29,6 +29,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \ .input(1, "x2", False, "required", "all") \ .input(2, "bias", False, "optional", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \ diff --git a/mindspore/ops/_op_impl/tbe/bias_add.py b/mindspore/ops/_op_impl/tbe/bias_add.py index 24607af141b..5ab19162997 100644 --- a/mindspore/ops/_op_impl/tbe/bias_add.py +++ b/mindspore/ops/_op_impl/tbe/bias_add.py @@ -27,6 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \ .input(0, "x", False, "required", "all") \ .input(1, "bias", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ diff --git a/mindspore/ops/_op_impl/tbe/bn_training_reduce.py b/mindspore/ops/_op_impl/tbe/bn_training_reduce.py index e19d4b65ffd..f33cba2110d 100644 --- a/mindspore/ops/_op_impl/tbe/bn_training_reduce.py +++ b/mindspore/ops/_op_impl/tbe/bn_training_reduce.py @@ -26,6 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \ .input(0, "x", False, "required", "all", reshape_type="NC") \ .output(0, "sum", False, "required", "all") \ .output(1, "square_sum", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py b/mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py index 66dc55ab105..89736a00971 100644 --- a/mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +++ b/mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py @@ -32,6 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \ .input(5, "batch_mean", False, "required", "all") \ .input(6, "batch_variance", False, "required", "all") \ .output(0, "y", False, "required", "all", reshape_type="NC") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, diff --git a/mindspore/ops/_op_impl/tbe/bn_training_update_grad.py b/mindspore/ops/_op_impl/tbe/bn_training_update_grad.py index 50989232815..1aa822a3c18 100644 --- a/mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +++ b/mindspore/ops/_op_impl/tbe/bn_training_update_grad.py @@ -30,6 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \ .input(3, "batch_variance", False, "required", "all") \ .output(0, "diff_scale", False, "required", "all") \ .output(1, "diff_offset", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, diff --git a/mindspore/ops/_op_impl/tbe/bn_training_update_v2.py b/mindspore/ops/_op_impl/tbe/bn_training_update_v2.py index 03a51664e86..a54d91a483b 100644 --- a/mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +++ b/mindspore/ops/_op_impl/tbe/bn_training_update_v2.py @@ -32,6 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \ .output(0, "y", False, "required", "all", reshape_type="NC") \ .output(1, "batch_mean", False, "required", "all") \ .output(2, "batch_variance", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ diff --git a/mindspore/ops/_op_impl/tbe/cast.py b/mindspore/ops/_op_impl/tbe/cast.py index 07e14139daa..0a809e28a7c 100644 --- a/mindspore/ops/_op_impl/tbe/cast.py +++ b/mindspore/ops/_op_impl/tbe/cast.py @@ -26,32 +26,27 @@ cast_op_info = TBERegOp("Cast") \ .attr("dst_type", "required", "int", "all") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.BOOL_Default, DataType.F16_Default) \ - .dtype_format(DataType.BOOL_Default, DataType.U8_Default) \ - .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ - .dtype_format(DataType.BOOL_Default, DataType.I32_Default) \ - .dtype_format(DataType.I8_Default, DataType.F16_Default) \ - .dtype_format(DataType.I8_Default, DataType.F32_Default) \ - .dtype_format(DataType.I8_Default, DataType.I32_Default) \ - .dtype_format(DataType.U8_Default, DataType.F16_Default) \ - .dtype_format(DataType.U8_Default, DataType.F32_Default) \ - .dtype_format(DataType.U8_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_Default, DataType.BOOL_Default) \ - .dtype_format(DataType.I32_Default, DataType.F16_Default) \ - .dtype_format(DataType.I32_Default, DataType.F32_Default) \ - .dtype_format(DataType.I32_Default, DataType.I8_Default) \ - .dtype_format(DataType.I32_Default, DataType.U8_Default) \ - .dtype_format(DataType.F16_Default, DataType.U8_Default) \ - .dtype_format(DataType.F16_Default, DataType.F32_Default) \ - .dtype_format(DataType.F16_Default, DataType.I32_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \ - .dtype_format(DataType.F32_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \ - .dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \ - .dtype_format(DataType.F32_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.I32_Default) \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.BOOL_None, DataType.F16_None) \ + .dtype_format(DataType.BOOL_None, DataType.U8_None) \ + .dtype_format(DataType.BOOL_None, DataType.F32_None) \ + .dtype_format(DataType.BOOL_None, DataType.I32_None) \ + .dtype_format(DataType.I8_None, DataType.F16_None) \ + .dtype_format(DataType.I8_None, DataType.F32_None) \ + .dtype_format(DataType.I8_None, DataType.I32_None) \ + .dtype_format(DataType.U8_None, DataType.F16_None) \ + .dtype_format(DataType.U8_None, DataType.F32_None) \ + .dtype_format(DataType.U8_None, DataType.I32_None) \ + .dtype_format(DataType.I32_None, DataType.BOOL_None) \ + .dtype_format(DataType.I32_None, DataType.F16_None) \ + .dtype_format(DataType.I32_None, DataType.F32_None) \ + .dtype_format(DataType.I32_None, DataType.I8_None) \ + .dtype_format(DataType.I32_None, DataType.U8_None) \ + .dtype_format(DataType.F16_None, DataType.U8_None) \ + .dtype_format(DataType.F16_None, DataType.F32_None) \ + .dtype_format(DataType.F16_None, DataType.I32_None) \ + .dtype_format(DataType.F32_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.I32_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/concat.py b/mindspore/ops/_op_impl/tbe/concat.py index 56807b15fc2..0bf636016f4 100644 --- a/mindspore/ops/_op_impl/tbe/concat.py +++ b/mindspore/ops/_op_impl/tbe/concat.py @@ -26,6 +26,7 @@ concat_op_info = TBERegOp("Concat") \ .attr("axis", "required", "int", "all") \ .input(0, "input_values", False, "dynamic", "all") \ .output(0, "output_data", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ .dtype_format(DataType.I8_Default, DataType.I8_Default) \ diff --git a/mindspore/ops/_op_impl/tbe/conv2d.py b/mindspore/ops/_op_impl/tbe/conv2d.py index 425521901d7..a2879d521aa 100644 --- a/mindspore/ops/_op_impl/tbe/conv2d.py +++ b/mindspore/ops/_op_impl/tbe/conv2d.py @@ -23,6 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \ .compute_cost(10) \ .kernel_name("conv2d") \ .partial_flag(True) \ + .op_pattern("dynamicFormat") \ .attr("stride", "required", "listInt", "all") \ .attr("pad_list", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ @@ -32,8 +33,7 @@ conv2d_op_info = TBERegOp("Conv2D") \ .input(2, "bias", False, "optional", "all") \ .input(3, "offset_w", False, "optional", "all") \ .output(0, "y", True, "required", "all") \ - .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default, - DataType.F16_5HD) \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/dropout_do_mask.py b/mindspore/ops/_op_impl/tbe/dropout_do_mask.py index 2bef489b961..a24e02f9640 100644 --- a/mindspore/ops/_op_impl/tbe/dropout_do_mask.py +++ b/mindspore/ops/_op_impl/tbe/dropout_do_mask.py @@ -27,6 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \ .input(1, "mask", False, "required", "all") \ .input(2, "keep_prob", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/elu.py b/mindspore/ops/_op_impl/tbe/elu.py index 9125d14727a..e61e2851af0 100644 --- a/mindspore/ops/_op_impl/tbe/elu.py +++ b/mindspore/ops/_op_impl/tbe/elu.py @@ -28,9 +28,7 @@ elu_op_info = TBERegOp("Elu") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/erf.py b/mindspore/ops/_op_impl/tbe/erf.py index 2247197c4e7..4c4893d505b 100644 --- a/mindspore/ops/_op_impl/tbe/erf.py +++ b/mindspore/ops/_op_impl/tbe/erf.py @@ -26,9 +26,7 @@ erf_op_info = TBERegOp("Erf") \ .op_pattern("formatAgnostic") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/erfc.py b/mindspore/ops/_op_impl/tbe/erfc.py index 7e1b76649aa..7b0eccf52e9 100644 --- a/mindspore/ops/_op_impl/tbe/erfc.py +++ b/mindspore/ops/_op_impl/tbe/erfc.py @@ -26,9 +26,7 @@ erfc_op_info = TBERegOp("Erfc") \ .op_pattern("formatAgnostic") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/expm1.py b/mindspore/ops/_op_impl/tbe/expm1.py index 2f62c33a3ef..a126aca36f7 100644 --- a/mindspore/ops/_op_impl/tbe/expm1.py +++ b/mindspore/ops/_op_impl/tbe/expm1.py @@ -27,9 +27,7 @@ expm1_op_info = TBERegOp("Expm1") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/fused_mul_add.py b/mindspore/ops/_op_impl/tbe/fused_mul_add.py index ad3c601e5d1..fa104fb5611 100644 --- a/mindspore/ops/_op_impl/tbe/fused_mul_add.py +++ b/mindspore/ops/_op_impl/tbe/fused_mul_add.py @@ -27,6 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \ .input(1, "x2", False, "required", "all") \ .input(2, "x3", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ diff --git a/mindspore/ops/_op_impl/tbe/layer_norm.py b/mindspore/ops/_op_impl/tbe/layer_norm.py index c52be2d4eff..03ddd2dc6cb 100644 --- a/mindspore/ops/_op_impl/tbe/layer_norm.py +++ b/mindspore/ops/_op_impl/tbe/layer_norm.py @@ -32,6 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \ .output(0, "y", False, "required", "all") \ .output(1, "mean", False, "required", "all") \ .output(2, "variance", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, diff --git a/mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py b/mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py index ef254465bc7..deca3840322 100644 --- a/mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +++ b/mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py @@ -30,6 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop") .input(3, "mean", False, "required", "all") \ .output(0, "pd_gamma", False, "required", "all") \ .output(1, "pd_beta", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, diff --git a/mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py b/mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py index bbab66816df..1d4f1ef231e 100644 --- a/mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +++ b/mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py @@ -29,6 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \ .input(3, "mean", False, "required", "all") \ .input(4, "gamma", False, "required", "all") \ .output(0, "pd_x", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, diff --git a/mindspore/ops/_op_impl/tbe/mul.py b/mindspore/ops/_op_impl/tbe/mul.py index fa74c88de36..5433bf0b532 100644 --- a/mindspore/ops/_op_impl/tbe/mul.py +++ b/mindspore/ops/_op_impl/tbe/mul.py @@ -26,21 +26,8 @@ mul_op_info = TBERegOp("Mul") \ .input(0, "x", False, "required", "all") \ .input(1, "y", False, "required", "all") \ .output(0, "output", False, "required", "all") \ - .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ - .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ - .dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \ - .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ - .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ - .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ - .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ - .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .op_pattern("dynamicFormat") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/real_div.py b/mindspore/ops/_op_impl/tbe/real_div.py index b39948971db..9c6d9e0b27f 100644 --- a/mindspore/ops/_op_impl/tbe/real_div.py +++ b/mindspore/ops/_op_impl/tbe/real_div.py @@ -26,10 +26,9 @@ realdiv_op_info = TBERegOp("RealDiv") \ .input(0, "x", False, "required", "all") \ .input(1, "y", False, "required", "all") \ .output(0, "z", False, "required", "all") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/reciprocal.py b/mindspore/ops/_op_impl/tbe/reciprocal.py index dfa126384c6..77f3bfac271 100644 --- a/mindspore/ops/_op_impl/tbe/reciprocal.py +++ b/mindspore/ops/_op_impl/tbe/reciprocal.py @@ -25,6 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \ .partial_flag(True) \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ diff --git a/mindspore/ops/_op_impl/tbe/reduce_mean.py b/mindspore/ops/_op_impl/tbe/reduce_mean.py index 67b96933a18..b01fd3bebd8 100644 --- a/mindspore/ops/_op_impl/tbe/reduce_mean.py +++ b/mindspore/ops/_op_impl/tbe/reduce_mean.py @@ -27,11 +27,11 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \ .attr("keep_dims", "optional", "bool", "all") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.I8_Default, DataType.I8_Default) \ - .dtype_format(DataType.U8_Default, DataType.U8_Default) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .op_pattern("reduce") \ + .dtype_format(DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.U8_None, DataType.U8_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/relu_grad_v2.py b/mindspore/ops/_op_impl/tbe/relu_grad_v2.py index 93d7dede629..e5f82c8b787 100644 --- a/mindspore/ops/_op_impl/tbe/relu_grad_v2.py +++ b/mindspore/ops/_op_impl/tbe/relu_grad_v2.py @@ -24,7 +24,7 @@ relu_grad_v2_op_info = TBERegOp("ReluGradV2") \ .kernel_name("relu_grad_v2") \ .partial_flag(True) \ .input(0, "gradients", False, "required", "all") \ - .input(1, "mask", False, "rerequired", "all") \ + .input(1, "mask", False, "required", "all") \ .output(0, "backprops", True, "required", "all") \ .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \ .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \ diff --git a/mindspore/ops/_op_impl/tbe/select.py b/mindspore/ops/_op_impl/tbe/select.py index 4af43253123..e924f050214 100644 --- a/mindspore/ops/_op_impl/tbe/select.py +++ b/mindspore/ops/_op_impl/tbe/select.py @@ -27,6 +27,7 @@ select_op_info = TBERegOp("Select") \ .input(1, "x1", False, "required", "all") \ .input(2, "x2", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ diff --git a/mindspore/ops/_op_impl/tbe/sign.py b/mindspore/ops/_op_impl/tbe/sign.py index 823715aa9f5..99f79703163 100644 --- a/mindspore/ops/_op_impl/tbe/sign.py +++ b/mindspore/ops/_op_impl/tbe/sign.py @@ -27,11 +27,8 @@ sign_op_info = TBERegOp("Sign") \ .input(0, "x", None, "required", None) \ .output(0, "y", True, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/softmax_grad_ext.py b/mindspore/ops/_op_impl/tbe/softmax_grad_ext.py index 1953b41223b..d43183dcb76 100644 --- a/mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +++ b/mindspore/ops/_op_impl/tbe/softmax_grad_ext.py @@ -30,6 +30,7 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \ .input(1, "x1", False, "required", "all") \ .input(2, "x2", False, "required", "all") \ .output(0, "y", True, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, diff --git a/mindspore/ops/_op_impl/tbe/softplus.py b/mindspore/ops/_op_impl/tbe/softplus.py index d362cd06db8..92261d91ef2 100644 --- a/mindspore/ops/_op_impl/tbe/softplus.py +++ b/mindspore/ops/_op_impl/tbe/softplus.py @@ -27,9 +27,7 @@ softplus_op_info = TBERegOp("Softplus") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/softplus_grad.py b/mindspore/ops/_op_impl/tbe/softplus_grad.py index 4bf7a82440d..3dc0e7ee0c3 100644 --- a/mindspore/ops/_op_impl/tbe/softplus_grad.py +++ b/mindspore/ops/_op_impl/tbe/softplus_grad.py @@ -28,9 +28,7 @@ softplus_grad_op_info = TBERegOp("SoftplusGrad") \ .input(1, "features", False, "required", "all") \ .output(0, "backprops", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/split_d.py b/mindspore/ops/_op_impl/tbe/split_d.py index dcc8219fd4d..d2faf310969 100644 --- a/mindspore/ops/_op_impl/tbe/split_d.py +++ b/mindspore/ops/_op_impl/tbe/split_d.py @@ -27,6 +27,7 @@ split_d_op_info = TBERegOp("Split") \ .attr("output_num", "required", "int", "all") \ .input(0, "value", False, "required", "all") \ .output(0, "output", False, "dynamic", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ .dtype_format(DataType.I8_Default, DataType.I8_Default) \ diff --git a/mindspore/ops/_op_impl/tbe/tensor_add.py b/mindspore/ops/_op_impl/tbe/tensor_add.py index 255c1b12784..a1f21bee776 100644 --- a/mindspore/ops/_op_impl/tbe/tensor_add.py +++ b/mindspore/ops/_op_impl/tbe/tensor_add.py @@ -26,6 +26,7 @@ tensor_add_op_info = TBERegOp("TensorAdd") \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ diff --git a/mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py b/mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py index 5dc07dd59ff..b1f81b72b01 100644 --- a/mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +++ b/mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py @@ -27,6 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \ .input(0, "x", False, "required", "all") \ .input(1, "segment_ids", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \ .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 3096e902506..a7a60b7181c 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -97,6 +97,7 @@ class RegOp: """ if not isinstance(value, str): raise TypeError("%s value must be str" % str(value)) + return True def _is_int(self, value): """ @@ -110,6 +111,7 @@ class RegOp: """ if not isinstance(value, int): raise TypeError("%s value must be int" % str(value)) + return True def _is_bool(self, value): """ @@ -123,6 +125,7 @@ class RegOp: """ if not isinstance(value, bool): raise TypeError("%s value must be bool" % str(value)) + return True def _check_param(self, param_list, key_list, fn_list, kwargs): """ @@ -494,6 +497,7 @@ class DataType: The current list below maybe not completed. If necessary, please add it. """ + None_None = ("", "") BOOL_None = ("bool", "") BOOL_Default = ("bool", "DefaultFormat") BOOL_5HD = ("bool", "NC1HWC0")