!47489 Add dtype support for ReduceMax, ReduceMin, ReduceProd, ReduceSum

Merge pull request !47489 from zhanzhan/reducesum
This commit is contained in:
i-robot 2023-01-05 09:43:01 +00:00 committed by Gitee
commit 140418aa72
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 36 additions and 7 deletions

View File

@ -59,14 +59,20 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}};
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::prod_list_ = {
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int32_t)},
{REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t)},
{REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
@ -81,6 +87,8 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
@ -92,7 +100,27 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
};
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::max_min_mean_list_ = {
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::max_min_list_ = {
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)},
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
{REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int32_t)},
{REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t)},
{REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
};
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::mean_list_ = {
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
@ -106,9 +134,9 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
};
std::map<std::string, std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>>>
ArrayReduceGpuKernelMod::kernel_attr_list_ = {
{prim::kPrimReduceSum->name(), sum_list_}, {prim::kPrimReduceMean->name(), max_min_mean_list_},
{prim::kPrimReduceProd->name(), prod_list_}, {prim::kPrimReduceMax->name(), max_min_mean_list_},
{prim::kPrimReduceMin->name(), max_min_mean_list_}, {prim::kPrimReduceAll->name(), all_any_list_},
{prim::kPrimReduceSum->name(), sum_list_}, {prim::kPrimReduceMean->name(), mean_list_},
{prim::kPrimReduceProd->name(), prod_list_}, {prim::kPrimReduceMax->name(), max_min_list_},
{prim::kPrimReduceMin->name(), max_min_list_}, {prim::kPrimReduceAll->name(), all_any_list_},
{prim::kPrimReduceAny->name(), all_any_list_}};
std::vector<KernelAttr> ArrayReduceGpuKernelMod::GetOpSupport() {

View File

@ -109,7 +109,8 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
static std::vector<std::pair<KernelAttr, ReduceFunc>> all_any_list_;
static std::vector<std::pair<KernelAttr, ReduceFunc>> prod_list_;
static std::vector<std::pair<KernelAttr, ReduceFunc>> sum_list_;
static std::vector<std::pair<KernelAttr, ReduceFunc>> max_min_mean_list_;
static std::vector<std::pair<KernelAttr, ReduceFunc>> max_min_list_;
static std::vector<std::pair<KernelAttr, ReduceFunc>> mean_list_;
static std::map<std::string, std::vector<std::pair<KernelAttr, ReduceFunc>>> kernel_attr_list_;
private: