!21913 add cpu mask_select type and fix visitkernel bug

Merge pull request !21913 from baihuawei/mask_select_and_visit_kernel_bug
This commit is contained in:
i-robot 2021-08-18 07:54:48 +00:00 committed by Gitee
commit 1c7c7dc6da
3 changed files with 55 additions and 0 deletions

View File

@ -52,6 +52,26 @@ MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
MaskedSelectCPUKernel, int);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16),
MaskedSelectCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
MaskedSelectCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16),
MaskedSelectCPUKernel, float16);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64),
MaskedSelectCPUKernel, double);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_

View File

@ -58,6 +58,38 @@ MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MaskedSelectGradCPUKernel, int);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
MaskedSelectGradCPUKernel, float16);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MaskedSelectGradCPUKernel, double);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
MaskedSelectGradCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MaskedSelectGradCPUKernel, int64_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_

View File

@ -203,6 +203,9 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
auto input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
if (AnfAlgo::GetInputTensorNum(cnode) == 0) {
return std::make_pair(nullptr, 0);
}
auto node = cnode->input(index + IntToSize(1));
MS_EXCEPTION_IF_NULL(node);
return VisitKernel(node, 0);