diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_update_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_update_cpu_kernel.cc index 140747d4661..bdce332579c 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_update_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_update_cpu_kernel.cc @@ -140,6 +140,9 @@ bool ScatterUpdateCpuKernelMod::Launch(const std::vector &in case kNumberTypeInt64: LaunchKernel(inputs, outputs); break; + case kNumberTypeBool: + LaunchKernel(inputs, outputs); + break; default: MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dtype of 'input_x' must be float16, float32, float64, int32 or int64, but got " diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_update_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_update_cpu_kernel.h index 6aeaf0312d6..7a90af5c93e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_update_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_update_cpu_kernel.h @@ -121,7 +121,19 @@ class TensorScatterUpdateCpuKernelMod : public ScatterUpdateCpuKernelMod { .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64)}; + .AddOutputAttr(kNumberTypeInt64), + + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeBool) + .AddOutputAttr(kNumberTypeBool), + + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeBool) + .AddOutputAttr(kNumberTypeBool)}; return support_list; } };