!43646 support reduce complex
Merge pull request !43646 from 黄晓/reduce_complex_1010
This commit is contained in:
commit
d0a4fabd6b
|
@ -16,10 +16,14 @@
|
|||
|
||||
#include <memory>
|
||||
#include "plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
|
||||
#include "ops/reduce.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
constexpr auto kReduceMean = "ReduceMean";
|
||||
constexpr auto kReduceMax = "ReduceMax";
|
||||
constexpr auto kReduceSum = "ReduceSum";
|
||||
|
@ -33,6 +37,9 @@ constexpr size_t kIndex1 = 1;
|
|||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
constexpr size_t kDynamicAxisInputNum = 2;
|
||||
constexpr size_t kComplexFloatFlag = 1;
|
||||
constexpr size_t kComplexDoubleFlag = 2;
|
||||
constexpr size_t kComplexRate = 2;
|
||||
|
||||
const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = {
|
||||
{"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG},
|
||||
|
@ -56,6 +63,8 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
|
|||
{STATIC_REGISTER(kNumberTypeFloat32, float)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat16, half)},
|
||||
{STATIC_REGISTER(kNumberTypeInt8, int8_t)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
|
||||
{DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
|
@ -63,12 +72,19 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
|
|||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}};
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
};
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::sum_list_ = {
|
||||
{STATIC_REGISTER(kNumberTypeFloat64, double)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat32, float)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat16, half)},
|
||||
{STATIC_REGISTER(kNumberTypeBool, bool)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
|
||||
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
|
@ -76,17 +92,29 @@ std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayRed
|
|||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}};
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
};
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::max_min_mean_list_ = {
|
||||
{STATIC_REGISTER(kNumberTypeFloat64, double)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat32, float)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat16, half)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)}};
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
};
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>>>
|
||||
ArrayReduceGpuKernelMod::kernel_attr_list_ = {
|
||||
{prim::kPrimReduceSum->name(), sum_list_}, {prim::kPrimReduceMean->name(), max_min_mean_list_},
|
||||
|
@ -133,7 +161,16 @@ bool ArrayReduceGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
|
|||
|
||||
auto type_id = kernel_attr.GetInputAttr(kIndex0).first;
|
||||
auto type_name = TypeIdLabel(type_id);
|
||||
data_type_ = GetCudnnDataType(type_name);
|
||||
if (type_id == kNumberTypeComplex64) {
|
||||
data_type_ = CUDNN_DATA_FLOAT;
|
||||
complex_op_type = kComplexFloatFlag;
|
||||
} else if (type_id == kNumberTypeComplex128) {
|
||||
data_type_ = CUDNN_DATA_DOUBLE;
|
||||
complex_op_type = kComplexDoubleFlag;
|
||||
} else {
|
||||
data_type_ = GetCudnnDataType(type_name);
|
||||
}
|
||||
|
||||
InitResource();
|
||||
return true;
|
||||
}
|
||||
|
@ -147,6 +184,12 @@ void ArrayReduceGpuKernelMod::InitCudnnResource() {
|
|||
&workspace_size_),
|
||||
"cudnnGetReductionWorkspaceSize failed.");
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
if (complex_op_type != 0) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(outputC_descriptor_, &output_size_),
|
||||
"cudnnGetTensorSizeInBytes failed.");
|
||||
output_size_list_.clear();
|
||||
output_size_list_.push_back(output_size_ * kComplexRate);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -198,7 +241,6 @@ int ArrayReduceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
|
||||
InferInAndOutDesc(inputA_shape, outputC_shape);
|
||||
InferArrayReduceType();
|
||||
|
||||
InitCudnnResource();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
@ -279,6 +321,43 @@ void ArrayReduceGpuKernelMod::InferArrayReduceType() {
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void ArrayReduceGpuKernelMod::LaunchComplexKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
S alpha = static_cast<S>(1.0f);
|
||||
S beta = static_cast<S>(0.0f);
|
||||
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);
|
||||
S *input_real = reinterpret_cast<S *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(input_size_));
|
||||
S *input_imag = reinterpret_cast<S *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(input_size_));
|
||||
S *output_real = reinterpret_cast<S *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(output_size_));
|
||||
S *output_imag = reinterpret_cast<S *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(output_size_));
|
||||
if (input_real == nullptr || input_imag == nullptr || output_real == nullptr || output_imag == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the memory alloc of failed";
|
||||
}
|
||||
int output_count = output_size_ / sizeof(S);
|
||||
int input_count = input_size_ / sizeof(S);
|
||||
Real(input_addr, input_real, input_count, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
Imag(input_addr, input_imag, input_count, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha,
|
||||
inputA_descriptor_, input_real, &beta, outputC_descriptor_, output_real),
|
||||
"cudnnReduceTensor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha,
|
||||
inputA_descriptor_, input_imag, &beta, outputC_descriptor_, output_imag),
|
||||
"cudnnReduceTensor failed.");
|
||||
ElewiseComplexArith(output_count, BROADCAST_TYPE_COMPLEX, output_real, output_imag, output_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(input_real);
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(input_imag);
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(output_real);
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(output_imag);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ArrayReduceGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
|
@ -288,30 +367,36 @@ bool ArrayReduceGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
}
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
T alpha = static_cast<T>(1.0f);
|
||||
T beta = static_cast<T>(0.0f);
|
||||
if (all_match_ || need_skip_execute_) {
|
||||
MS_LOG(DEBUG)
|
||||
<< "The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel.";
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(output_addr, input_addr, input_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync failed in ArrayReduceGpuKernelMod::Launch.");
|
||||
return true;
|
||||
}
|
||||
if (complex_op_type == kComplexFloatFlag) {
|
||||
LaunchComplexKernel<Complex<float>, float>(inputs, workspace, outputs, stream_ptr);
|
||||
return true;
|
||||
} else if (complex_op_type == kComplexDoubleFlag) {
|
||||
LaunchComplexKernel<Complex<double>, double>(inputs, workspace, outputs, stream_ptr);
|
||||
return true;
|
||||
}
|
||||
T alpha = static_cast<T>(1.0f);
|
||||
T beta = static_cast<T>(0.0f);
|
||||
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);
|
||||
if (data_type_ == CUDNN_DATA_DOUBLE) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha,
|
||||
inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr),
|
||||
"cudnnReduceTensor failed.");
|
||||
} else {
|
||||
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 0);
|
||||
if (data_type_ == CUDNN_DATA_DOUBLE) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha,
|
||||
inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr),
|
||||
"cudnnReduceTensor failed.");
|
||||
} else {
|
||||
const float alphaf = static_cast<float>(alpha);
|
||||
const float betaf = static_cast<float>(beta);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_,
|
||||
&alphaf, inputA_descriptor_, input_addr, &betaf, outputC_descriptor_, output_addr),
|
||||
"cudnnReduceTensor failed.");
|
||||
}
|
||||
const float alphaf = static_cast<float>(alpha);
|
||||
const float betaf = static_cast<float>(beta);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alphaf,
|
||||
inputA_descriptor_, input_addr, &betaf, outputC_descriptor_, output_addr),
|
||||
"cudnnReduceTensor failed.");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -70,7 +70,9 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
|
|||
keep_dims_ = false;
|
||||
all_match_ = false;
|
||||
is_null_input_ = false;
|
||||
complex_op_type = 0;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
workspace_size_ = 0;
|
||||
kernel_name_ = "ArrayReduce";
|
||||
axis_.clear();
|
||||
|
@ -93,6 +95,9 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
|
|||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
template <typename T, typename S>
|
||||
void LaunchComplexKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
using ReduceFunc =
|
||||
std::function<bool(ArrayReduceGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
|
||||
|
@ -118,7 +123,9 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
|
|||
bool need_skip_execute_;
|
||||
bool all_match_;
|
||||
bool is_null_input_;
|
||||
int complex_op_type;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
static constexpr size_t kAxisIndex_{1};
|
||||
std::string kernel_type_{"Unknown"};
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import ops
|
||||
|
@ -133,3 +135,23 @@ def test_complex_broadcast():
|
|||
res_ms = complex1 / complex2
|
||||
res_to = np.divide(complex_1, complex_2)
|
||||
assert complex_compare(res_ms, res_to)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_complex_reducesum():
|
||||
"""
|
||||
Feature: complex reduce Operation.
|
||||
Description: Test complex reduce Operation.
|
||||
Expectation: the result match given one.
|
||||
"""
|
||||
x0_real = np.random.rand(2, 3, 4).astype(np.float32)
|
||||
x0_imag = np.random.rand(2, 3, 4).astype(np.float32)
|
||||
x0 = Complex()(Tensor(x0_real), Tensor(x0_imag))
|
||||
axis0 = 2
|
||||
keep_dims0 = True
|
||||
res_ms = P.ReduceSum(keep_dims0)(x0, axis0)
|
||||
x0_torch = torch.complex(torch.tensor(x0_real), torch.tensor(x0_imag))
|
||||
res_torch = torch.sum(x0_torch, axis0, keep_dims0)
|
||||
np.allclose(res_ms.asnumpy(), res_torch.numpy(), rtol=5e-03, atol=1.e-8)
|
||||
|
|
Loading…
Reference in New Issue