From cdacd5ca768526eb8a3d010b857067f3451ff3c0 Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Sat, 30 May 2020 18:52:38 +0800 Subject: [PATCH] add reducemean's kernel select rules --- mindspore/ccsrc/kernel/kernel_query.cc | 6 ++++-- mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc | 13 +++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index f8523e94e8d..8d3ee64591e 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -44,6 +44,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); } + kernel_info_list->clear(); MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; @@ -54,11 +55,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorempty()) { AicpuMetadataInfo(kernel_node, kernel_info_list); if (!kernel_info_list->empty()) { - MS_LOG(INFO) << "Warning The node [" << kernel_node->DebugString() - << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; + MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() + << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); } } diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index adf5280a8b2..aedb0b3eafe 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -581,6 +581,7 @@ bool IsShapeMatchFormat(const std::vector &shape, const std::string &for bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { MS_EXCEPTION_IF_NULL(kernel_node); + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); const size_t kCAxis = 1; for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); @@ -593,6 +594,12 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { return false; } + if (kernel_name == "ReduceMean") { + auto keep_dims = AnfAlgo::GetNodeAttr(kernel_node, kAttrKeepDims); + if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) { + return false; + } + } } for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); @@ -605,6 +612,12 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: } return false; } + if (kernel_name == "ReduceMean") { + auto keep_dims = AnfAlgo::GetNodeAttr(kernel_node, kAttrKeepDims); + if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) { + return false; + } + } } if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&