forked from mindspore-Ecosystem/mindspore
!2691 use two condition, false branch caculate error
Merge pull request !2691 from hexia/deal-switch-input
This commit is contained in:
commit
381bbc4db5
|
@ -562,10 +562,17 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern
|
|||
MS_LOG(WARNING) << "kernel [" << (kernel_info_list.size() + index)
|
||||
<< "] :" << aicpu_kernel_info_list[index]->ToString();
|
||||
}
|
||||
MS_LOG(WARNING) << " <<<";
|
||||
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid kernel info, not supported the type:" << buffer.str()
|
||||
<< ", please refer to the supported dtypes in candidates kernel info list";
|
||||
if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) {
|
||||
auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
||||
// Set format and data type for input tensor.
|
||||
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
||||
} else {
|
||||
MS_LOG(WARNING) << " <<<";
|
||||
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid kernel info, not supported the type:" << buffer.str()
|
||||
<< ", please refer to the supported dtypes in candidates kernel info list";
|
||||
}
|
||||
}
|
||||
return select_status;
|
||||
}
|
||||
|
|
|
@ -75,8 +75,8 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr
|
|||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo() {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{};
|
||||
vector<string> input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT};
|
||||
vector<TypeId> input_type{kNumberTypeUInt32, kNumberTypeBool};
|
||||
vector<string> input_format{kOpFormat_DEFAULT};
|
||||
vector<TypeId> input_type{kNumberTypeInt32};
|
||||
if (input_format.size() != input_type.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size "
|
||||
<< input_type.size();
|
||||
|
|
Loading…
Reference in New Issue