From c00e5d8847993ae4f08063e3c5751d7dc6736a85 Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Wed, 4 Jan 2023 14:03:01 +0800 Subject: [PATCH] Add dtype support for ReduceMax --- .../kernel/arrays/array_reduce_gpu_kernel.cc | 40 ++++++++++++++++--- .../kernel/arrays/array_reduce_gpu_kernel.h | 3 +- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.cc index 24966518167..af274976211 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.cc @@ -59,14 +59,20 @@ std::vector> ArrayRed {REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)}, {REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}}; std::vector> ArrayReduceGpuKernelMod::prod_list_ = { - {REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)}, - {REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)}, {REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)}, {REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)}, {REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)}, {REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)}, {REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)}, {REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}, + {REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)}, + {REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)}, + {REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)}, + {REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)}, + {REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)}, + {REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int32_t)}, + {REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t)}, + {REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t)}, {REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex)}, {REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex)}, {REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex)}, @@ -81,6 +87,8 @@ std::vector> ArrayRed {REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)}, {REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)}, {REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}, + {REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)}, + {REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)}, {REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)}, {REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)}, {REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)}, @@ -92,7 +100,27 @@ std::vector> ArrayRed {REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex)}, {REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex)}, }; -std::vector> ArrayReduceGpuKernelMod::max_min_mean_list_ = { +std::vector> ArrayReduceGpuKernelMod::max_min_list_ = { + {REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)}, + {REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)}, + {REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)}, + {REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)}, + {REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)}, + {REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}, + {REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)}, + {REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)}, + {REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, int16_t)}, + {REDUCE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, int16_t)}, + {REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t)}, + {REDUCE_REGISTER(kNumberTypeInt32, kNumberTypeInt64, int32_t)}, + {REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt32, int64_t)}, + {REDUCE_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t)}, + {REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex)}, + {REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex)}, + {REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex)}, + {REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex)}, +}; +std::vector> ArrayReduceGpuKernelMod::mean_list_ = { {REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)}, {REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)}, {REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)}, @@ -106,9 +134,9 @@ std::vector> ArrayRed }; std::map>> ArrayReduceGpuKernelMod::kernel_attr_list_ = { - {prim::kPrimReduceSum->name(), sum_list_}, {prim::kPrimReduceMean->name(), max_min_mean_list_}, - {prim::kPrimReduceProd->name(), prod_list_}, {prim::kPrimReduceMax->name(), max_min_mean_list_}, - {prim::kPrimReduceMin->name(), max_min_mean_list_}, {prim::kPrimReduceAll->name(), all_any_list_}, + {prim::kPrimReduceSum->name(), sum_list_}, {prim::kPrimReduceMean->name(), mean_list_}, + {prim::kPrimReduceProd->name(), prod_list_}, {prim::kPrimReduceMax->name(), max_min_list_}, + {prim::kPrimReduceMin->name(), max_min_list_}, {prim::kPrimReduceAll->name(), all_any_list_}, {prim::kPrimReduceAny->name(), all_any_list_}}; std::vector ArrayReduceGpuKernelMod::GetOpSupport() { diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h index 681772df2ce..a2268fded48 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h @@ -109,7 +109,8 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod { static std::vector> all_any_list_; static std::vector> prod_list_; static std::vector> sum_list_; - static std::vector> max_min_mean_list_; + static std::vector> max_min_list_; + static std::vector> mean_list_; static std::map>> kernel_attr_list_; private: