forked from mindspore-Ecosystem/mindspore
!21913 add cpu mask_select type and fix visitkernel bug
Merge pull request !21913 from baihuawei/mask_select_and_visit_kernel_bug
This commit is contained in:
commit
1c7c7dc6da
|
@ -52,6 +52,26 @@ MS_REG_CPU_KERNEL_T(
|
|||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
||||
MaskedSelectCPUKernel, int);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16),
|
||||
MaskedSelectCPUKernel, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
|
||||
MaskedSelectCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16),
|
||||
MaskedSelectCPUKernel, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
MaskedSelect,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64),
|
||||
MaskedSelectCPUKernel, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
|
||||
|
|
|
@ -58,6 +58,38 @@ MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
|||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MaskedSelectGradCPUKernel, int);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MaskedSelectGradCPUKernel, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MaskedSelectGradCPUKernel, double);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
MaskedSelectGradCPUKernel, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MaskedSelectGradCPUKernel, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -203,6 +203,9 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
|
|||
auto input0 = cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(input0);
|
||||
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||
if (AnfAlgo::GetInputTensorNum(cnode) == 0) {
|
||||
return std::make_pair(nullptr, 0);
|
||||
}
|
||||
auto node = cnode->input(index + IntToSize(1));
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return VisitKernel(node, 0);
|
||||
|
|
Loading…
Reference in New Issue