!47670 SparseAddGrad supports more types

Merge pull request !47670 from YijieChen/ops
This commit is contained in:
i-robot 2023-01-10 01:27:31 +00:00 committed by Gitee
commit 140195f344
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 21 additions and 16 deletions

View File

@ -15,6 +15,7 @@
*/
#include <algorithm>
#include <utility>
#include <complex>
#include <set>
#include <map>
#include <functional>
@ -160,25 +161,29 @@ bool SparseAddGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
return true;
}
#define CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(ms_index_type, ms_value_type, index_type, value_type) \
{ \
KernelAttr() \
.AddInputAttr(ms_value_type) \
.AddInputAttr(ms_index_type) \
.AddInputAttr(ms_index_type) \
.AddInputAttr(ms_index_type) \
.AddOutputAttr(ms_value_type) \
.AddOutputAttr(ms_value_type), \
&SparseAddGradCpuKernelMod::LaunchKernel<value_type, index_type> \
}
const std::vector<std::pair<KernelAttr, SparseAddGradCpuKernelMod::KernelRunFunc>>
&SparseAddGradCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, SparseAddGradCpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseAddGradCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseAddGradCpuKernelMod::LaunchKernel<float, int64_t>},
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat32, int64_t, float),
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat64, int64_t, double),
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt8, int64_t, int8_t),
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt16, int64_t, int16_t),
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t, int32_t),
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t),
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex64, int64_t, std::complex<float>),
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, int64_t, std::complex<double>),
};
return func_list;
}