forked from mindspore-Ecosystem/mindspore
!9998 fix cpu select kernel info
From: @huaweib Reviewed-by: @chujinjin,@kisnwang Signed-off-by: @kisnwang
This commit is contained in:
commit
dfc53b580a
|
@ -92,7 +92,7 @@ bool ReduceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
|
||||
void ReduceCPUKernel::CheckAxis(const CNodePtr &kernel_node) {
|
||||
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
|
||||
if (axis_addr->isa<ValueTuple>()) {
|
||||
if (axis_addr->isa<ValueTuple>() || axis_addr->isa<ValueList>()) {
|
||||
std::vector<int> attr_axis;
|
||||
std::vector<int64_t> attr_axis_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, AXIS);
|
||||
(void)std::transform(attr_axis_me.begin(), attr_axis_me.end(), std::back_inserter(attr_axis),
|
||||
|
|
|
@ -272,8 +272,8 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
|||
}
|
||||
}
|
||||
// All formats and data types matched
|
||||
if (max_type_matched_num == SizeToInt(input_types.size()) &&
|
||||
max_format_matched_num == SizeToInt(input_types.size())) {
|
||||
if (input_type_format_matched_num.first == SizeToInt(input_types.size()) &&
|
||||
input_type_format_matched_num.second == SizeToInt(input_types.size())) {
|
||||
matched->first = true;
|
||||
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
||||
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
||||
|
|
Loading…
Reference in New Issue