forked from mindspore-Ecosystem/mindspore
!47489 Add dtype support for ReduceMax, ReduceMin, ReduceProd, ReduceSum
Merge pull request !47489 from zhanzhan/reducesum
This commit is contained in:
commit
140418aa72
|
@ -59,14 +59,20 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
|
|||
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
|
||||
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}};
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::prod_list_ = {
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int32_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
|
@ -81,6 +87,8 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
|
|||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
|
||||
|
@ -92,7 +100,27 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
|
|||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
};
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::max_min_mean_list_ = {
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::max_min_list_ = {
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int32_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
};
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::mean_list_ = {
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
|
@ -106,9 +134,9 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
|
|||
};
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>>>
|
||||
ArrayReduceGpuKernelMod::kernel_attr_list_ = {
|
||||
{prim::kPrimReduceSum->name(), sum_list_}, {prim::kPrimReduceMean->name(), max_min_mean_list_},
|
||||
{prim::kPrimReduceProd->name(), prod_list_}, {prim::kPrimReduceMax->name(), max_min_mean_list_},
|
||||
{prim::kPrimReduceMin->name(), max_min_mean_list_}, {prim::kPrimReduceAll->name(), all_any_list_},
|
||||
{prim::kPrimReduceSum->name(), sum_list_}, {prim::kPrimReduceMean->name(), mean_list_},
|
||||
{prim::kPrimReduceProd->name(), prod_list_}, {prim::kPrimReduceMax->name(), max_min_list_},
|
||||
{prim::kPrimReduceMin->name(), max_min_list_}, {prim::kPrimReduceAll->name(), all_any_list_},
|
||||
{prim::kPrimReduceAny->name(), all_any_list_}};
|
||||
|
||||
std::vector<KernelAttr> ArrayReduceGpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -109,7 +109,8 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
|
|||
static std::vector<std::pair<KernelAttr, ReduceFunc>> all_any_list_;
|
||||
static std::vector<std::pair<KernelAttr, ReduceFunc>> prod_list_;
|
||||
static std::vector<std::pair<KernelAttr, ReduceFunc>> sum_list_;
|
||||
static std::vector<std::pair<KernelAttr, ReduceFunc>> max_min_mean_list_;
|
||||
static std::vector<std::pair<KernelAttr, ReduceFunc>> max_min_list_;
|
||||
static std::vector<std::pair<KernelAttr, ReduceFunc>> mean_list_;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, ReduceFunc>>> kernel_attr_list_;
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue