forked from OSSInnovation/mindspore
Add max match rule for cpu kernel selection
This commit is contained in:
parent
fe934520e6
commit
d96044fbd9
|
@ -78,33 +78,40 @@ void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &ke
|
|||
}
|
||||
}
|
||||
|
||||
bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes) {
|
||||
std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
||||
const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes) {
|
||||
if (kernel_attr.GetInputSize() != input_types.size()) {
|
||||
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
|
||||
return false;
|
||||
return std::make_pair(0, 0);
|
||||
}
|
||||
int data_type_matched_num = 0;
|
||||
int format_matched_num = 0;
|
||||
auto input_num = input_types.size();
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(),
|
||||
[i](size_t index) { return index == i; });
|
||||
bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size());
|
||||
if (have_cnode_input && is_not_cnode_idx) {
|
||||
data_type_matched_num++;
|
||||
format_matched_num++;
|
||||
continue;
|
||||
}
|
||||
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
|
||||
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
|
||||
<< ", actual input dtype:" << input_types[i];
|
||||
return false;
|
||||
} else {
|
||||
data_type_matched_num++;
|
||||
}
|
||||
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
|
||||
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second
|
||||
<< ", actual input format:" << input_formats[i];
|
||||
return false;
|
||||
} else {
|
||||
format_matched_num++;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return std::make_pair(data_type_matched_num, format_matched_num);
|
||||
}
|
||||
|
||||
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
||||
|
@ -121,6 +128,18 @@ void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
|||
kernel_attr->AddOutputAttr(output_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types,
|
||||
const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types,
|
||||
AnfNode *kernel_node) {
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputsFormat(input_formats);
|
||||
builder->SetInputsDeviceType(input_types);
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||
|
@ -136,38 +155,49 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
auto kernel_attrs =
|
||||
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
||||
|
||||
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
||||
auto kernel_attr = kernel_attrs[index];
|
||||
int max_type_matched_num = -1;
|
||||
int max_format_matched_num = -1;
|
||||
KernelAttr selected_kernel_attr;
|
||||
for (auto kernel_attr : kernel_attrs) {
|
||||
if (kernel_attr.GetAllSame()) {
|
||||
ExpandKernelAttr(kernel_node, &kernel_attr);
|
||||
}
|
||||
bool ignore_check = false;
|
||||
if (index == kernel_attrs.size() - 1 && input_types.size() == input_not_cnode_indexes.size()) {
|
||||
ignore_check = true;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (kernel_attr.GetOutputSize() != output_num) {
|
||||
MS_LOG(DEBUG) << "Output num is not equal!";
|
||||
continue;
|
||||
}
|
||||
if (ignore_check || IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (kernel_attr.GetOutputSize() != output_num) {
|
||||
MS_LOG(DEBUG) << "Output num is not equal!";
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
|
||||
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
|
||||
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
|
||||
for (auto &input_index : input_not_cnode_indexes) {
|
||||
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
|
||||
}
|
||||
std::pair<int, int> input_type_format_matched_num =
|
||||
GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes);
|
||||
// Data type first
|
||||
if (input_type_format_matched_num.first > max_type_matched_num) {
|
||||
max_type_matched_num = input_type_format_matched_num.first;
|
||||
max_format_matched_num = input_type_format_matched_num.second;
|
||||
selected_kernel_attr = kernel_attr;
|
||||
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
||||
input_type_format_matched_num.second > max_format_matched_num) {
|
||||
max_format_matched_num = input_type_format_matched_num.second;
|
||||
selected_kernel_attr = kernel_attr;
|
||||
}
|
||||
// All formats and data types matched
|
||||
if (max_type_matched_num == SizeToInt(input_types.size()) &&
|
||||
max_format_matched_num == SizeToInt(input_types.size())) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputsFormat(input_formats);
|
||||
builder->SetInputsDeviceType(input_types);
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||
if ((max_type_matched_num == SizeToInt(input_types.size()) &&
|
||||
max_format_matched_num == SizeToInt(input_types.size())) ||
|
||||
input_types.size() == input_not_cnode_indexes.size()) {
|
||||
MS_LOG(INFO) << "Input format and dtype is matched, max_type_matched_num: " << max_type_matched_num
|
||||
<< ", max_format_matched_num: " << max_format_matched_num;
|
||||
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types);
|
||||
UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node);
|
||||
for (auto &input_index : input_not_cnode_indexes) {
|
||||
input_types[input_index] = selected_kernel_attr.GetInputAttr(input_index).first;
|
||||
}
|
||||
}
|
||||
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
|
|
Loading…
Reference in New Issue