diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_add_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_add_grad_cpu_kernel.cc index aa4a0c9f3df..e92d6bd106e 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_add_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sparse_add_grad_cpu_kernel.cc @@ -15,6 +15,7 @@ */ #include #include +#include #include #include #include @@ -160,25 +161,29 @@ bool SparseAddGradCpuKernelMod::LaunchKernel(const std::vector \ + } + const std::vector> &SparseAddGradCpuKernelMod::GetFuncList() const { static const std::vector> func_list = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseAddGradCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &SparseAddGradCpuKernelMod::LaunchKernel}, + CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat32, int64_t, float), + CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat64, int64_t, double), + CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt8, int64_t, int8_t), + CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt16, int64_t, int16_t), + CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t, int32_t), + CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t), + CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex64, int64_t, std::complex), + CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, int64_t, std::complex), }; return func_list; }