From f1deedc570e43ca1094cc47f32f0852221ee57e8 Mon Sep 17 00:00:00 2001 From: muxiyin Date: Tue, 11 Oct 2022 10:12:38 +0800 Subject: [PATCH] reduce sup complex --- .../kernel/arrays/array_reduce_gpu_kernel.cc | 129 +++++++++++++++--- .../kernel/arrays/array_reduce_gpu_kernel.h | 7 + tests/st/ops/gpu/test_complex_op.py | 22 +++ 3 files changed, 136 insertions(+), 22 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.cc index 5874d780d71..20c4f2fd5e0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.cc @@ -16,10 +16,14 @@ #include #include "plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh" #include "ops/reduce.h" namespace mindspore { namespace kernel { +template +using Complex = mindspore::utils::Complex; constexpr auto kReduceMean = "ReduceMean"; constexpr auto kReduceMax = "ReduceMax"; constexpr auto kReduceSum = "ReduceSum"; @@ -33,6 +37,9 @@ constexpr size_t kIndex1 = 1; constexpr size_t kIndex2 = 2; constexpr size_t kIndex3 = 3; constexpr size_t kDynamicAxisInputNum = 2; +constexpr size_t kComplexFloatFlag = 1; +constexpr size_t kComplexDoubleFlag = 2; +constexpr size_t kComplexRate = 2; const std::map kReduceTypeMap = { {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, @@ -56,6 +63,8 @@ std::vector> ArrayRed {STATIC_REGISTER(kNumberTypeFloat32, float)}, {STATIC_REGISTER(kNumberTypeFloat16, half)}, {STATIC_REGISTER(kNumberTypeInt8, int8_t)}, + {STATIC_REGISTER(kNumberTypeComplex64, Complex)}, + {STATIC_REGISTER(kNumberTypeComplex128, Complex)}, {DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)}, {DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)}, {DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)}, @@ -63,12 +72,19 @@ std::vector> ArrayRed {DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)}, {DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)}, {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)}, - {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}}; + {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}, + {DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex)}, + {DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex)}, + {DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex)}, + {DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex)}, +}; std::vector> ArrayReduceGpuKernelMod::sum_list_ = { {STATIC_REGISTER(kNumberTypeFloat64, double)}, {STATIC_REGISTER(kNumberTypeFloat32, float)}, {STATIC_REGISTER(kNumberTypeFloat16, half)}, {STATIC_REGISTER(kNumberTypeBool, bool)}, + {STATIC_REGISTER(kNumberTypeComplex64, Complex)}, + {STATIC_REGISTER(kNumberTypeComplex128, Complex)}, {DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)}, {DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}, {DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)}, @@ -76,17 +92,29 @@ std::vector> ArrayRed {DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)}, {DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)}, {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)}, - {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}}; + {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}, + {DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex)}, + {DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex)}, + {DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex)}, + {DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex)}, +}; std::vector> ArrayReduceGpuKernelMod::max_min_mean_list_ = { {STATIC_REGISTER(kNumberTypeFloat64, double)}, {STATIC_REGISTER(kNumberTypeFloat32, float)}, {STATIC_REGISTER(kNumberTypeFloat16, half)}, + {STATIC_REGISTER(kNumberTypeComplex64, Complex)}, + {STATIC_REGISTER(kNumberTypeComplex128, Complex)}, {DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)}, {DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)}, {DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)}, {DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)}, {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)}, - {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}}; + {DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}, + {DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex)}, + {DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex)}, + {DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex)}, + {DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex)}, +}; std::map>> ArrayReduceGpuKernelMod::kernel_attr_list_ = { {prim::kPrimReduceSum->name(), sum_list_}, {prim::kPrimReduceMean->name(), max_min_mean_list_}, @@ -133,7 +161,16 @@ bool ArrayReduceGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s auto type_id = kernel_attr.GetInputAttr(kIndex0).first; auto type_name = TypeIdLabel(type_id); - data_type_ = GetCudnnDataType(type_name); + if (type_id == kNumberTypeComplex64) { + data_type_ = CUDNN_DATA_FLOAT; + complex_op_type = kComplexFloatFlag; + } else if (type_id == kNumberTypeComplex128) { + data_type_ = CUDNN_DATA_DOUBLE; + complex_op_type = kComplexDoubleFlag; + } else { + data_type_ = GetCudnnDataType(type_name); + } + InitResource(); return true; } @@ -147,6 +184,12 @@ void ArrayReduceGpuKernelMod::InitCudnnResource() { &workspace_size_), "cudnnGetReductionWorkspaceSize failed."); workspace_size_list_.push_back(workspace_size_); + if (complex_op_type != 0) { + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(outputC_descriptor_, &output_size_), + "cudnnGetTensorSizeInBytes failed."); + output_size_list_.clear(); + output_size_list_.push_back(output_size_ * kComplexRate); + } return; } @@ -198,7 +241,6 @@ int ArrayReduceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const InferInAndOutDesc(inputA_shape, outputC_shape); InferArrayReduceType(); - InitCudnnResource(); return KRET_OK; } @@ -279,6 +321,43 @@ void ArrayReduceGpuKernelMod::InferArrayReduceType() { return; } +template +void ArrayReduceGpuKernelMod::LaunchComplexKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + S alpha = static_cast(1.0f); + S beta = static_cast(0.0f); + T *workspace_addr = GetPossiblyNullDeviceAddress(workspace, 0); + S *input_real = reinterpret_cast(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(input_size_)); + S *input_imag = reinterpret_cast(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(input_size_)); + S *output_real = reinterpret_cast(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(output_size_)); + S *output_imag = reinterpret_cast(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(output_size_)); + if (input_real == nullptr || input_imag == nullptr || output_real == nullptr || output_imag == nullptr) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the memory alloc of failed"; + } + int output_count = output_size_ / sizeof(S); + int input_count = input_size_ / sizeof(S); + Real(input_addr, input_real, input_count, reinterpret_cast(stream_ptr)); + Imag(input_addr, input_imag, input_count, reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, + inputA_descriptor_, input_real, &beta, outputC_descriptor_, output_real), + "cudnnReduceTensor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, + inputA_descriptor_, input_imag, &beta, outputC_descriptor_, output_imag), + "cudnnReduceTensor failed."); + ElewiseComplexArith(output_count, BROADCAST_TYPE_COMPLEX, output_real, output_imag, output_addr, + reinterpret_cast(stream_ptr)); + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(input_real); + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(input_imag); + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(output_real); + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(output_imag); + return; +} + template bool ArrayReduceGpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &workspace, @@ -288,30 +367,36 @@ bool ArrayReduceGpuKernelMod::LaunchKernel(const std::vector &inputs } T *input_addr = GetDeviceAddress(inputs, 0); T *output_addr = GetDeviceAddress(outputs, 0); - - T alpha = static_cast(1.0f); - T beta = static_cast(0.0f); if (all_match_ || need_skip_execute_) { MS_LOG(DEBUG) << "The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel."; CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(output_addr, input_addr, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync failed in ArrayReduceGpuKernelMod::Launch."); + return true; + } + if (complex_op_type == kComplexFloatFlag) { + LaunchComplexKernel, float>(inputs, workspace, outputs, stream_ptr); + return true; + } else if (complex_op_type == kComplexDoubleFlag) { + LaunchComplexKernel, double>(inputs, workspace, outputs, stream_ptr); + return true; + } + T alpha = static_cast(1.0f); + T beta = static_cast(0.0f); + T *workspace_addr = GetPossiblyNullDeviceAddress(workspace, 0); + if (data_type_ == CUDNN_DATA_DOUBLE) { + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, + inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr), + "cudnnReduceTensor failed."); } else { - T *workspace_addr = GetPossiblyNullDeviceAddress(workspace, 0); - if (data_type_ == CUDNN_DATA_DOUBLE) { - CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( - cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, - inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr), - "cudnnReduceTensor failed."); - } else { - const float alphaf = static_cast(alpha); - const float betaf = static_cast(beta); - CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( - cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, - &alphaf, inputA_descriptor_, input_addr, &betaf, outputC_descriptor_, output_addr), - "cudnnReduceTensor failed."); - } + const float alphaf = static_cast(alpha); + const float betaf = static_cast(beta); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alphaf, + inputA_descriptor_, input_addr, &betaf, outputC_descriptor_, output_addr), + "cudnnReduceTensor failed."); } return true; } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h index 7037c03d913..0905c2ca299 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h @@ -70,7 +70,9 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod { keep_dims_ = false; all_match_ = false; is_null_input_ = false; + complex_op_type = 0; input_size_ = 0; + output_size_ = 0; workspace_size_ = 0; kernel_name_ = "ArrayReduce"; axis_.clear(); @@ -93,6 +95,9 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod { template bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr); + template + void LaunchComplexKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); using ReduceFunc = std::function &, const std::vector &, const std::vector &, void *)>; @@ -118,7 +123,9 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod { bool need_skip_execute_; bool all_match_; bool is_null_input_; + int complex_op_type; size_t input_size_; + size_t output_size_; size_t workspace_size_; static constexpr size_t kAxisIndex_{1}; std::string kernel_type_{"Unknown"}; diff --git a/tests/st/ops/gpu/test_complex_op.py b/tests/st/ops/gpu/test_complex_op.py index 1ed08d44d50..2ffad1380fe 100644 --- a/tests/st/ops/gpu/test_complex_op.py +++ b/tests/st/ops/gpu/test_complex_op.py @@ -15,6 +15,8 @@ import numpy as np import pytest +import torch +import mindspore.ops.operations as P from mindspore import Tensor from mindspore.nn import Cell from mindspore import ops @@ -133,3 +135,23 @@ def test_complex_broadcast(): res_ms = complex1 / complex2 res_to = np.divide(complex_1, complex_2) assert complex_compare(res_ms, res_to) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_complex_reducesum(): + """ + Feature: complex reduce Operation. + Description: Test complex reduce Operation. + Expectation: the result match given one. + """ + x0_real = np.random.rand(2, 3, 4).astype(np.float32) + x0_imag = np.random.rand(2, 3, 4).astype(np.float32) + x0 = Complex()(Tensor(x0_real), Tensor(x0_imag)) + axis0 = 2 + keep_dims0 = True + res_ms = P.ReduceSum(keep_dims0)(x0, axis0) + x0_torch = torch.complex(torch.tensor(x0_real), torch.tensor(x0_imag)) + res_torch = torch.sum(x0_torch, axis0, keep_dims0) + np.allclose(res_ms.asnumpy(), res_torch.numpy(), rtol=5e-03, atol=1.e-8)