forked from mindspore-Ecosystem/mindspore
!47670 SparseAddGrad supports more types
Merge pull request !47670 from YijieChen/ops
This commit is contained in:
commit
140195f344
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <complex>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
@ -160,25 +161,29 @@ bool SparseAddGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(ms_index_type, ms_value_type, index_type, value_type) \
|
||||||
|
{ \
|
||||||
|
KernelAttr() \
|
||||||
|
.AddInputAttr(ms_value_type) \
|
||||||
|
.AddInputAttr(ms_index_type) \
|
||||||
|
.AddInputAttr(ms_index_type) \
|
||||||
|
.AddInputAttr(ms_index_type) \
|
||||||
|
.AddOutputAttr(ms_value_type) \
|
||||||
|
.AddOutputAttr(ms_value_type), \
|
||||||
|
&SparseAddGradCpuKernelMod::LaunchKernel<value_type, index_type> \
|
||||||
|
}
|
||||||
|
|
||||||
const std::vector<std::pair<KernelAttr, SparseAddGradCpuKernelMod::KernelRunFunc>>
|
const std::vector<std::pair<KernelAttr, SparseAddGradCpuKernelMod::KernelRunFunc>>
|
||||||
&SparseAddGradCpuKernelMod::GetFuncList() const {
|
&SparseAddGradCpuKernelMod::GetFuncList() const {
|
||||||
static const std::vector<std::pair<KernelAttr, SparseAddGradCpuKernelMod::KernelRunFunc>> func_list = {
|
static const std::vector<std::pair<KernelAttr, SparseAddGradCpuKernelMod::KernelRunFunc>> func_list = {
|
||||||
{KernelAttr()
|
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat32, int64_t, float),
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeFloat64, int64_t, double),
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt8, int64_t, int8_t),
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt16, int64_t, int16_t),
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t, int32_t),
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t),
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex64, int64_t, std::complex<float>),
|
||||||
&SparseAddGradCpuKernelMod::LaunchKernel<float, int32_t>},
|
CPU_SPARSE_ADD_GRAD_KERNEL_REGISTER(kNumberTypeInt64, kNumberTypeComplex128, int64_t, std::complex<double>),
|
||||||
{KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
&SparseAddGradCpuKernelMod::LaunchKernel<float, int64_t>},
|
|
||||||
};
|
};
|
||||||
return func_list;
|
return func_list;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue