Support complex mstype for ms.numpy,roll function.

This commit is contained in:
huangziling 2022-09-16 14:19:21 +08:00
parent 66ba9e952b
commit c0560eb779
4 changed files with 26 additions and 3 deletions

View File

@ -21,6 +21,7 @@
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
"mindspore/mindspore/ccsrc/backend/common/optimizer/op_adaptation_info_factory.h" "runtime/explicit"
"mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cu" "runtime/int"
# Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"

View File

@ -15,9 +15,18 @@
*/
#include "plugin/device/gpu/kernel/arrays/concatv2_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ConcatV2FwdGpuKernelMod, Complex<double>)
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ConcatV2FwdGpuKernelMod, Complex<float>)
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ConcatV2FwdGpuKernelMod, double)

View File

@ -19,6 +19,11 @@
#include <cuda_runtime.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cuh"
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
__global__ void Concat(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis,
int *len_axis, T **inputs, T *output) {
@ -52,6 +57,12 @@ void ConcatKernel(const size_t size, const int input_num, const int all_size_bef
return;
}
template CUDA_LIB_EXPORT void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, Complex<double> **inputs,
Complex<double> *output, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, Complex<float> **inputs,
Complex<float> *output, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, double **inputs, double *output,
cudaStream_t cuda_stream);

View File

@ -1078,7 +1078,8 @@ def roll(a, shift, axis=None):
# F.strided_slice only supports float on cpu, this will change once more supports
# are added.
if not _check_is_float(original_dtype):
a = a.astype(mstype.float32)
if not original_dtype in (mstype.complex64, mstype.complex128):
a = a.astype(mstype.float32)
if axis is None:
restore_shape = True
axis = 0
@ -1090,7 +1091,8 @@ def roll(a, shift, axis=None):
if restore_shape:
a = a.reshape(original_shape)
if not _check_is_float(original_dtype):
a = a.astype(original_dtype)
if not original_dtype in (mstype.complex64, mstype.complex128):
a = a.astype(original_dtype)
return a