!1736 add reducemean's special kernel fileter rule

Merge pull request !1736 from lianliguang/master
This commit is contained in:
mindspore-ci-bot 2020-06-01 10:02:09 +08:00 committed by Gitee
commit 72fd41786c
2 changed files with 17 additions and 2 deletions

View File

@ -44,6 +44,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
}
kernel_info_list->clear();
MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : ["
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
@ -54,11 +55,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list);
if (!kernel_info_list->empty()) {
MS_LOG(INFO) << "Warning The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
}
}

View File

@ -581,6 +581,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
const size_t kCAxis = 1;
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
@ -593,6 +594,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false;
}
if (kernel_name == "ReduceMean") {
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) {
return false;
}
}
}
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
@ -605,6 +612,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
}
return false;
}
if (kernel_name == "ReduceMean") {
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) {
return false;
}
}
}
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&