forked from mindspore-Ecosystem/mindspore
!47670 SparseAddGrad supports more types
Merge pull request !47670 from YijieChen/ops
This commit is contained in:
commit
140195f344
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
|
@ -160,25 +161,29 @@ bool SparseAddGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
|
|||
return true;
|
||||
}
|
||||
|
||||
#define CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(ms_index_type, ms_value_type, index_type, value_type) \
|
||||
{ \
|
||||
KernelAttr() \
|
||||
.AddInputAttr(ms_value_type) \
|
||||
.AddInputAttr(ms_index_type) \
|
||||
.AddInputAttr(ms_index_type) \
|
||||
.AddInputAttr(ms_index_type) \
|
||||
.AddOutputAttr(ms_value_type) \
|
||||
.AddOutputAttr(ms_value_type), \
|
||||
&SparseAddGradCpuKernelMod::LaunchKernel<value_type, index_type> \
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, SparseAddGradCpuKernelMod::KernelRunFunc>>
|
||||
&SparseAddGradCpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, SparseAddGradCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseAddGradCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseAddGradCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat32, int64_t, float),
|
||||
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat64, int64_t, double),
|
||||
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt8, int64_t, int8_t),
|
||||
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt16, int64_t, int16_t),
|
||||
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t, int32_t),
|
||||
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t),
|
||||
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex64, int64_t, std::complex<float>),
|
||||
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, int64_t, std::complex<double>),
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue