!1736 add reducemean's special kernel fileter rule
Merge pull request !1736 from lianliguang/master
This commit is contained in:
commit
72fd41786c
|
@ -44,6 +44,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
|||
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
|
||||
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
|
||||
}
|
||||
kernel_info_list->clear();
|
||||
MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : ["
|
||||
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
|
||||
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
|
||||
|
@ -54,11 +55,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
TbeMetadataInfo(kernel_node, kernel_info_list);
|
||||
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
||||
if (kernel_info_list->empty()) {
|
||||
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
||||
if (!kernel_info_list->empty()) {
|
||||
MS_LOG(INFO) << "Warning The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
|
||||
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
|
||||
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -581,6 +581,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
|
|||
|
||||
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
const size_t kCAxis = 1;
|
||||
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
|
||||
|
@ -593,6 +594,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
|
|||
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) {
|
||||
return false;
|
||||
}
|
||||
if (kernel_name == "ReduceMean") {
|
||||
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
|
||||
if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
|
||||
|
@ -605,6 +612,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
|
|||
}
|
||||
return false;
|
||||
}
|
||||
if (kernel_name == "ReduceMean") {
|
||||
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
|
||||
if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
|
||||
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
|
||||
|
|
Loading…
Reference in New Issue