forked from mindspore-Ecosystem/mindspore
tbe select broadcast reduce dynamic
This commit is contained in:
parent
553432c968
commit
caab25e09b
|
@ -20,7 +20,7 @@
|
||||||
#include "kernel/aicpu/aicpu_kernel_metadata.h"
|
#include "kernel/aicpu/aicpu_kernel_metadata.h"
|
||||||
#include "kernel/rts/rt_kernel_info.h"
|
#include "kernel/rts/rt_kernel_info.h"
|
||||||
#include "kernel/hccl/hccl_kernel_metadata.h"
|
#include "kernel/hccl/hccl_kernel_metadata.h"
|
||||||
#include "kernel/tbe/tbe_kernel_select.h"
|
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
|
||||||
#include "session/anf_runtime_algorithm.h"
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -63,7 +63,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||||
TbeMetadataInfo(kernel_node, kernel_info_list);
|
TbeMetadataInfo(kernel_node, kernel_info_list);
|
||||||
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
|
||||||
if (kernel_info_list->empty()) {
|
if (kernel_info_list->empty()) {
|
||||||
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
||||||
if (!kernel_info_list->empty()) {
|
if (!kernel_info_list->empty()) {
|
||||||
|
@ -114,7 +113,6 @@ bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
|
||||||
auto cnode = kernel_node->cast<CNodePtr>();
|
auto cnode = kernel_node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
TbeMetadataInfo(cnode, &kernel_info_list);
|
TbeMetadataInfo(cnode, &kernel_info_list);
|
||||||
FilterInvalidKernelInfo(cnode, &kernel_info_list);
|
|
||||||
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
|
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
|
||||||
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
|
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
|
|
@ -126,6 +126,8 @@ class OpInfo {
|
||||||
bool is_ref() const { return !ref_infos_.empty(); }
|
bool is_ref() const { return !ref_infos_.empty(); }
|
||||||
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
|
bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
|
||||||
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
|
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
|
||||||
|
void ClearInputs() { (void)inputs_ptr_.clear(); }
|
||||||
|
void ClearOutputs() { (void)outputs_ptr_.clear(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string op_name_;
|
std::string op_name_;
|
||||||
|
|
|
@ -35,7 +35,7 @@ constexpr auto kKernelName = "kernel_name";
|
||||||
constexpr auto kPartialFlag = "partial_flag";
|
constexpr auto kPartialFlag = "partial_flag";
|
||||||
constexpr auto kReshapeType = "reshape_type";
|
constexpr auto kReshapeType = "reshape_type";
|
||||||
constexpr auto kOpPattern = "op_pattern";
|
constexpr auto kOpPattern = "op_pattern";
|
||||||
constexpr auto kDynamicFormat = "dynamic_format";
|
constexpr auto kDynamicFormat = "dynamicFormat";
|
||||||
constexpr auto kFormatAgnostic = "formatAgnostic";
|
constexpr auto kFormatAgnostic = "formatAgnostic";
|
||||||
constexpr auto kBroadcast = "broadcast";
|
constexpr auto kBroadcast = "broadcast";
|
||||||
constexpr auto kReduce = "reduce";
|
constexpr auto kReduce = "reduce";
|
||||||
|
@ -100,7 +100,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
|
||||||
|
|
||||||
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
|
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
|
||||||
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern},
|
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern},
|
||||||
{kFormatAgnostic, kBroadcastPattern},
|
{kBroadcast, kBroadcastPattern},
|
||||||
{kReduce, kReducePattern},
|
{kReduce, kReducePattern},
|
||||||
{kDynamicFormat, kDynamicFormatPattern}};
|
{kDynamicFormat, kDynamicFormatPattern}};
|
||||||
op_info->set_async_flag(obj.at(kAsyncFlag));
|
op_info->set_async_flag(obj.at(kAsyncFlag));
|
||||||
|
@ -108,13 +108,18 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
||||||
op_info->set_compute_cost(obj.at(kComputeCost));
|
op_info->set_compute_cost(obj.at(kComputeCost));
|
||||||
op_info->set_kernel_name(obj.at(kKernelName));
|
op_info->set_kernel_name(obj.at(kKernelName));
|
||||||
op_info->set_partial_flag(obj.at(kPartialFlag));
|
op_info->set_partial_flag(obj.at(kPartialFlag));
|
||||||
|
|
||||||
if (obj.find(kOpPattern) != obj.end()) {
|
if (obj.find(kOpPattern) != obj.end()) {
|
||||||
if (kOpPatternMap.find(obj.at(kOpPattern)) != kOpPatternMap.end()) {
|
std::string op_pattern = obj.at(kOpPattern);
|
||||||
op_info->set_op_pattern(obj.at(kOpPattern));
|
auto find_iter = kOpPatternMap.find(op_pattern);
|
||||||
|
if (find_iter == kOpPatternMap.end()) {
|
||||||
|
if (!op_pattern.empty()) {
|
||||||
|
MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern;
|
||||||
}
|
}
|
||||||
|
op_info->set_op_pattern(kCommonPattern);
|
||||||
|
} else {
|
||||||
|
op_info->set_op_pattern(find_iter->second);
|
||||||
}
|
}
|
||||||
if (obj.find(kDynamicFormat) != obj.end()) {
|
|
||||||
op_info->set_dynamic_format(obj.at(kDynamicFormat));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ const std::map<TypeId, std::string> type_id_str_maps = {
|
||||||
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
|
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
|
||||||
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
|
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
|
||||||
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
|
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
|
||||||
{TypeId::kNumberTypeBool, "bool"},
|
{TypeId::kNumberTypeBool, "int8"},
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::map<std::string, std::string> type_str_maps = {
|
const std::map<std::string, std::string> type_str_maps = {
|
||||||
|
@ -85,7 +85,7 @@ std::string DtypeToString(const std::string &dtypes) {
|
||||||
std::string TypeIdToString(TypeId type_id) {
|
std::string TypeIdToString(TypeId type_id) {
|
||||||
auto iter = type_id_str_maps.find(type_id);
|
auto iter = type_id_str_maps.find(type_id);
|
||||||
if (iter == type_id_str_maps.end()) {
|
if (iter == type_id_str_maps.end()) {
|
||||||
MS_LOG(EXCEPTION) << "Illegal input dtype." << TypeIdLabel(type_id);
|
MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id);
|
||||||
}
|
}
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,41 +111,20 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node,
|
||||||
if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
|
if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
|
||||||
TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
|
TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
|
||||||
} else {
|
} else {
|
||||||
// dtype : float16
|
auto dtype = GetDeviceInputType(anf_node, real_input_index);
|
||||||
auto tensor_dtype =
|
auto format = GetDeviceInputFormat(anf_node, real_input_index);
|
||||||
std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index)));
|
auto shape = GetDeviceInputShape(anf_node, real_input_index);
|
||||||
MS_EXCEPTION_IF_NULL(tensor_dtype);
|
|
||||||
std::string dtype = tensor_dtype->element()->ToString();
|
|
||||||
dtype = tbe::DtypeToString(dtype);
|
|
||||||
|
|
||||||
// format
|
|
||||||
std::string format = AnfAlgo::GetInputFormat(anf_node, real_input_index);
|
|
||||||
if (format == kOpFormat_DEFAULT) {
|
|
||||||
format = kOpFormat_NCHW;
|
|
||||||
} else if (format == kOpFormat_FRAC_Z) {
|
|
||||||
format = kOpFormat_FRACTAL_Z;
|
|
||||||
}
|
|
||||||
|
|
||||||
nlohmann::json input_desc_json;
|
|
||||||
input_desc_json["dtype"] = dtype;
|
|
||||||
input_desc_json["name"] = op_input_name + std::to_string(input_i);
|
|
||||||
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
|
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
|
||||||
if (ori_shape.empty()) {
|
if (ori_shape.empty()) {
|
||||||
ori_shape.emplace_back(1);
|
ori_shape.emplace_back(1);
|
||||||
}
|
}
|
||||||
|
nlohmann::json input_desc_json;
|
||||||
|
input_desc_json["dtype"] = dtype;
|
||||||
|
input_desc_json["name"] = op_input_name + std::to_string(input_i);
|
||||||
input_desc_json["ori_shape"] = ori_shape;
|
input_desc_json["ori_shape"] = ori_shape;
|
||||||
input_desc_json["ori_format"] = kOpFormat_NCHW;
|
input_desc_json["ori_format"] = kOpFormat_NCHW;
|
||||||
auto shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
|
|
||||||
if (shape.empty()) {
|
|
||||||
shape.emplace_back(1);
|
|
||||||
}
|
|
||||||
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
|
|
||||||
input_desc_json["shape"] = ori_shape;
|
|
||||||
input_desc_json["format"] = kOpFormat_NCHW;
|
|
||||||
} else {
|
|
||||||
input_desc_json["shape"] = shape;
|
input_desc_json["shape"] = shape;
|
||||||
input_desc_json["format"] = format;
|
input_desc_json["format"] = format;
|
||||||
}
|
|
||||||
input_desc_json["valid"] = value;
|
input_desc_json["valid"] = value;
|
||||||
input_desc_json["param_type"] = input_ptr->param_type();
|
input_desc_json["param_type"] = input_ptr->param_type();
|
||||||
input_list->emplace_back(input_desc_json);
|
input_list->emplace_back(input_desc_json);
|
||||||
|
@ -325,40 +304,22 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co
|
||||||
MS_EXCEPTION_IF_NULL(output_idx);
|
MS_EXCEPTION_IF_NULL(output_idx);
|
||||||
MS_EXCEPTION_IF_NULL(output_list);
|
MS_EXCEPTION_IF_NULL(output_list);
|
||||||
for (size_t i = 0; i < output_obj_num; i++) {
|
for (size_t i = 0; i < output_obj_num; i++) {
|
||||||
nlohmann::json output_obj;
|
auto dtype = GetDeviceOutputType(anf_node, *output_idx);
|
||||||
auto type_ptr = std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, *output_idx)));
|
auto format = GetDeviceOutputFormat(anf_node, *output_idx);
|
||||||
std::string dtype = type_ptr->element()->ToString();
|
auto shape = GetDeviceOutputShape(anf_node, *output_idx);
|
||||||
dtype = tbe::DtypeToString(dtype);
|
std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
|
||||||
std::string format = AnfAlgo::GetOutputFormat(anf_node, *output_idx);
|
if (ori_shape.empty()) {
|
||||||
if (format == kOpFormat_DEFAULT) {
|
|
||||||
format = kOpFormat_NCHW;
|
|
||||||
} else if (format == kOpFormat_FRAC_Z) {
|
|
||||||
format = kOpFormat_FRACTAL_Z;
|
|
||||||
}
|
|
||||||
std::vector<size_t> ori_shape;
|
|
||||||
if (AnfAlgo::GetOutputInferShape(anf_node, *output_idx).empty()) {
|
|
||||||
ori_shape.emplace_back(1);
|
ori_shape.emplace_back(1);
|
||||||
} else {
|
|
||||||
ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
|
|
||||||
}
|
}
|
||||||
|
nlohmann::json output_obj;
|
||||||
output_obj["dtype"] = dtype;
|
output_obj["dtype"] = dtype;
|
||||||
auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, *output_idx);
|
|
||||||
if (shape.empty()) {
|
|
||||||
shape.emplace_back(1);
|
|
||||||
}
|
|
||||||
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
|
|
||||||
output_obj["shape"] = ori_shape;
|
|
||||||
output_obj["format"] = kOpFormat_NCHW;
|
|
||||||
} else {
|
|
||||||
output_obj["shape"] = shape;
|
output_obj["shape"] = shape;
|
||||||
output_obj["format"] = format;
|
output_obj["format"] = format;
|
||||||
}
|
|
||||||
output_obj["ori_shape"] = ori_shape;
|
output_obj["ori_shape"] = ori_shape;
|
||||||
output_obj["ori_format"] = kOpFormat_NCHW;
|
output_obj["ori_format"] = kOpFormat_NCHW;
|
||||||
output_obj["name"] = output_ptr->name();
|
output_obj["name"] = output_ptr->name();
|
||||||
output_obj["valid"] = true;
|
output_obj["valid"] = true;
|
||||||
output_obj["param_type"] = output_ptr->param_type();
|
output_obj["param_type"] = output_ptr->param_type();
|
||||||
|
|
||||||
output_list->emplace_back(output_obj);
|
output_list->emplace_back(output_obj);
|
||||||
(*output_idx)++;
|
(*output_idx)++;
|
||||||
}
|
}
|
||||||
|
@ -456,6 +417,84 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
std::vector<size_t> shape;
|
||||||
|
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
|
||||||
|
shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index);
|
||||||
|
} else {
|
||||||
|
shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index);
|
||||||
|
}
|
||||||
|
if (shape.empty()) {
|
||||||
|
shape.emplace_back(1);
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
TypeId type_id;
|
||||||
|
if (creater_type_ == OP_SELECT_FORMAT) {
|
||||||
|
type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index);
|
||||||
|
} else {
|
||||||
|
type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index);
|
||||||
|
}
|
||||||
|
return tbe::TypeIdToString(type_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
std::string format = kOpFormat_NCHW;
|
||||||
|
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
|
||||||
|
format = AnfAlgo::GetInputFormat(anf_node, real_index);
|
||||||
|
if (format == kOpFormat_FRAC_Z) {
|
||||||
|
format = kOpFormat_FRACTAL_Z;
|
||||||
|
} else if (format == kOpFormat_DEFAULT) {
|
||||||
|
format = kOpFormat_NCHW;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return format;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
std::vector<size_t> shape;
|
||||||
|
if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
|
||||||
|
shape = AnfAlgo::GetOutputInferShape(anf_node, real_index);
|
||||||
|
} else {
|
||||||
|
shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index);
|
||||||
|
}
|
||||||
|
if (shape.empty()) {
|
||||||
|
shape.emplace_back(1);
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
TypeId type_id;
|
||||||
|
if (creater_type_ == OP_SELECT_FORMAT) {
|
||||||
|
type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index);
|
||||||
|
} else {
|
||||||
|
type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index);
|
||||||
|
}
|
||||||
|
return tbe::TypeIdToString(type_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
std::string format = kOpFormat_NCHW;
|
||||||
|
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
|
||||||
|
format = AnfAlgo::GetOutputFormat(anf_node, real_index);
|
||||||
|
if (format == kOpFormat_FRAC_Z) {
|
||||||
|
format = kOpFormat_FRACTAL_Z;
|
||||||
|
} else if (format == kOpFormat_DEFAULT) {
|
||||||
|
format = kOpFormat_NCHW;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return format;
|
||||||
|
}
|
||||||
|
|
||||||
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
|
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
|
||||||
std::vector<size_t> *output_size_list) {
|
std::vector<size_t> *output_size_list) {
|
||||||
if (input_size_list == nullptr || output_size_list == nullptr) {
|
if (input_size_list == nullptr || output_size_list == nullptr) {
|
||||||
|
|
|
@ -93,7 +93,7 @@ class TbeKernelJsonCreator {
|
||||||
nlohmann::json *outputs_json);
|
nlohmann::json *outputs_json);
|
||||||
bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
|
bool GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpInfo> &op_info,
|
||||||
nlohmann::json *attrs_json);
|
nlohmann::json *attrs_json);
|
||||||
void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj);
|
static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj);
|
||||||
bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
|
bool GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
|
||||||
const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i,
|
const std::shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name, size_t input_i,
|
||||||
std::vector<nlohmann::json> *input_list);
|
std::vector<nlohmann::json> *input_list);
|
||||||
|
@ -105,6 +105,13 @@ class TbeKernelJsonCreator {
|
||||||
void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
|
void GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
|
||||||
const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
|
const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
|
||||||
std::vector<nlohmann::json> *output_list);
|
std::vector<nlohmann::json> *output_list);
|
||||||
|
std::vector<size_t> GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||||
|
std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||||
|
std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||||
|
std::vector<size_t> GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||||
|
std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||||
|
std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const;
|
||||||
|
|
||||||
kCreaterType creater_type_;
|
kCreaterType creater_type_;
|
||||||
std::string json_name_;
|
std::string json_name_;
|
||||||
std::string json_info_;
|
std::string json_info_;
|
||||||
|
|
|
@ -1,664 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "kernel/tbe/tbe_kernel_select.h"
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
#include "session/anf_runtime_algorithm.h"
|
|
||||||
#include "kernel/oplib/oplib.h"
|
|
||||||
#include "kernel/tbe/tbe_kernel_build.h"
|
|
||||||
#include "nlohmann/json.hpp"
|
|
||||||
#include "common/utils.h"
|
|
||||||
#include "utils/context/ms_context.h"
|
|
||||||
#include "kernel/tbe/tbe_python_funcs.h"
|
|
||||||
#include "pre_activate/common/helper.h"
|
|
||||||
#include "kernel/tbe/tbe_convert_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace kernel {
|
|
||||||
constexpr auto kName = "name";
|
|
||||||
constexpr auto kDtype = "dtype";
|
|
||||||
constexpr auto kFormat = "format";
|
|
||||||
constexpr auto kPrefixInput = "input";
|
|
||||||
constexpr auto kPrefixOutput = "output";
|
|
||||||
const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"},
|
|
||||||
{"NHWC", "DefaultFormat"},
|
|
||||||
{"ND", "DefaultFormat"},
|
|
||||||
{"FRACTAL_Z", "FracZ"},
|
|
||||||
{"NDHWC", "DefaultFormat"}};
|
|
||||||
static const std::vector<std::string> CHECK_SUPPORTED_OPTYPE{
|
|
||||||
"MatMul", "BatchMatMul", "TopK", "InTopK", "Pack", "GatherNd", "UnsortedSegmentMinD", "UnsortedSegmentProdD", "Cast"};
|
|
||||||
|
|
||||||
bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
|
|
||||||
|
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
||||||
auto iter = std::find(CHECK_SUPPORTED_OPTYPE.begin(), CHECK_SUPPORTED_OPTYPE.end(), op_name);
|
|
||||||
if (iter == CHECK_SUPPORTED_OPTYPE.end()) {
|
|
||||||
MS_LOG(DEBUG) << "Op " << op_name << "this op does not need to check op supported.";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// replace kernel_info with current kernel info
|
|
||||||
auto ori_select_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(anf_node);
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(select_kernel_build_info, anf_node.get());
|
|
||||||
|
|
||||||
nlohmann::json kernel_json;
|
|
||||||
TbeKernelJsonCreator creator(CHECK_SUPPORTED);
|
|
||||||
bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json);
|
|
||||||
if (!ret) {
|
|
||||||
MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed";
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = TbePythonFuncs::CheckSupported(kernel_json);
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get());
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CheckJsonItemValidity(const nlohmann::json &json_obj, const std::string &key_name,
|
|
||||||
const std::vector<std::string> &keys) {
|
|
||||||
if (!json_obj[key_name].is_object()) {
|
|
||||||
MS_LOG(DEBUG) << key_name << "is not an object!";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (auto key : keys) {
|
|
||||||
if (json_obj[key_name].find(key) == json_obj[key_name].end()) {
|
|
||||||
MS_LOG(DEBUG) << "Key" << key << "of " << key_name << " is not found!";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> SplitStr(const std::string &string, const std::string &sep) {
|
|
||||||
std::vector<std::string> result;
|
|
||||||
size_t start = 0;
|
|
||||||
size_t index = string.find(sep, start);
|
|
||||||
std::string substr;
|
|
||||||
while (index != std::string::npos) {
|
|
||||||
if (string.size() > start) {
|
|
||||||
substr = string.substr(start, index - start);
|
|
||||||
}
|
|
||||||
(void)substr.erase(0, substr.find_first_not_of(' '));
|
|
||||||
(void)substr.erase(substr.find_last_not_of(' ') + 1);
|
|
||||||
auto iter = DYNAMIC_FORMAT_MAP.find(substr);
|
|
||||||
if (iter != DYNAMIC_FORMAT_MAP.end()) {
|
|
||||||
substr = iter->second;
|
|
||||||
}
|
|
||||||
result.push_back(substr);
|
|
||||||
start = index + sep.size();
|
|
||||||
index = string.find(sep, start);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (string.size() > start) {
|
|
||||||
substr = string.substr(start);
|
|
||||||
}
|
|
||||||
(void)substr.erase(0, substr.find_first_not_of(' '));
|
|
||||||
(void)substr.erase(substr.find_last_not_of(' ') + 1);
|
|
||||||
auto iter = DYNAMIC_FORMAT_MAP.find(substr);
|
|
||||||
if (iter != DYNAMIC_FORMAT_MAP.end()) {
|
|
||||||
substr = iter->second;
|
|
||||||
}
|
|
||||||
result.push_back(substr);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ConvertFormatDtype(const std::string &format, const std::string &dtype, const std::shared_ptr<OpIOInfo> &io_info) {
|
|
||||||
MS_EXCEPTION_IF_NULL(io_info);
|
|
||||||
std::vector<std::string> format_vec = SplitStr(format, ",");
|
|
||||||
std::vector<std::string> dtype_vec = SplitStr(dtype, ",");
|
|
||||||
io_info->set_formats(format_vec);
|
|
||||||
io_info->set_dtypes(dtype_vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector<std::shared_ptr<OpIOInfo>> *const inputs,
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> *const outputs) {
|
|
||||||
nlohmann::json json_obj = nlohmann::json::parse(jsonStr);
|
|
||||||
if (!json_obj.is_object()) {
|
|
||||||
MS_LOG(DEBUG) << "JsonStr is not an object, the jsonStr is:" << jsonStr;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
std::vector<std::string> keys = {kName, kDtype, kFormat};
|
|
||||||
for (const auto &item : json_obj.items()) {
|
|
||||||
std::string key_name;
|
|
||||||
key_name = item.key();
|
|
||||||
if (key_name.empty()) {
|
|
||||||
MS_LOG(DEBUG) << "Key name is empty!";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!CheckJsonItemValidity(json_obj, key_name, keys)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) {
|
|
||||||
std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>();
|
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
|
||||||
input->set_name(json_obj[key_name].at(kName));
|
|
||||||
ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input);
|
|
||||||
inputs->emplace_back(input);
|
|
||||||
} else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) {
|
|
||||||
std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>();
|
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
|
||||||
output->set_name(json_obj[key_name].at(kName));
|
|
||||||
ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), output);
|
|
||||||
outputs->emplace_back(output);
|
|
||||||
} else {
|
|
||||||
MS_LOG(DEBUG) << "Key name:" << key_name << " is undefined!";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string OpSelectFormat(const std::shared_ptr<AnfNode> &anf_node) {
|
|
||||||
nlohmann::json kernel_json;
|
|
||||||
std::string res_json_str;
|
|
||||||
TbeKernelJsonCreator creator(OP_SELECT_FORMAT);
|
|
||||||
bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json);
|
|
||||||
if (!ret) {
|
|
||||||
MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed";
|
|
||||||
return res_json_str;
|
|
||||||
}
|
|
||||||
res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json);
|
|
||||||
MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str;
|
|
||||||
return res_json_str;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetTidyInputsInfo(const std::shared_ptr<AnfNode> &anf_node,
|
|
||||||
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
|
|
||||||
const std::vector<std::shared_ptr<OpIOInfo>> &inputs) {
|
|
||||||
std::vector<TypeId> inputs_type;
|
|
||||||
std::vector<std::string> inputs_format;
|
|
||||||
std::vector<int> dyn_input_sizes;
|
|
||||||
size_t dyn_input_idx = 0;
|
|
||||||
size_t kernel_info_index = 0;
|
|
||||||
size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node);
|
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
|
|
||||||
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
|
||||||
MS_EXCEPTION_IF_NULL(inputs[i]);
|
|
||||||
std::string param_type = inputs[i]->param_type();
|
|
||||||
if (i >= real_input_num) {
|
|
||||||
MS_LOG(INFO) << "Input index: " << i << " is out of real_input_num:" << real_input_num;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, i);
|
|
||||||
auto format = kOpFormat_DEFAULT;
|
|
||||||
if (param_type == "dynamic") {
|
|
||||||
if (!dyn_input_sizes.empty()) {
|
|
||||||
for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
|
|
||||||
kernel_info_index++;
|
|
||||||
inputs_type.emplace_back(type_id);
|
|
||||||
inputs_format.emplace_back(format);
|
|
||||||
}
|
|
||||||
dyn_input_idx++;
|
|
||||||
}
|
|
||||||
} else if (param_type == "required") {
|
|
||||||
kernel_info_index++;
|
|
||||||
inputs_type.emplace_back(type_id);
|
|
||||||
inputs_format.emplace_back(format);
|
|
||||||
} else {
|
|
||||||
if (kernel_info_index < real_input_num) {
|
|
||||||
MS_LOG(INFO) << "Input type is optional, input index is :" << kernel_info_index;
|
|
||||||
kernel_info_index++;
|
|
||||||
inputs_type.emplace_back(type_id);
|
|
||||||
inputs_format.emplace_back(format);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
builder->SetInputsDeviceType(inputs_type);
|
|
||||||
builder->SetInputsFormat(inputs_format);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetTidyOutputsInfo(const std::shared_ptr<AnfNode> &anf_node,
|
|
||||||
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
|
|
||||||
const std::vector<std::shared_ptr<OpIOInfo>> &outputs) {
|
|
||||||
std::vector<TypeId> outputs_type;
|
|
||||||
std::vector<std::string> outputs_format;
|
|
||||||
auto real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
|
||||||
size_t output_idx = 0;
|
|
||||||
for (const auto &output : outputs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
|
||||||
if (output_idx >= real_output_num) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
size_t output_num = 0;
|
|
||||||
if (output->param_type() == "dynamic") {
|
|
||||||
if (outputs.size() > 1) {
|
|
||||||
MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
|
|
||||||
}
|
|
||||||
output_num = real_output_num;
|
|
||||||
} else if (output->param_type() == "required") {
|
|
||||||
output_num = 1;
|
|
||||||
} else {
|
|
||||||
if (output_idx < real_output_num) {
|
|
||||||
MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
|
|
||||||
output_num = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
|
||||||
auto type_id = AnfAlgo::GetOutputInferDataType(anf_node, output_idx);
|
|
||||||
outputs_type.emplace_back(type_id);
|
|
||||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
|
||||||
output_idx++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
builder->SetOutputsDeviceType(outputs_type);
|
|
||||||
builder->SetOutputsFormat(outputs_format);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GenTidyKernelBuildInfo(const std::shared_ptr<AnfNode> &anf_node,
|
|
||||||
const std::vector<std::shared_ptr<OpIOInfo>> &inputs,
|
|
||||||
const std::vector<std::shared_ptr<OpIOInfo>> &outputs) {
|
|
||||||
auto builder_tmp = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
||||||
builder_tmp->SetKernelType(TBE_KERNEL);
|
|
||||||
SetTidyInputsInfo(anf_node, builder_tmp, inputs);
|
|
||||||
SetTidyOutputsInfo(anf_node, builder_tmp, outputs);
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder_tmp->Build(), anf_node.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr,
|
|
||||||
const std::shared_ptr<OpInfo> &op_info_new_ptr) {
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_static = op_info_ptr->inputs_ptr();
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> outputs_static = op_info_ptr->outputs_ptr();
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_dyn;
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> outputs_dyn;
|
|
||||||
if ((op_info_ptr->imply_type() == kTBE) && (!mindspore::opt::IsNopNode(kernel_node->cast<AnfNodePtr>()))) {
|
|
||||||
// 1. create tidy kernelBuildInfo in order to generate json for calling op_select_format
|
|
||||||
auto anf_node = kernel_node->cast<std::shared_ptr<AnfNode>>();
|
|
||||||
auto kernel_build_info_ptr = AnfAlgo::GetSelectKernelBuildInfo(anf_node);
|
|
||||||
GenTidyKernelBuildInfo(kernel_node, inputs_static, outputs_static);
|
|
||||||
|
|
||||||
// 2.get dynamic format from op_impl
|
|
||||||
std::string res_json_str;
|
|
||||||
auto context_ptr = MsContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
||||||
if (context_ptr->execution_mode() != kPynativeMode) {
|
|
||||||
res_json_str = OpSelectFormat(kernel_node);
|
|
||||||
}
|
|
||||||
if (!res_json_str.empty()) {
|
|
||||||
(void)ParseDynamicFormatJson(res_json_str, &inputs_dyn, &outputs_dyn);
|
|
||||||
}
|
|
||||||
if (inputs_static.size() != inputs_dyn.size()) {
|
|
||||||
inputs_dyn.clear();
|
|
||||||
}
|
|
||||||
if (outputs_static.size() != outputs_dyn.size()) {
|
|
||||||
outputs_dyn.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. resume kernel node's SelectKernelBuildInfo
|
|
||||||
// As it has been replaced by GenTidyKernelBuildInfo in order to call python func
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_ptr, anf_node.get());
|
|
||||||
}
|
|
||||||
// 4.replace by dynamic format and dtype
|
|
||||||
if (inputs_dyn.empty() && outputs_dyn.empty()) {
|
|
||||||
MS_LOG(INFO) << "Dynamic select format response is empty, use static register info.";
|
|
||||||
op_info_new_ptr->set_inputs_ptr(inputs_static);
|
|
||||||
op_info_new_ptr->set_outputs_ptr(outputs_static);
|
|
||||||
} else {
|
|
||||||
MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format.";
|
|
||||||
for (size_t i = 0; i < inputs_static.size(); i++) {
|
|
||||||
inputs_dyn[i]->set_param_type(inputs_static[i]->param_type());
|
|
||||||
inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type());
|
|
||||||
}
|
|
||||||
for (size_t j = 0; j < outputs_static.size(); j++) {
|
|
||||||
outputs_dyn[j]->set_param_type(outputs_static[j]->param_type());
|
|
||||||
outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type());
|
|
||||||
}
|
|
||||||
op_info_new_ptr->set_inputs_ptr(inputs_dyn);
|
|
||||||
op_info_new_ptr->set_outputs_ptr(outputs_dyn);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5.copy other opinfo to new op_info_new
|
|
||||||
op_info_new_ptr->set_op_name(op_info_ptr->op_name());
|
|
||||||
op_info_new_ptr->set_imply_type(op_info_ptr->imply_type());
|
|
||||||
op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
|
|
||||||
for (const auto &c : reshape_type_str) {
|
|
||||||
switch (c) {
|
|
||||||
case 'N':
|
|
||||||
reshape_type_vec->push_back(kernel::N);
|
|
||||||
break;
|
|
||||||
case 'C':
|
|
||||||
reshape_type_vec->push_back(kernel::C);
|
|
||||||
break;
|
|
||||||
case 'H':
|
|
||||||
reshape_type_vec->push_back(kernel::H);
|
|
||||||
break;
|
|
||||||
case 'W':
|
|
||||||
reshape_type_vec->push_back(kernel::W);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
|
|
||||||
size_t builder_idex, const std::vector<int> &dyn_input_sizes,
|
|
||||||
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
|
|
||||||
MS_EXCEPTION_IF_NULL(builder);
|
|
||||||
|
|
||||||
std::vector<TypeId> inputs_device_type;
|
|
||||||
std::vector<std::string> inputs_format;
|
|
||||||
size_t dyn_input_idx = 0;
|
|
||||||
size_t kernel_info_index = 0;
|
|
||||||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
|
||||||
size_t kernel_info_cnt = inputs[0]->dtypes().size();
|
|
||||||
|
|
||||||
std::vector<std::vector<Axis>> reshape_types;
|
|
||||||
for (const auto &input : inputs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
|
||||||
std::string param_type = input->param_type();
|
|
||||||
std::vector<std::string> dtypes = input->dtypes();
|
|
||||||
std::vector<std::string> formats = input->formats();
|
|
||||||
if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
|
|
||||||
MS_LOG(ERROR) << "Set input kernel builder info, dtyps size != formats size.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Axis> reshape_type;
|
|
||||||
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (param_type == "dynamic") {
|
|
||||||
if (dyn_input_sizes.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
|
|
||||||
kernel_info_index++;
|
|
||||||
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
|
|
||||||
inputs_device_type.push_back(type_id);
|
|
||||||
inputs_format.push_back(formats[builder_idex]);
|
|
||||||
reshape_types.push_back(reshape_type);
|
|
||||||
}
|
|
||||||
dyn_input_idx++;
|
|
||||||
} else if (param_type == "required") {
|
|
||||||
kernel_info_index++;
|
|
||||||
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
|
|
||||||
inputs_device_type.push_back(type_id);
|
|
||||||
inputs_format.push_back(formats[builder_idex]);
|
|
||||||
reshape_types.push_back(reshape_type);
|
|
||||||
} else {
|
|
||||||
if (kernel_info_index < real_input_num) {
|
|
||||||
MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index;
|
|
||||||
kernel_info_index++;
|
|
||||||
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
|
|
||||||
inputs_device_type.push_back(type_id);
|
|
||||||
inputs_format.push_back(formats[builder_idex]);
|
|
||||||
reshape_types.push_back(reshape_type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
builder->SetInputReshapeType(reshape_types);
|
|
||||||
builder->SetInputsDeviceType(inputs_device_type);
|
|
||||||
builder->SetInputsFormat(inputs_format);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
|
|
||||||
const size_t &real_output_num,
|
|
||||||
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
|
|
||||||
// not now but in the next we need to support dynamic output case
|
|
||||||
MS_EXCEPTION_IF_NULL(builder);
|
|
||||||
|
|
||||||
size_t output_idx = 0;
|
|
||||||
std::vector<TypeId> outputs_device_type;
|
|
||||||
std::vector<std::string> outputs_format;
|
|
||||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
|
||||||
size_t kernel_info_cnt = outputs[0]->dtypes().size();
|
|
||||||
|
|
||||||
std::vector<std::vector<Axis>> reshape_types;
|
|
||||||
for (const auto &output : outputs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
|
||||||
if (output_idx >= real_output_num) {
|
|
||||||
MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::vector<Axis> reshape_type;
|
|
||||||
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t output_num = 0;
|
|
||||||
if (output->param_type() == "dynamic") {
|
|
||||||
if (outputs.size() > 1) {
|
|
||||||
MS_LOG(EXCEPTION) << "Dynamic output is unsupported multi output!";
|
|
||||||
}
|
|
||||||
output_num = real_output_num;
|
|
||||||
} else if (output->param_type() == "required") {
|
|
||||||
output_num = 1;
|
|
||||||
} else {
|
|
||||||
if (output_idx < real_output_num) {
|
|
||||||
MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is " << output_idx;
|
|
||||||
output_num = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
|
||||||
std::vector<std::string> dtypes = output->dtypes();
|
|
||||||
std::vector<std::string> formats = output->formats();
|
|
||||||
if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
|
|
||||||
MS_LOG(ERROR) << "Set output kernel builder info, dtyps size != formats size.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
|
|
||||||
outputs_device_type.push_back(type_id);
|
|
||||||
outputs_format.push_back(formats[builder_idex]);
|
|
||||||
reshape_types.push_back(reshape_type);
|
|
||||||
output_idx++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
builder->SetOutputReshapeType(reshape_types);
|
|
||||||
builder->SetOutputsFormat(outputs_format);
|
|
||||||
builder->SetOutputsDeviceType(outputs_device_type);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
|
|
||||||
Processor processor, const std::shared_ptr<const OpInfo> &op_info_ptr) {
|
|
||||||
MS_EXCEPTION_IF_NULL(builder);
|
|
||||||
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
|
||||||
|
|
||||||
builder->SetProcessor(processor);
|
|
||||||
std::string fusion_type = op_info_ptr->fusion_type();
|
|
||||||
if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) {
|
|
||||||
builder->SetFusionType(tbe::GetFusionType(fusion_type));
|
|
||||||
}
|
|
||||||
builder->SetOpPattern(op_info_ptr->op_pattern());
|
|
||||||
builder->SetKernelType(TBE_KERNEL);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr,
|
|
||||||
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
|
||||||
size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
||||||
size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
|
|
||||||
std::vector<int> dyn_input_sizes;
|
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
|
|
||||||
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
|
|
||||||
}
|
|
||||||
if (!inputs.empty()) {
|
|
||||||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
|
||||||
size_t kernel_info_cnt = inputs[0]->dtypes().size();
|
|
||||||
for (size_t j = 0; j < kernel_info_cnt; j++) {
|
|
||||||
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
||||||
MS_EXCEPTION_IF_NULL(builder);
|
|
||||||
SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr);
|
|
||||||
|
|
||||||
if (!SetKernelBuilderInputInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
|
|
||||||
MS_LOG(ERROR) << "Parse kernel metadata, set inputs kernel builder info failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!outputs.empty()) {
|
|
||||||
if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) {
|
|
||||||
MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel_info_list->push_back(builder->Build());
|
|
||||||
}
|
|
||||||
} else if (!outputs.empty()) {
|
|
||||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
|
||||||
size_t kernel_info_cnt = outputs[0]->dtypes().size();
|
|
||||||
for (size_t j = 0; j < kernel_info_cnt; j++) {
|
|
||||||
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
||||||
MS_EXCEPTION_IF_NULL(builder);
|
|
||||||
SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr);
|
|
||||||
|
|
||||||
if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) {
|
|
||||||
MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel_info_list->push_back(builder->Build());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
|
|
||||||
// if format is default, it remarkes support all format
|
|
||||||
if (kOpFormatList.find(format) == kOpFormatList.end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
|
|
||||||
}
|
|
||||||
if (format == kOpFormat_DEFAULT) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// if shape size is 0, the shape will be a scalar
|
|
||||||
if (shape.empty()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (shape.size() > kShape4dDims) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (format == kOpFormat_FRAC_NZ && shape.size() < 2) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
||||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
|
||||||
const size_t kCAxis = 1;
|
|
||||||
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
|
|
||||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
|
|
||||||
if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) {
|
|
||||||
if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (kernel_name == "ReduceMean") {
|
|
||||||
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
|
|
||||||
if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
|
|
||||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
|
|
||||||
if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) {
|
|
||||||
if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (kernel_name == "ReduceMean") {
|
|
||||||
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
|
|
||||||
if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
|
|
||||||
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
|
|
||||||
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
|
||||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list;
|
|
||||||
|
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
|
||||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
|
|
||||||
if (op_info_ptr == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// dynamic get op format and dtype and replace opinfo
|
|
||||||
auto op_info_new_ptr = std::make_shared<OpInfo>();
|
|
||||||
ReplaceByDynamicFormatDtype(kernel_node, op_info_ptr, op_info_new_ptr);
|
|
||||||
|
|
||||||
if (!ParseMetadata(kernel_node, op_info_new_ptr, &parse_info_list)) {
|
|
||||||
MS_LOG(INFO) << "Tbe parsed metadata of op[" << op_name << "] failed.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto context_ptr = MsContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
||||||
for (const auto &parse_info : parse_info_list) {
|
|
||||||
if (IsValidKernelInfo(kernel_node, *(parse_info))) {
|
|
||||||
if (CheckSupported(kernel_node, parse_info)) {
|
|
||||||
kernel_info_list->push_back(parse_info);
|
|
||||||
} else {
|
|
||||||
MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (kernel_info_list->empty()) {
|
|
||||||
MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "].";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace kernel
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,20 +13,18 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_
|
||||||
#ifndef MINDSPORE_TBE_KERNEL_SELECT_H
|
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_SELECT_COMMON_UTILS_H_
|
||||||
#define MINDSPORE_TBE_KERNEL_SELECT_H
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
|
||||||
#include "kernel/oplib/opinfo.h"
|
|
||||||
#include "kernel/kernel_build_info.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
struct SupportFormat {
|
||||||
|
std::vector<std::vector<std::string>> input_format;
|
||||||
|
std::vector<std::vector<std::string>> output_format;
|
||||||
|
};
|
||||||
|
using SupportFormatItem = std::vector<std::string>;
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_TBE_KERNEL_SELECT_H
|
#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_
|
|
@ -0,0 +1,319 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
constexpr char kDynInputKey[] = "dyn_input_sizes";
|
||||||
|
constexpr size_t kInputIndex_0 = 0;
|
||||||
|
constexpr size_t kChannelN = 0;
|
||||||
|
constexpr size_t kChannelC = 1;
|
||||||
|
constexpr size_t kAlignmented16 = 16;
|
||||||
|
// 1. all shape no scalar and same
|
||||||
|
// 2. part scalar : no_scalar (shape size > xxx && alig xxx)
|
||||||
|
// 3. all no_scalar and not same (broad cast xxx dim)
|
||||||
|
bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
input_num_ = 0;
|
||||||
|
output_num_ = 0;
|
||||||
|
input_shapes_.clear();
|
||||||
|
output_shapes_.clear();
|
||||||
|
if (AnfAlgo::HasNodeAttr(kDynInputKey, cnode_ptr_)) {
|
||||||
|
MS_LOG(INFO) << "This broadcast node has dynamic input.";
|
||||||
|
auto dynamic_size_vec = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode_ptr_, kDynInputKey);
|
||||||
|
if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << "dynamic attr set error, please check.";
|
||||||
|
}
|
||||||
|
auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0);
|
||||||
|
PadScalarShape(&dynamic_input_shape0_);
|
||||||
|
input_shapes_.emplace_back(dynamic_input_shape0_);
|
||||||
|
input_num_ = 1;
|
||||||
|
} else {
|
||||||
|
input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||||
|
for (size_t i = 0; i < input_num_; ++i) {
|
||||||
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
|
||||||
|
PadScalarShape(&input_shape);
|
||||||
|
input_shapes_.emplace_back(input_shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||||
|
for (size_t i = 0; i < output_num_; ++i) {
|
||||||
|
auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
|
||||||
|
PadScalarShape(&output);
|
||||||
|
output_shapes_.emplace_back(output);
|
||||||
|
}
|
||||||
|
AssignSupportFormat(kOpFormat_DEFAULT, support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
if (IsSameShape()) {
|
||||||
|
if (!HasScalarInput()) {
|
||||||
|
AssignSupportFormat(kOpFormat_NC1HWC0, support_format);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SupportFormatItem input_support_format;
|
||||||
|
SupportFormatItem output_support_format;
|
||||||
|
if (HasScalarInput()) {
|
||||||
|
for (const auto &shape : input_shapes_) {
|
||||||
|
if (IsScalarShape(shape)) {
|
||||||
|
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||||
|
} else {
|
||||||
|
if (!Is4DShape(shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (shape[kChannelC] % kAlignmented16 != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
input_support_format.emplace_back(kOpFormat_NC1HWC0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const auto &shape : input_shapes_) {
|
||||||
|
if (!Is4DShape(shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto shape_tmp = input_shapes_[0];
|
||||||
|
auto broadcast_c_axis = std::any_of(
|
||||||
|
input_shapes_.begin(), input_shapes_.end(),
|
||||||
|
[&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); });
|
||||||
|
if (broadcast_c_axis) {
|
||||||
|
MS_LOG(INFO) << "This node broadcast c channel.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
input_support_format.assign(input_num_, kOpFormat_NC1HWC0);
|
||||||
|
}
|
||||||
|
GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format);
|
||||||
|
support_format->input_format.emplace_back(input_support_format);
|
||||||
|
support_format->output_format.emplace_back(output_support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
if (IsSameShape()) {
|
||||||
|
if (!HasScalarInput()) {
|
||||||
|
AssignSupportFormat(kOpFormat_FRAC_Z, support_format);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SupportFormatItem input_support_format;
|
||||||
|
SupportFormatItem output_support_format;
|
||||||
|
if (HasScalarInput()) {
|
||||||
|
for (const auto &shape : input_shapes_) {
|
||||||
|
if (IsScalarShape(shape)) {
|
||||||
|
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||||
|
} else {
|
||||||
|
if (!Is4DShape(shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
input_support_format.emplace_back(kOpFormat_FRAC_Z);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format);
|
||||||
|
support_format->input_format.emplace_back(input_support_format);
|
||||||
|
support_format->output_format.emplace_back(output_support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
if (IsSameShape()) {
|
||||||
|
if (!HasScalarInput()) {
|
||||||
|
AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SupportFormatItem input_support_format;
|
||||||
|
SupportFormatItem output_support_format;
|
||||||
|
if (HasScalarInput()) {
|
||||||
|
for (const auto &shape : input_shapes_) {
|
||||||
|
if (IsScalarShape(shape)) {
|
||||||
|
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||||
|
} else {
|
||||||
|
if (!Is4DShape(shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (shape[kChannelN] % kAlignmented16 != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const auto &shape : input_shapes_) {
|
||||||
|
if (!Is4DShape(shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto shape_tmp = input_shapes_[0];
|
||||||
|
auto broadcast_nc_axis =
|
||||||
|
std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) {
|
||||||
|
return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN));
|
||||||
|
});
|
||||||
|
if (broadcast_nc_axis) {
|
||||||
|
MS_LOG(INFO) << "This node broadcast n || c channel.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0);
|
||||||
|
}
|
||||||
|
GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format);
|
||||||
|
support_format->input_format.emplace_back(input_support_format);
|
||||||
|
support_format->output_format.emplace_back(output_support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
if (IsSameShape()) {
|
||||||
|
if (!HasScalarInput()) {
|
||||||
|
AssignSupportFormat(kOpFormat_FRAC_NZ, support_format);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SupportFormatItem input_support_format;
|
||||||
|
SupportFormatItem output_support_format;
|
||||||
|
if (HasScalarInput()) {
|
||||||
|
for (const auto &shape : input_shapes_) {
|
||||||
|
if (IsScalarShape(shape)) {
|
||||||
|
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||||
|
} else {
|
||||||
|
if (shape.size() < kShape2dDims) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
input_support_format.emplace_back(kOpFormat_FRAC_NZ);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(),
|
||||||
|
[](const std::vector<size_t> &elem) { return elem.size() < kShape2dDims; });
|
||||||
|
if (less_2dims) {
|
||||||
|
MS_LOG(INFO) << "This node dim less 2.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape_tmp = input_shapes_[0];
|
||||||
|
auto broadcast_last_dim =
|
||||||
|
std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector<size_t> &elem) {
|
||||||
|
return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) ||
|
||||||
|
(shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2));
|
||||||
|
});
|
||||||
|
if (broadcast_last_dim) {
|
||||||
|
MS_LOG(INFO) << "This node broadcast last channel.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
input_support_format.assign(input_num_, kOpFormat_FRAC_NZ);
|
||||||
|
}
|
||||||
|
GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format);
|
||||||
|
support_format->input_format.emplace_back(input_support_format);
|
||||||
|
support_format->output_format.emplace_back(output_support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const {
|
||||||
|
return shape.size() == kShape4dDims;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelBroadCastSelecter::IsSameShape() const {
|
||||||
|
auto shape = input_shapes_.begin();
|
||||||
|
for (const auto &item : input_shapes_) {
|
||||||
|
if (shape->size() != item.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < shape->size(); ++i) {
|
||||||
|
if (shape->at(i) != item.at(i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelBroadCastSelecter::PadScalarShape(std::vector<size_t> *shape) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (shape->empty()) {
|
||||||
|
shape->emplace_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector<size_t> &shape) const {
|
||||||
|
return (shape.size() == 1 && shape[0] == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelBroadCastSelecter::HasScalarInput() const {
|
||||||
|
bool ret = false;
|
||||||
|
for (const auto &shape : input_shapes_) {
|
||||||
|
if (IsScalarShape(shape)) {
|
||||||
|
ret = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format,
|
||||||
|
SupportFormatItem *output_support_item) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(output_support_item);
|
||||||
|
for (const auto &shape : output_shapes_) {
|
||||||
|
if (IsScalarShape(shape)) {
|
||||||
|
output_support_item->emplace_back(kOpFormat_DEFAULT);
|
||||||
|
} else {
|
||||||
|
output_support_item->emplace_back(support_format);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str,
|
||||||
|
mindspore::kernel::SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
SupportFormatItem input_support_format;
|
||||||
|
SupportFormatItem output_support_format;
|
||||||
|
input_support_format.assign(input_num_, support_format_str);
|
||||||
|
output_support_format.assign(output_num_, support_format_str);
|
||||||
|
support_format->input_format.emplace_back(input_support_format);
|
||||||
|
support_format->output_format.emplace_back(output_support_format);
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class TbeKernelBroadCastSelecter {
|
||||||
|
public:
|
||||||
|
explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
|
||||||
|
~TbeKernelBroadCastSelecter() = default;
|
||||||
|
bool GetShapeInfo(SupportFormat *support_format);
|
||||||
|
bool IsBroadCastSupport5HD(SupportFormat *support_format) const;
|
||||||
|
bool IsBroadCastSupportFracZ(SupportFormat *support_format) const;
|
||||||
|
bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const;
|
||||||
|
bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const;
|
||||||
|
bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool IsSameShape() const;
|
||||||
|
void PadScalarShape(std::vector<size_t> *shape) const;
|
||||||
|
bool Is4DShape(const std::vector<size_t> &shape) const;
|
||||||
|
bool IsScalarShape(const std::vector<size_t> &shape) const;
|
||||||
|
bool HasScalarInput() const;
|
||||||
|
void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const;
|
||||||
|
void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const;
|
||||||
|
// broadcast
|
||||||
|
CNodePtr cnode_ptr_;
|
||||||
|
size_t input_num_{};
|
||||||
|
size_t output_num_{};
|
||||||
|
std::vector<std::vector<size_t>> input_shapes_;
|
||||||
|
std::vector<std::vector<size_t>> output_shapes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H
|
|
@ -0,0 +1,180 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
constexpr char kKeepDims[] = "keep_dims";
|
||||||
|
constexpr char kAxis[] = "axis";
|
||||||
|
constexpr char kTypeInt32[] = "Int32";
|
||||||
|
constexpr size_t kInputIndex_0 = 0;
|
||||||
|
constexpr size_t kOutputIndex_0 = 0;
|
||||||
|
constexpr size_t kChannelN = 0;
|
||||||
|
constexpr size_t kChannelC = 1;
|
||||||
|
constexpr size_t kReduceNZMinDim = 3;
|
||||||
|
|
||||||
|
bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
input_shape_.clear();
|
||||||
|
output_shape_.clear();
|
||||||
|
axis_.clear();
|
||||||
|
auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||||
|
auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||||
|
if (input_num != 1 || output_num != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num
|
||||||
|
<< ", output num: " << output_num;
|
||||||
|
}
|
||||||
|
// get input/output shape
|
||||||
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0);
|
||||||
|
PadScalarShape(&input_shape_);
|
||||||
|
output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0);
|
||||||
|
PadScalarShape(&output_shape_);
|
||||||
|
// get keep dim attr
|
||||||
|
GetReduceAttrKeepDim();
|
||||||
|
// get axis attr
|
||||||
|
GetReduceAttrAxis();
|
||||||
|
AssignSupportFormat(kOpFormat_DEFAULT, support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
if (!Is4DShape(input_shape_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!keep_dims_ || axis_.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); });
|
||||||
|
if (reduce_c_axis) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
AssignSupportFormat(kOpFormat_NC1HWC0, support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
// like to 5HD
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const {
|
||||||
|
return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const {
|
||||||
|
return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
if (input_shape_.size() < kReduceNZMinDim) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (axis_.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) {
|
||||||
|
return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2));
|
||||||
|
});
|
||||||
|
if (reduce_last_axis) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
AssignSupportFormat(kOpFormat_FRAC_NZ, support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format,
|
||||||
|
mindspore::kernel::SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
if (!Is4DShape(input_shape_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!keep_dims_ || axis_.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(),
|
||||||
|
[](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); });
|
||||||
|
if (reduce_n_c_axis) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
AssignSupportFormat(format, support_format);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelReduceSelecter::GetReduceAttrAxis() {
|
||||||
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto axis = primitive->GetAttr(kAxis);
|
||||||
|
if (axis == nullptr) {
|
||||||
|
MS_LOG(INFO) << "This node does't have axie attr.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto type = axis->type();
|
||||||
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
std::vector<int> axis_list;
|
||||||
|
if (type->ToString() == kTypeInt32) {
|
||||||
|
axis_list.emplace_back(GetValue<int>(axis));
|
||||||
|
} else {
|
||||||
|
axis_list = GetValue<std::vector<int>>(axis);
|
||||||
|
}
|
||||||
|
for (const auto &elem : axis_list) {
|
||||||
|
if (elem < 0) {
|
||||||
|
axis_.emplace_back(input_shape_.size() + elem);
|
||||||
|
} else {
|
||||||
|
axis_.emplace_back(IntToSize(elem));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelReduceSelecter::GetReduceAttrKeepDim() {
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kKeepDims, cnode_ptr_)) {
|
||||||
|
MS_LOG(INFO) << "This node does't have keep_attr.";
|
||||||
|
keep_dims_ = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
keep_dims_ = AnfAlgo::GetNodeAttr<bool>(cnode_ptr_, kKeepDims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str,
|
||||||
|
mindspore::kernel::SupportFormat *support_format) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(support_format);
|
||||||
|
SupportFormatItem input_support_format;
|
||||||
|
SupportFormatItem output_support_format;
|
||||||
|
input_support_format.emplace_back(support_format_str);
|
||||||
|
output_support_format.emplace_back(support_format_str);
|
||||||
|
support_format->input_format.emplace_back(input_support_format);
|
||||||
|
support_format->output_format.emplace_back(output_support_format);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; }
|
||||||
|
|
||||||
|
void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (shape->empty()) {
|
||||||
|
shape->emplace_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class TbeKernelReduceSelecter {
|
||||||
|
public:
|
||||||
|
explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
|
||||||
|
~TbeKernelReduceSelecter() = default;
|
||||||
|
bool GetShapeInfo(SupportFormat *support_format);
|
||||||
|
bool IsReduceSupport5HD(SupportFormat *support_format) const;
|
||||||
|
bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const;
|
||||||
|
bool IsReduceSupportFracZ(SupportFormat *support_format) const;
|
||||||
|
bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const;
|
||||||
|
bool IsReduceSupportFracNZ(SupportFormat *support_format) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const;
|
||||||
|
void GetReduceAttrAxis();
|
||||||
|
void GetReduceAttrKeepDim();
|
||||||
|
void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const;
|
||||||
|
bool Is4DShape(const std::vector<size_t> &shape) const;
|
||||||
|
void PadScalarShape(std::vector<size_t> *shape) const;
|
||||||
|
CNodePtr cnode_ptr_;
|
||||||
|
std::vector<size_t> input_shape_{};
|
||||||
|
std::vector<size_t> output_shape_{};
|
||||||
|
std::vector<size_t> axis_{};
|
||||||
|
bool keep_dims_ = false;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H
|
|
@ -0,0 +1,633 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <utility>
|
||||||
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "kernel/oplib/oplib.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_build.h"
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "kernel/tbe/tbe_python_funcs.h"
|
||||||
|
#include "pre_activate/common/helper.h"
|
||||||
|
#include "kernel/tbe/tbe_convert_utils.h"
|
||||||
|
#include "parallel/ops_info/ops_utils.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
constexpr auto kName = "name";
|
||||||
|
constexpr auto kDtype = "dtype";
|
||||||
|
constexpr auto kFormat = "format";
|
||||||
|
constexpr auto kPrefixInput = "input";
|
||||||
|
constexpr auto kPrefixOutput = "output";
|
||||||
|
constexpr char kDynInputKey[] = "dyn_input_sizes";
|
||||||
|
constexpr char kParamTypeDynamic[] = "dynamic";
|
||||||
|
constexpr char kParamTypeRequre[] = "required";
|
||||||
|
constexpr char kParamTypeOptional[] = "optional";
|
||||||
|
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
||||||
|
auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list);
|
||||||
|
tbe_selecter.TbeMetadataInfoEx();
|
||||||
|
}
|
||||||
|
|
||||||
|
TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list)
|
||||||
|
: cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {}
|
||||||
|
|
||||||
|
void TbeKernelSelect::TbeMetadataInfoEx() {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info_list_);
|
||||||
|
node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_);
|
||||||
|
auto op_info_ptr = OpLib::FindOp(node_name_, kTBE);
|
||||||
|
if (!op_info_ptr) {
|
||||||
|
MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_
|
||||||
|
<< ", node name: " << cnode_ptr_->fullname_with_scope();
|
||||||
|
OpPattern pattern = op_info_ptr->op_pattern();
|
||||||
|
if (pattern == kCommonPattern) {
|
||||||
|
GetCommonPatternKernelInfo(*op_info_ptr);
|
||||||
|
} else if (pattern == kDynamicFormatPattern) {
|
||||||
|
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
|
||||||
|
} else if (pattern == kFormatAgnosticPattern) {
|
||||||
|
GetAgnosticPatternKernelInfo(*op_info_ptr);
|
||||||
|
} else if (pattern == kBroadcastPattern) {
|
||||||
|
GetBroadcastPatternKernelInfo(*op_info_ptr);
|
||||||
|
} else if (pattern == kReducePattern) {
|
||||||
|
GetReducePatternKernelInfo(*op_info_ptr);
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "Warning: op pattern is invailed.";
|
||||||
|
}
|
||||||
|
// check support
|
||||||
|
FilterInVaildKernelInfo();
|
||||||
|
MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
|
||||||
|
MS_LOG(INFO) << "start.";
|
||||||
|
// get dynamic inputs
|
||||||
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
std::vector<int> dyn_input_sizes;
|
||||||
|
if (primitive->HasAttr(kDynInputKey)) {
|
||||||
|
dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr(kDynInputKey));
|
||||||
|
}
|
||||||
|
// get real input/output num
|
||||||
|
size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||||
|
const auto inputs_info = op_info.inputs_ptr();
|
||||||
|
size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||||
|
const auto outputs_info = op_info.outputs_ptr();
|
||||||
|
if (inputs_info.empty() && outputs_info.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "op info input & output is null, please check.";
|
||||||
|
}
|
||||||
|
// create kernel build info from opinfo
|
||||||
|
size_t kernel_build_info_num =
|
||||||
|
inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size();
|
||||||
|
for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) {
|
||||||
|
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||||
|
SetTbeBuildCommonInfo(op_info, &builder);
|
||||||
|
std::vector<std::string> inputs_format;
|
||||||
|
std::vector<TypeId> inputs_device_type;
|
||||||
|
std::vector<std::vector<Axis>> inputs_reshape_type;
|
||||||
|
// input
|
||||||
|
if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes,
|
||||||
|
&inputs_format, &inputs_device_type, &inputs_reshape_type)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
builder.SetInputsDeviceType(inputs_device_type);
|
||||||
|
builder.SetInputsFormat(inputs_format);
|
||||||
|
builder.SetInputReshapeType(inputs_reshape_type);
|
||||||
|
// output
|
||||||
|
std::vector<std::string> outputs_format;
|
||||||
|
std::vector<TypeId> outputs_device_type;
|
||||||
|
std::vector<std::vector<Axis>> outputs_reshape_type;
|
||||||
|
if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes,
|
||||||
|
&outputs_format, &outputs_device_type, &outputs_reshape_type)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
builder.SetOutputsDeviceType(outputs_device_type);
|
||||||
|
builder.SetOutputsFormat(outputs_format);
|
||||||
|
builder.SetOutputReshapeType(outputs_reshape_type);
|
||||||
|
kernel_info_list_->emplace_back(builder.Build());
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "end.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) {
|
||||||
|
MS_LOG(INFO) << "start.";
|
||||||
|
//
|
||||||
|
OpInfo op_info_new;
|
||||||
|
CreateNewOpInfo(op_info, &op_info_new);
|
||||||
|
GetCommonPatternKernelInfo(op_info_new);
|
||||||
|
MS_LOG(INFO) << "end.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) {
|
||||||
|
MS_LOG(INFO) << "start.";
|
||||||
|
if (op_info.inputs_ptr().size() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "AgnosticPattern only support one input.";
|
||||||
|
}
|
||||||
|
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0);
|
||||||
|
if (kOpFormatList.find(format) == kOpFormatList.end()) {
|
||||||
|
MS_LOG(INFO) << "Got the unknown format " << format;
|
||||||
|
format = kOpFormat_DEFAULT;
|
||||||
|
}
|
||||||
|
SupportFormat support_format;
|
||||||
|
SupportFormatItem input_item;
|
||||||
|
SupportFormatItem output_item;
|
||||||
|
input_item.assign(op_info.inputs_ptr().size(), format);
|
||||||
|
output_item.assign(op_info.outputs_ptr().size(), format);
|
||||||
|
support_format.input_format.emplace_back(input_item);
|
||||||
|
support_format.output_format.emplace_back(output_item);
|
||||||
|
PrintSupportedFormat(support_format);
|
||||||
|
OpInfo op_info_new;
|
||||||
|
CreateNewOpInfo(op_info, support_format, &op_info_new);
|
||||||
|
GetCommonPatternKernelInfo(op_info_new);
|
||||||
|
MS_LOG(INFO) << "end.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) {
|
||||||
|
MS_LOG(INFO) << "start.";
|
||||||
|
auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_);
|
||||||
|
SupportFormat support_format;
|
||||||
|
broadcast_selecter.GetShapeInfo(&support_format);
|
||||||
|
if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) {
|
||||||
|
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD.";
|
||||||
|
}
|
||||||
|
if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) {
|
||||||
|
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ.";
|
||||||
|
}
|
||||||
|
if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) {
|
||||||
|
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0.";
|
||||||
|
}
|
||||||
|
if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) {
|
||||||
|
MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ.";
|
||||||
|
}
|
||||||
|
PrintSupportedFormat(support_format);
|
||||||
|
OpInfo op_info_new;
|
||||||
|
CreateNewOpInfo(op_info, support_format, &op_info_new);
|
||||||
|
GetCommonPatternKernelInfo(op_info_new);
|
||||||
|
MS_LOG(INFO) << "end.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
|
||||||
|
MS_LOG(INFO) << "start.";
|
||||||
|
auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_);
|
||||||
|
SupportFormat support_format;
|
||||||
|
reduce_selecter.GetShapeInfo(&support_format);
|
||||||
|
if (!reduce_selecter.IsReduceSupport5HD(&support_format)) {
|
||||||
|
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD.";
|
||||||
|
}
|
||||||
|
if (reduce_selecter.IsReduceSupportFracZ(&support_format)) {
|
||||||
|
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ.";
|
||||||
|
}
|
||||||
|
if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) {
|
||||||
|
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0.";
|
||||||
|
}
|
||||||
|
if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) {
|
||||||
|
MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ.";
|
||||||
|
}
|
||||||
|
PrintSupportedFormat(support_format);
|
||||||
|
OpInfo op_info_new;
|
||||||
|
CreateNewOpInfo(op_info, support_format, &op_info_new);
|
||||||
|
GetCommonPatternKernelInfo(op_info_new);
|
||||||
|
MS_LOG(INFO) << "end.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::FilterInVaildKernelInfo() {
|
||||||
|
if (kernel_info_list_->empty()) {
|
||||||
|
MS_LOG(INFO) << "Warning: get kernel build info failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto kernel_build_info_iter = kernel_info_list_->begin();
|
||||||
|
while (kernel_build_info_iter != kernel_info_list_->end()) {
|
||||||
|
if (!FilterInVaildShape(kernel_build_info_iter)) {
|
||||||
|
MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString();
|
||||||
|
kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!TbeCheckSupported(kernel_build_info_iter)) {
|
||||||
|
MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString();
|
||||||
|
kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
kernel_build_info_iter++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelSelect::FilterInVaildShape(
|
||||||
|
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) {
|
||||||
|
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
|
||||||
|
auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats();
|
||||||
|
for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) {
|
||||||
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
|
||||||
|
auto format = kernel_build_info_inputs_format.at(i);
|
||||||
|
if (!IsShapeMatchFormat(shape, format)) {
|
||||||
|
MS_LOG(INFO) << "The " << i << "th input check failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats();
|
||||||
|
for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) {
|
||||||
|
auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j);
|
||||||
|
auto format = kernel_build_info_outputs_format.at(j);
|
||||||
|
if (!IsShapeMatchFormat(shape, format)) {
|
||||||
|
MS_LOG(INFO) << "The " << j << "th input check failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
|
||||||
|
if (format == kOpFormat_DEFAULT) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
static std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
||||||
|
// if format is default, it remarkes support all format
|
||||||
|
if (kOpFormatList.find(format) == kOpFormatList.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
|
||||||
|
}
|
||||||
|
// server not support format with C04 suffix
|
||||||
|
if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) !=
|
||||||
|
kServerNotSupportFormat.end()) {
|
||||||
|
MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// not support format:
|
||||||
|
// 1 NDHWC with shape size != 5
|
||||||
|
// 2 FRAC_NZ with shape size < 2
|
||||||
|
// 3 !NDHWC with shape size > 4
|
||||||
|
if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) ||
|
||||||
|
(format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) ||
|
||||||
|
(format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) {
|
||||||
|
MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelSelect::TbeCheckSupported(
|
||||||
|
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) {
|
||||||
|
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
|
||||||
|
static const std::set<std::string> kCheckSupportedOpType = {parallel::MATMUL,
|
||||||
|
parallel::BATCHMATMUL,
|
||||||
|
parallel::TOPK,
|
||||||
|
parallel::IN_TOPK,
|
||||||
|
parallel::PACK,
|
||||||
|
parallel::GATHER_ND,
|
||||||
|
parallel::UNSORTEF_SEGMENT_MIND,
|
||||||
|
parallel::UNSORTEF_SEGMENT_PRODD,
|
||||||
|
parallel::CAST};
|
||||||
|
auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_);
|
||||||
|
if (iter == kCheckSupportedOpType.end()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Check support start.";
|
||||||
|
// replace kernel_info with current kernel info
|
||||||
|
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
|
||||||
|
nlohmann::json kernel_json;
|
||||||
|
TbeKernelJsonCreator creator(CHECK_SUPPORTED);
|
||||||
|
bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed.";
|
||||||
|
}
|
||||||
|
ret = TbePythonFuncs::CheckSupported(kernel_json);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info,
|
||||||
|
mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) {
|
||||||
|
MS_EXCEPTION_IF_NULL(builder);
|
||||||
|
builder->SetProcessor(AICORE);
|
||||||
|
std::string fusion_type = op_info.fusion_type();
|
||||||
|
if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) {
|
||||||
|
builder->SetFusionType(tbe::GetFusionType(fusion_type));
|
||||||
|
}
|
||||||
|
builder->SetOpPattern(op_info.op_pattern());
|
||||||
|
builder->SetKernelType(TBE_KERNEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
|
||||||
|
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
|
||||||
|
const std::vector<int> &dyn_input_sizes, std::vector<std::string> *formats,
|
||||||
|
std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types) {
|
||||||
|
MS_EXCEPTION_IF_NULL(formats);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_types);
|
||||||
|
MS_EXCEPTION_IF_NULL(reshape_types);
|
||||||
|
size_t dynamic_input_index = 0;
|
||||||
|
size_t real_io_tensor_index = 0;
|
||||||
|
size_t io_info_index = 0;
|
||||||
|
size_t io_info_num = ios_info.size();
|
||||||
|
for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) {
|
||||||
|
std::shared_ptr<OpIOInfo> io_info_item = ios_info[io_info_index];
|
||||||
|
auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index);
|
||||||
|
std::string kernel_build_info_format;
|
||||||
|
if (!io_info_item->formats().empty()) {
|
||||||
|
kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index);
|
||||||
|
}
|
||||||
|
std::string io_param_type = io_info_item->param_type();
|
||||||
|
std::vector<Axis> reshape_type;
|
||||||
|
StringToAxisVector(io_info_item->reshape_type(), &reshape_type);
|
||||||
|
if (io_param_type == kParamTypeDynamic) {
|
||||||
|
// dynamic io
|
||||||
|
if (is_input) {
|
||||||
|
if (dynamic_input_index >= dyn_input_sizes.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index
|
||||||
|
<< ", dyn_input_sizes size: " << dyn_input_sizes.size();
|
||||||
|
}
|
||||||
|
int dynamic_input_size = dyn_input_sizes[dynamic_input_index];
|
||||||
|
for (int i = 0; i < dynamic_input_size; ++i) {
|
||||||
|
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||||
|
formats->emplace_back(kernel_build_info_format);
|
||||||
|
reshape_types->emplace_back(reshape_type);
|
||||||
|
}
|
||||||
|
dynamic_input_index++;
|
||||||
|
real_io_tensor_index += dynamic_input_size;
|
||||||
|
} else {
|
||||||
|
if (ios_info.size() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output.";
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < real_io_tensor_num; ++i) {
|
||||||
|
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||||
|
formats->emplace_back(kernel_build_info_format);
|
||||||
|
reshape_types->emplace_back(reshape_type);
|
||||||
|
}
|
||||||
|
real_io_tensor_index += real_io_tensor_num;
|
||||||
|
}
|
||||||
|
} else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) {
|
||||||
|
// requre or optional io
|
||||||
|
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||||
|
formats->emplace_back(kernel_build_info_format);
|
||||||
|
reshape_types->emplace_back(reshape_type);
|
||||||
|
real_io_tensor_index++;
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (io_info_index != io_info_num) {
|
||||||
|
MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num
|
||||||
|
<< "), this node may has optional input/output.";
|
||||||
|
}
|
||||||
|
if (real_io_tensor_index != real_io_tensor_num) {
|
||||||
|
std::string io_type = is_input ? "inputs " : "outputs";
|
||||||
|
MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num
|
||||||
|
<< ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index
|
||||||
|
<< ") != real_io_tensor_num(" << real_io_tensor_num << ")";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
|
||||||
|
MS_EXCEPTION_IF_NULL(reshape_type_vec);
|
||||||
|
for (const auto &c : reshape_type_str) {
|
||||||
|
switch (c) {
|
||||||
|
case 'N':
|
||||||
|
reshape_type_vec->push_back(kernel::N);
|
||||||
|
break;
|
||||||
|
case 'C':
|
||||||
|
reshape_type_vec->push_back(kernel::C);
|
||||||
|
break;
|
||||||
|
case 'H':
|
||||||
|
reshape_type_vec->push_back(kernel::H);
|
||||||
|
break;
|
||||||
|
case 'W':
|
||||||
|
reshape_type_vec->push_back(kernel::W);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
|
||||||
|
const std::vector<std::vector<std::string>> &support_format_item, size_t index,
|
||||||
|
mindspore::kernel::OpIOInfo *op_io_info_new) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_io_info_new);
|
||||||
|
op_io_info_new->set_index(op_io_info.index());
|
||||||
|
op_io_info_new->set_name(op_io_info.name());
|
||||||
|
op_io_info_new->set_param_type(op_io_info.param_type());
|
||||||
|
op_io_info_new->set_need_compile(op_io_info.need_compile());
|
||||||
|
op_io_info_new->set_reshape_type(op_io_info.reshape_type());
|
||||||
|
op_io_info_new->set_shape(op_io_info.shape());
|
||||||
|
// dtype
|
||||||
|
std::vector<std::string> dtype_new;
|
||||||
|
auto dtype = op_io_info.dtypes();
|
||||||
|
for (size_t i = 0; i < support_format_item.size(); ++i) {
|
||||||
|
dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end());
|
||||||
|
}
|
||||||
|
op_io_info_new->set_dtypes(dtype_new);
|
||||||
|
// format
|
||||||
|
std::vector<std::string> format_new;
|
||||||
|
for (const auto &formats : support_format_item) {
|
||||||
|
auto format = formats.at(index);
|
||||||
|
for (size_t j = 0; j < dtype.size(); ++j) {
|
||||||
|
format_new.emplace_back(format);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op_io_info_new->set_formats(format_new);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) {
|
||||||
|
const std::map<std::string, std::string> kDynamicFormatMap = {
|
||||||
|
{"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}};
|
||||||
|
if (op_select_json_item.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Op select ret item is null.";
|
||||||
|
}
|
||||||
|
const char space = ' ';
|
||||||
|
const char sep = ',';
|
||||||
|
std::string op_select_tmp = op_select_json_item + ",";
|
||||||
|
std::vector<std::string> ret;
|
||||||
|
auto begin = op_select_tmp.find_first_not_of(space, 0);
|
||||||
|
auto sep_pos = op_select_tmp.find(sep);
|
||||||
|
while (sep_pos != std::string::npos) {
|
||||||
|
auto obj = op_select_tmp.substr(begin, sep_pos - begin);
|
||||||
|
if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) {
|
||||||
|
obj = kDynamicFormatMap.at(obj);
|
||||||
|
}
|
||||||
|
ret.emplace_back(obj);
|
||||||
|
begin = op_select_tmp.find_first_not_of(space, sep_pos + 1);
|
||||||
|
sep_pos = op_select_tmp.find(sep, begin);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string TbeKernelSelect::OpSelectFormat() {
|
||||||
|
nlohmann::json kernel_json;
|
||||||
|
std::string res_json_str;
|
||||||
|
TbeKernelJsonCreator creator(OP_SELECT_FORMAT);
|
||||||
|
bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed.";
|
||||||
|
}
|
||||||
|
res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json);
|
||||||
|
if (res_json_str.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "op select format error.";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str;
|
||||||
|
return res_json_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format,
|
||||||
|
mindspore::kernel::OpInfo *op_info_new) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_info_new);
|
||||||
|
if (op_info.inputs_ptr().size() != support_format.input_format[0].size() ||
|
||||||
|
op_info.outputs_ptr().size() != support_format.output_format[0].size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size()
|
||||||
|
<< ", input support size: " << support_format.input_format[0].size()
|
||||||
|
<< ", op info output size: " << op_info.outputs_ptr().size()
|
||||||
|
<< ", output support size: " << support_format.output_format[0].size();
|
||||||
|
}
|
||||||
|
*op_info_new = op_info;
|
||||||
|
op_info_new->ClearInputs();
|
||||||
|
op_info_new->ClearOutputs();
|
||||||
|
for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
|
||||||
|
auto input = op_info.inputs_ptr().at(i);
|
||||||
|
auto input_new = std::make_shared<OpIOInfo>();
|
||||||
|
CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get());
|
||||||
|
op_info_new->add_inputs_ptr(input_new);
|
||||||
|
}
|
||||||
|
for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) {
|
||||||
|
auto output = op_info.outputs_ptr().at(j);
|
||||||
|
auto output_new = std::make_shared<OpIOInfo>();
|
||||||
|
CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get());
|
||||||
|
op_info_new->add_outputs_ptr(output_new);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SelectOpIOInfo {
|
||||||
|
std::string name;
|
||||||
|
std::vector<std::string> dtypes;
|
||||||
|
std::vector<std::string> formats;
|
||||||
|
};
|
||||||
|
|
||||||
|
void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
|
||||||
|
mindspore::kernel::OpInfo *op_info_new) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_info_new);
|
||||||
|
auto op_seclect_json = OpSelectFormat();
|
||||||
|
if (!op_seclect_json.empty()) {
|
||||||
|
nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json);
|
||||||
|
if (!json_obj.is_object()) {
|
||||||
|
MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json;
|
||||||
|
}
|
||||||
|
std::vector<SelectOpIOInfo> inputs;
|
||||||
|
std::vector<SelectOpIOInfo> outputs;
|
||||||
|
for (const auto &item : json_obj.items()) {
|
||||||
|
const std::string &item_name = item.key();
|
||||||
|
bool is_input = (item_name.find(kPrefixInput) != std::string::npos);
|
||||||
|
bool is_output = (item_name.find(kPrefixOutput) != std::string::npos);
|
||||||
|
if (!is_input && !is_output) {
|
||||||
|
MS_LOG(EXCEPTION) << "op select ret json is error.";
|
||||||
|
}
|
||||||
|
if (is_input) {
|
||||||
|
SelectOpIOInfo select_input;
|
||||||
|
select_input.name = item.value().at(kName);
|
||||||
|
std::string input_dtype_item = item.value().at(kDtype);
|
||||||
|
select_input.dtypes = SplitStrToVec(input_dtype_item);
|
||||||
|
std::string input_format_item = item.value().at(kFormat);
|
||||||
|
select_input.formats = SplitStrToVec(input_format_item);
|
||||||
|
inputs.emplace_back(select_input);
|
||||||
|
} else if (is_output) {
|
||||||
|
SelectOpIOInfo select_output;
|
||||||
|
select_output.name = item.value().at(kName);
|
||||||
|
std::string input_dtype_item = item.value().at(kDtype);
|
||||||
|
select_output.dtypes = SplitStrToVec(input_dtype_item);
|
||||||
|
std::string input_format_item = item.value().at(kFormat);
|
||||||
|
select_output.formats = SplitStrToVec(input_format_item);
|
||||||
|
outputs.emplace_back(select_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register.";
|
||||||
|
}
|
||||||
|
|
||||||
|
*op_info_new = op_info;
|
||||||
|
op_info_new->ClearInputs();
|
||||||
|
op_info_new->ClearOutputs();
|
||||||
|
for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) {
|
||||||
|
auto input_new = std::make_shared<OpIOInfo>();
|
||||||
|
CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get());
|
||||||
|
op_info_new->add_inputs_ptr(input_new);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) {
|
||||||
|
auto output_new = std::make_shared<OpIOInfo>();
|
||||||
|
CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get());
|
||||||
|
op_info_new->add_outputs_ptr(output_new);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info,
|
||||||
|
const std::vector<std::string> &support_dtype,
|
||||||
|
const std::vector<std::string> &support_format,
|
||||||
|
mindspore::kernel::OpIOInfo *op_io_info_new) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_io_info_new);
|
||||||
|
op_io_info_new->set_index(op_io_info.index());
|
||||||
|
op_io_info_new->set_name(op_io_info.name());
|
||||||
|
op_io_info_new->set_param_type(op_io_info.param_type());
|
||||||
|
op_io_info_new->set_need_compile(op_io_info.need_compile());
|
||||||
|
op_io_info_new->set_reshape_type(op_io_info.reshape_type());
|
||||||
|
op_io_info_new->set_shape(op_io_info.shape());
|
||||||
|
// dtype
|
||||||
|
std::vector<std::string> dtype_new;
|
||||||
|
for (size_t i = 0; i < support_format.size(); ++i) {
|
||||||
|
dtype_new.insert(dtype_new.end(), support_dtype.begin(), support_dtype.end());
|
||||||
|
}
|
||||||
|
op_io_info_new->set_dtypes(dtype_new);
|
||||||
|
// format
|
||||||
|
std::vector<std::string> format_new;
|
||||||
|
for (const auto &format : support_format) {
|
||||||
|
for (size_t j = 0; j < support_dtype.size(); ++j) {
|
||||||
|
format_new.emplace_back(format);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op_io_info_new->set_formats(format_new);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) {
|
||||||
|
if (support_format.input_format.size() != support_format.output_format.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output("
|
||||||
|
<< support_format.output_format.size() << ") size not match.";
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < support_format.input_format.size(); ++i) {
|
||||||
|
auto input_items = support_format.input_format.at(i);
|
||||||
|
auto output_items = support_format.output_format.at(i);
|
||||||
|
std::string print_str = "[";
|
||||||
|
for (const auto &input : input_items) {
|
||||||
|
print_str.append(input);
|
||||||
|
print_str.append(", ");
|
||||||
|
}
|
||||||
|
print_str.append("] -->");
|
||||||
|
for (const auto &output : output_items) {
|
||||||
|
print_str.append(output);
|
||||||
|
print_str.append(", ");
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Support format: " << print_str;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,77 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_TBE_KERNEL_SELECT_H
|
||||||
|
#define MINDSPORE_TBE_KERNEL_SELECT_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "kernel/oplib/opinfo.h"
|
||||||
|
#include "kernel/kernel_build_info.h"
|
||||||
|
#include "kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
||||||
|
|
||||||
|
class TbeKernelSelect {
|
||||||
|
using OpInfoPtr = std::shared_ptr<OpInfo>;
|
||||||
|
using KernelBuildInfoIter = std::vector<std::shared_ptr<KernelBuildInfo>>::iterator;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
||||||
|
~TbeKernelSelect() = default;
|
||||||
|
void TbeMetadataInfoEx();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void GetCommonPatternKernelInfo(const OpInfo &op_info);
|
||||||
|
void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info);
|
||||||
|
void GetAgnosticPatternKernelInfo(const OpInfo &op_info);
|
||||||
|
void GetBroadcastPatternKernelInfo(const OpInfo &op_info);
|
||||||
|
void GetReducePatternKernelInfo(const OpInfo &op_info);
|
||||||
|
void FilterInVaildKernelInfo();
|
||||||
|
bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter);
|
||||||
|
static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format);
|
||||||
|
bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter);
|
||||||
|
static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder);
|
||||||
|
bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
|
||||||
|
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, const std::vector<int> &dyn_input_sizes,
|
||||||
|
std::vector<std::string> *formats, std::vector<TypeId> *device_types,
|
||||||
|
std::vector<std::vector<Axis>> *reshape_types);
|
||||||
|
static void StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
|
||||||
|
static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new);
|
||||||
|
static void CreateNewOpIOInfo(const OpIOInfo &op_io_info,
|
||||||
|
const std::vector<std::vector<std::string>> &support_format_item, size_t index,
|
||||||
|
OpIOInfo *op_io_info_new);
|
||||||
|
// op select(dynamic)
|
||||||
|
void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new);
|
||||||
|
static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector<std::string> &support_dtype,
|
||||||
|
const std::vector<std::string> &support_format, OpIOInfo *op_io_info_new);
|
||||||
|
static std::vector<std::string> SplitStrToVec(const std::string &op_select_json_item);
|
||||||
|
std::string OpSelectFormat();
|
||||||
|
|
||||||
|
static void PrintSupportedFormat(const SupportFormat &support_format);
|
||||||
|
|
||||||
|
private:
|
||||||
|
CNodePtr cnode_ptr_;
|
||||||
|
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list_;
|
||||||
|
std::string node_name_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_TBE_KERNEL_SELECT_H
|
|
@ -216,6 +216,13 @@ constexpr char NEG[] = "Neg";
|
||||||
constexpr char BATCH_MATMUL[] = "BatchMatMul";
|
constexpr char BATCH_MATMUL[] = "BatchMatMul";
|
||||||
constexpr char EXPAND_DIMS[] = "ExpandDims";
|
constexpr char EXPAND_DIMS[] = "ExpandDims";
|
||||||
constexpr char SQUARE[] = "Square";
|
constexpr char SQUARE[] = "Square";
|
||||||
|
constexpr char BATCHMATMUL[] = "BatchMatMul";
|
||||||
|
constexpr char TOPK[] = "TopK";
|
||||||
|
constexpr char IN_TOPK[] = "InTopK";
|
||||||
|
constexpr char PACK[] = "Pack";
|
||||||
|
constexpr char GATHER_ND[] = "GatherNd";
|
||||||
|
constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
|
||||||
|
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
|
||||||
|
|
||||||
// Parallel don't care
|
// Parallel don't care
|
||||||
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
|
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "device/ascend/kernel_select_ascend.h"
|
#include "device/ascend/kernel_select_ascend.h"
|
||||||
#include "kernel/kernel_query.h"
|
#include "kernel/kernel_query.h"
|
||||||
#include "kernel/tbe/tbe_kernel_select.h"
|
|
||||||
#include "kernel/oplib/oplib.h"
|
#include "kernel/oplib/oplib.h"
|
||||||
#include "session/anf_runtime_algorithm.h"
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
auto node_name = AnfAlgo::GetCNodeName(node);
|
||||||
if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) {
|
if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
|
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
|
||||||
|
|
|
@ -26,12 +26,9 @@ abs_op_info = TBERegOp("Abs") \
|
||||||
.op_pattern("formatAgnostic") \
|
.op_pattern("formatAgnostic") \
|
||||||
.input(0, "x", None, "required", None) \
|
.input(0, "x", None, "required", None) \
|
||||||
.output(0, "y", True, "required", "all") \
|
.output(0, "y", True, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.I32_None, DataType.I32_None) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
|
||||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ abs_grad_op_info = TBERegOp("AbsGrad") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("abs_grad") \
|
.kernel_name("abs_grad") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "y", None, "required", None) \
|
.input(0, "y", None, "required", None) \
|
||||||
.input(1, "dy", None, "required", None) \
|
.input(1, "dy", None, "required", None) \
|
||||||
.output(0, "z", False, "required", "all") \
|
.output(0, "z", False, "required", "all") \
|
||||||
|
|
|
@ -26,6 +26,7 @@ add_op_info = TBERegOp("Add") \
|
||||||
.input(0, "x1", False, "required", "all") \
|
.input(0, "x1", False, "required", "all") \
|
||||||
.input(1, "x2", False, "required", "all") \
|
.input(1, "x2", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
|
|
@ -26,17 +26,10 @@ add_n_op_info = TBERegOp("AddN") \
|
||||||
.attr("n", "required", "int", "all") \
|
.attr("n", "required", "int", "all") \
|
||||||
.input(0, "x", False, "dynamic", "all") \
|
.input(0, "x", False, "dynamic", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.op_pattern("broadcast") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
.dtype_format(DataType.I32_None, DataType.I32_None) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
|
||||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
|
||||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
|
||||||
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \
|
||||||
.input(1, "x2", False, "required", "all") \
|
.input(1, "x2", False, "required", "all") \
|
||||||
.input(2, "bias", False, "optional", "all") \
|
.input(2, "bias", False, "optional", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
|
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F16_FracNZ) \
|
||||||
|
|
|
@ -27,6 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.input(1, "bias", False, "required", "all") \
|
.input(1, "bias", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
|
|
@ -26,6 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
|
||||||
.input(0, "x", False, "required", "all", reshape_type="NC") \
|
.input(0, "x", False, "required", "all", reshape_type="NC") \
|
||||||
.output(0, "sum", False, "required", "all") \
|
.output(0, "sum", False, "required", "all") \
|
||||||
.output(1, "square_sum", False, "required", "all") \
|
.output(1, "square_sum", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
|
@ -32,6 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
|
||||||
.input(5, "batch_mean", False, "required", "all") \
|
.input(5, "batch_mean", False, "required", "all") \
|
||||||
.input(6, "batch_variance", False, "required", "all") \
|
.input(6, "batch_variance", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all", reshape_type="NC") \
|
.output(0, "y", False, "required", "all", reshape_type="NC") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \
|
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
|
|
@ -30,6 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
|
||||||
.input(3, "batch_variance", False, "required", "all") \
|
.input(3, "batch_variance", False, "required", "all") \
|
||||||
.output(0, "diff_scale", False, "required", "all") \
|
.output(0, "diff_scale", False, "required", "all") \
|
||||||
.output(1, "diff_offset", False, "required", "all") \
|
.output(1, "diff_offset", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
DataType.F32_5HD, DataType.F32_5HD) \
|
DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
|
|
|
@ -32,6 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \
|
||||||
.output(0, "y", False, "required", "all", reshape_type="NC") \
|
.output(0, "y", False, "required", "all", reshape_type="NC") \
|
||||||
.output(1, "batch_mean", False, "required", "all") \
|
.output(1, "batch_mean", False, "required", "all") \
|
||||||
.output(2, "batch_variance", False, "required", "all") \
|
.output(2, "batch_variance", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD,
|
DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD,
|
||||||
DataType.F32_5HD, DataType.F32_5HD) \
|
DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
|
|
@ -26,32 +26,27 @@ cast_op_info = TBERegOp("Cast") \
|
||||||
.attr("dst_type", "required", "int", "all") \
|
.attr("dst_type", "required", "int", "all") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
|
.op_pattern("formatAgnostic") \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
|
.dtype_format(DataType.BOOL_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
|
.dtype_format(DataType.BOOL_None, DataType.U8_None) \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
|
.dtype_format(DataType.BOOL_None, DataType.F32_None) \
|
||||||
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
|
.dtype_format(DataType.BOOL_None, DataType.I32_None) \
|
||||||
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
|
.dtype_format(DataType.I8_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
|
.dtype_format(DataType.I8_None, DataType.F32_None) \
|
||||||
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
|
.dtype_format(DataType.I8_None, DataType.I32_None) \
|
||||||
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
|
.dtype_format(DataType.U8_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
|
.dtype_format(DataType.U8_None, DataType.F32_None) \
|
||||||
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
|
.dtype_format(DataType.U8_None, DataType.I32_None) \
|
||||||
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
|
.dtype_format(DataType.I32_None, DataType.BOOL_None) \
|
||||||
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.I32_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
|
.dtype_format(DataType.I32_None, DataType.F32_None) \
|
||||||
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
|
.dtype_format(DataType.I32_None, DataType.I8_None) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
|
.dtype_format(DataType.I32_None, DataType.U8_None) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F16_None, DataType.U8_None) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
.dtype_format(DataType.F16_None, DataType.F32_None) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F32_5HD) \
|
.dtype_format(DataType.F16_None, DataType.I32_None) \
|
||||||
.dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \
|
.dtype_format(DataType.F32_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \
|
.dtype_format(DataType.F32_None, DataType.I32_None) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \
|
|
||||||
.dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ concat_op_info = TBERegOp("Concat") \
|
||||||
.attr("axis", "required", "int", "all") \
|
.attr("axis", "required", "int", "all") \
|
||||||
.input(0, "input_values", False, "dynamic", "all") \
|
.input(0, "input_values", False, "dynamic", "all") \
|
||||||
.output(0, "output_data", False, "required", "all") \
|
.output(0, "output_data", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
|
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
|
||||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
|
|
@ -23,6 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("conv2d") \
|
.kernel_name("conv2d") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.attr("stride", "required", "listInt", "all") \
|
.attr("stride", "required", "listInt", "all") \
|
||||||
.attr("pad_list", "required", "listInt", "all") \
|
.attr("pad_list", "required", "listInt", "all") \
|
||||||
.attr("dilation", "required", "listInt", "all") \
|
.attr("dilation", "required", "listInt", "all") \
|
||||||
|
@ -32,8 +33,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
|
||||||
.input(2, "bias", False, "optional", "all") \
|
.input(2, "bias", False, "optional", "all") \
|
||||||
.input(3, "offset_w", False, "optional", "all") \
|
.input(3, "offset_w", False, "optional", "all") \
|
||||||
.output(0, "y", True, "required", "all") \
|
.output(0, "y", True, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default,
|
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \
|
||||||
DataType.F16_5HD) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \
|
||||||
.input(1, "mask", False, "required", "all") \
|
.input(1, "mask", False, "required", "all") \
|
||||||
.input(2, "keep_prob", False, "required", "all") \
|
.input(2, "keep_prob", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
|
@ -28,9 +28,7 @@ elu_op_info = TBERegOp("Elu") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,7 @@ erf_op_info = TBERegOp("Erf") \
|
||||||
.op_pattern("formatAgnostic") \
|
.op_pattern("formatAgnostic") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,7 @@ erfc_op_info = TBERegOp("Erfc") \
|
||||||
.op_pattern("formatAgnostic") \
|
.op_pattern("formatAgnostic") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
|
@ -27,9 +27,7 @@ expm1_op_info = TBERegOp("Expm1") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \
|
||||||
.input(1, "x2", False, "required", "all") \
|
.input(1, "x2", False, "required", "all") \
|
||||||
.input(2, "x3", False, "required", "all") \
|
.input(2, "x3", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
||||||
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
|
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
|
||||||
|
|
|
@ -32,6 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.output(1, "mean", False, "required", "all") \
|
.output(1, "mean", False, "required", "all") \
|
||||||
.output(2, "variance", False, "required", "all") \
|
.output(2, "variance", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default, DataType.F16_Default) \
|
DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
|
|
|
@ -30,6 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop")
|
||||||
.input(3, "mean", False, "required", "all") \
|
.input(3, "mean", False, "required", "all") \
|
||||||
.output(0, "pd_gamma", False, "required", "all") \
|
.output(0, "pd_gamma", False, "required", "all") \
|
||||||
.output(1, "pd_beta", False, "required", "all") \
|
.output(1, "pd_beta", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F32_Default, DataType.F32_Default) \
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
|
|
|
@ -29,6 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \
|
||||||
.input(3, "mean", False, "required", "all") \
|
.input(3, "mean", False, "required", "all") \
|
||||||
.input(4, "gamma", False, "required", "all") \
|
.input(4, "gamma", False, "required", "all") \
|
||||||
.output(0, "pd_x", False, "required", "all") \
|
.output(0, "pd_x", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default, DataType.F16_Default) \
|
DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
|
|
|
@ -26,21 +26,8 @@ mul_op_info = TBERegOp("Mul") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.input(1, "y", False, "required", "all") \
|
.input(1, "y", False, "required", "all") \
|
||||||
.output(0, "output", False, "required", "all") \
|
.output(0, "output", False, "required", "all") \
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
|
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||||
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
|
|
||||||
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ, DataType.I32_FracNZ) \
|
|
||||||
.dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \
|
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
|
||||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
|
||||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
|
||||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,10 +26,9 @@ realdiv_op_info = TBERegOp("RealDiv") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.input(1, "y", False, "required", "all") \
|
.input(1, "y", False, "required", "all") \
|
||||||
.output(0, "z", False, "required", "all") \
|
.output(0, "z", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
.op_pattern("broadcast") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
|
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
|
||||||
|
|
|
@ -27,11 +27,11 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \
|
||||||
.attr("keep_dims", "optional", "bool", "all") \
|
.attr("keep_dims", "optional", "bool", "all") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
.op_pattern("reduce") \
|
||||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
.dtype_format(DataType.I8_None, DataType.I8_None) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.U8_None, DataType.U8_None) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ relu_grad_v2_op_info = TBERegOp("ReluGradV2") \
|
||||||
.kernel_name("relu_grad_v2") \
|
.kernel_name("relu_grad_v2") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.input(0, "gradients", False, "required", "all") \
|
.input(0, "gradients", False, "required", "all") \
|
||||||
.input(1, "mask", False, "rerequired", "all") \
|
.input(1, "mask", False, "required", "all") \
|
||||||
.output(0, "backprops", True, "required", "all") \
|
.output(0, "backprops", True, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \
|
.dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \
|
.dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \
|
||||||
|
|
|
@ -27,6 +27,7 @@ select_op_info = TBERegOp("Select") \
|
||||||
.input(1, "x1", False, "required", "all") \
|
.input(1, "x1", False, "required", "all") \
|
||||||
.input(2, "x2", False, "required", "all") \
|
.input(2, "x2", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
.dtype_format(DataType.BOOL_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
.dtype_format(DataType.BOOL_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
|
|
@ -27,11 +27,8 @@ sign_op_info = TBERegOp("Sign") \
|
||||||
.input(0, "x", None, "required", None) \
|
.input(0, "x", None, "required", None) \
|
||||||
.output(0, "y", True, "required", "all") \
|
.output(0, "y", True, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \
|
||||||
.input(1, "x1", False, "required", "all") \
|
.input(1, "x1", False, "required", "all") \
|
||||||
.input(2, "x2", False, "required", "all") \
|
.input(2, "x2", False, "required", "all") \
|
||||||
.output(0, "y", True, "required", "all") \
|
.output(0, "y", True, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default, DataType.F16_Default) \
|
DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD,
|
||||||
|
|
|
@ -27,9 +27,7 @@ softplus_op_info = TBERegOp("Softplus") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,9 +28,7 @@ softplus_grad_op_info = TBERegOp("SoftplusGrad") \
|
||||||
.input(1, "features", False, "required", "all") \
|
.input(1, "features", False, "required", "all") \
|
||||||
.output(0, "backprops", False, "required", "all") \
|
.output(0, "backprops", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ split_d_op_info = TBERegOp("Split") \
|
||||||
.attr("output_num", "required", "int", "all") \
|
.attr("output_num", "required", "int", "all") \
|
||||||
.input(0, "value", False, "required", "all") \
|
.input(0, "value", False, "required", "all") \
|
||||||
.output(0, "output", False, "dynamic", "all") \
|
.output(0, "output", False, "dynamic", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
|
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
|
||||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
|
|
@ -26,6 +26,7 @@ tensor_add_op_info = TBERegOp("TensorAdd") \
|
||||||
.input(0, "x1", False, "required", "all") \
|
.input(0, "x1", False, "required", "all") \
|
||||||
.input(1, "x2", False, "required", "all") \
|
.input(1, "x2", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
|
|
@ -27,6 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.input(1, "segment_ids", False, "required", "all") \
|
.input(1, "segment_ids", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("dynamicFormat") \
|
||||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||||
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
|
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
|
||||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
|
|
@ -97,6 +97,7 @@ class RegOp:
|
||||||
"""
|
"""
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise TypeError("%s value must be str" % str(value))
|
raise TypeError("%s value must be str" % str(value))
|
||||||
|
return True
|
||||||
|
|
||||||
def _is_int(self, value):
|
def _is_int(self, value):
|
||||||
"""
|
"""
|
||||||
|
@ -110,6 +111,7 @@ class RegOp:
|
||||||
"""
|
"""
|
||||||
if not isinstance(value, int):
|
if not isinstance(value, int):
|
||||||
raise TypeError("%s value must be int" % str(value))
|
raise TypeError("%s value must be int" % str(value))
|
||||||
|
return True
|
||||||
|
|
||||||
def _is_bool(self, value):
|
def _is_bool(self, value):
|
||||||
"""
|
"""
|
||||||
|
@ -123,6 +125,7 @@ class RegOp:
|
||||||
"""
|
"""
|
||||||
if not isinstance(value, bool):
|
if not isinstance(value, bool):
|
||||||
raise TypeError("%s value must be bool" % str(value))
|
raise TypeError("%s value must be bool" % str(value))
|
||||||
|
return True
|
||||||
|
|
||||||
def _check_param(self, param_list, key_list, fn_list, kwargs):
|
def _check_param(self, param_list, key_list, fn_list, kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -494,6 +497,7 @@ class DataType:
|
||||||
The current list below maybe not completed. If necessary, please add it.
|
The current list below maybe not completed. If necessary, please add it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
None_None = ("", "")
|
||||||
BOOL_None = ("bool", "")
|
BOOL_None = ("bool", "")
|
||||||
BOOL_Default = ("bool", "DefaultFormat")
|
BOOL_Default = ("bool", "DefaultFormat")
|
||||||
BOOL_5HD = ("bool", "NC1HWC0")
|
BOOL_5HD = ("bool", "NC1HWC0")
|
||||||
|
|
Loading…
Reference in New Issue