forked from mindspore-Ecosystem/mindspore
Support complex mstype for ms.numpy,roll function.
This commit is contained in:
parent
66ba9e952b
commit
c0560eb779
|
@ -21,6 +21,7 @@
|
|||
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
|
||||
"mindspore/mindspore/ccsrc/backend/common/optimizer/op_adaptation_info_factory.h" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cu" "runtime/int"
|
||||
|
||||
# Modelzoo
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
|
||||
|
|
|
@ -15,9 +15,18 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/concatv2_gpu_kernel.h"
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ConcatV2FwdGpuKernelMod, Complex<double>)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ConcatV2FwdGpuKernelMod, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ConcatV2FwdGpuKernelMod, double)
|
||||
|
|
|
@ -19,6 +19,11 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename T>
|
||||
__global__ void Concat(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis,
|
||||
int *len_axis, T **inputs, T *output) {
|
||||
|
@ -52,6 +57,12 @@ void ConcatKernel(const size_t size, const int input_num, const int all_size_bef
|
|||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
|
||||
const int all_size_axis, int *len_axis, Complex<double> **inputs,
|
||||
Complex<double> *output, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
|
||||
const int all_size_axis, int *len_axis, Complex<float> **inputs,
|
||||
Complex<float> *output, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
|
||||
const int all_size_axis, int *len_axis, double **inputs, double *output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -1078,7 +1078,8 @@ def roll(a, shift, axis=None):
|
|||
# F.strided_slice only supports float on cpu, this will change once more supports
|
||||
# are added.
|
||||
if not _check_is_float(original_dtype):
|
||||
a = a.astype(mstype.float32)
|
||||
if not original_dtype in (mstype.complex64, mstype.complex128):
|
||||
a = a.astype(mstype.float32)
|
||||
if axis is None:
|
||||
restore_shape = True
|
||||
axis = 0
|
||||
|
@ -1090,7 +1091,8 @@ def roll(a, shift, axis=None):
|
|||
if restore_shape:
|
||||
a = a.reshape(original_shape)
|
||||
if not _check_is_float(original_dtype):
|
||||
a = a.astype(original_dtype)
|
||||
if not original_dtype in (mstype.complex64, mstype.complex128):
|
||||
a = a.astype(original_dtype)
|
||||
return a
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue