add gpu complex ops

This commit is contained in:
zhouyaqiang 2021-08-03 18:47:18 +08:00
parent 2ee8bf46ca
commit dad375abb9
29 changed files with 1063 additions and 130 deletions

View File

@ -21,6 +21,7 @@
"mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc" "containerOutOfBounds"
"mindspore/mindspore/core/ops/strided_slice.cc" "zerodivcond"
"mindspore/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc" "useStlAlgorithm"
"mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.cc" "unknownMacro"
# MindData
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"

View File

@ -36,25 +36,41 @@ namespace mindspore {
namespace kernel {
constexpr char kAxis[] = "axis";
constexpr char kTypeInt32[] = "Int32";
const std::unordered_map<std::string, TypeId> type_id_maps = {
{"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
{"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
{"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
{"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
{"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
{"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
{"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
{"bool", TypeId::kNumberTypeBool}, {"complex64", TypeId::kNumberTypeComplex64}};
const std::unordered_map<std::string, TypeId> type_id_maps = {{"float", TypeId::kNumberTypeFloat32},
{"float16", TypeId::kNumberTypeFloat16},
{"float32", TypeId::kNumberTypeFloat32},
{"float64", TypeId::kNumberTypeFloat64},
{"int", TypeId::kNumberTypeInt},
{"int8", TypeId::kNumberTypeInt8},
{"int16", TypeId::kNumberTypeInt16},
{"int32", TypeId::kNumberTypeInt32},
{"int64", TypeId::kNumberTypeInt64},
{"uint", TypeId::kNumberTypeUInt},
{"uint8", TypeId::kNumberTypeUInt8},
{"uint16", TypeId::kNumberTypeUInt16},
{"uint32", TypeId::kNumberTypeUInt32},
{"uint64", TypeId::kNumberTypeUInt64},
{"bool", TypeId::kNumberTypeBool},
{"complex64", TypeId::kNumberTypeComplex64},
{"complex128", TypeId::kNumberTypeComplex128}};
const std::map<TypeId, std::string> type_id_str_map = {
{TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
{TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
{TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
{TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "bool"}, {TypeId::kNumberTypeComplex64, "complex64"}};
const std::map<TypeId, std::string> type_id_str_map = {{TypeId::kNumberTypeFloat32, "float32"},
{TypeId::kNumberTypeFloat16, "float16"},
{TypeId::kNumberTypeFloat, "float"},
{TypeId::kNumberTypeFloat64, "float64"},
{TypeId::kNumberTypeInt, "int"},
{TypeId::kNumberTypeInt8, "int8"},
{TypeId::kNumberTypeInt16, "int16"},
{TypeId::kNumberTypeInt32, "int32"},
{TypeId::kNumberTypeInt64, "int64"},
{TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"},
{TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"},
{TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "bool"},
{TypeId::kNumberTypeComplex64, "complex64"},
{TypeId::kNumberTypeComplex128, "complex128"}};
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},

View File

@ -42,6 +42,12 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutput
int8_t, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
int8_t, bool)
template <typename T>
using Complex = mindspore::utils::Complex<T>;
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, int8_t, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, int8_t, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
int16_t, int8_t)
@ -67,6 +73,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutpu
CastGpuKernel, int16_t, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
int16_t, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, int16_t, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, int16_t, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
int32_t, int8_t)
@ -92,6 +102,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutpu
CastGpuKernel, int32_t, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
int32_t, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, int32_t, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, int32_t, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
int64_t, int8_t)
@ -117,6 +131,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutpu
CastGpuKernel, int64_t, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
int64_t, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, int64_t, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, int64_t, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
uint8_t, int8_t)
@ -142,6 +160,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutpu
CastGpuKernel, uint8_t, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
uint8_t, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, uint8_t, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, uint8_t, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
uint16_t, int8_t)
@ -167,6 +189,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutp
CastGpuKernel, uint16_t, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
uint16_t, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, uint16_t, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, uint16_t, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
uint32_t, int8_t)
@ -192,6 +218,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutp
CastGpuKernel, uint32_t, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
uint32_t, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, uint32_t, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, uint32_t, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
uint64_t, int8_t)
@ -217,6 +247,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutp
CastGpuKernel, uint64_t, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
uint64_t, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, uint64_t, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, uint64_t, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
half, int8_t)
@ -242,6 +276,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOut
CastGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
half, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, half, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, half, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
float, int8_t)
@ -267,6 +305,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
CastGpuKernel, float, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
float, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, float, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, float, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
double, int8_t)
@ -292,6 +334,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOut
CastGpuKernel, double, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
double, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, double, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, double, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
bool, int8_t)
@ -317,5 +363,64 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutput
bool, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
bool, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, bool, Complex<float>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, bool, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt8),
CastGpuKernel, Complex<float>, int8_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt16),
CastGpuKernel, Complex<float>, int16_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32),
CastGpuKernel, Complex<float>, int32_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64),
CastGpuKernel, Complex<float>, int64_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt8),
CastGpuKernel, Complex<float>, uint8_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt16),
CastGpuKernel, Complex<float>, uint16_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt32),
CastGpuKernel, Complex<float>, uint32_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt64),
CastGpuKernel, Complex<float>, uint64_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
CastGpuKernel, Complex<float>, float)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat64),
CastGpuKernel, Complex<float>, double)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat16),
CastGpuKernel, Complex<float>, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeBool),
CastGpuKernel, Complex<float>, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex128),
CastGpuKernel, Complex<float>, Complex<double>)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt8),
CastGpuKernel, Complex<double>, int8_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt16),
CastGpuKernel, Complex<double>, int16_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32),
CastGpuKernel, Complex<double>, int32_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64),
CastGpuKernel, Complex<double>, int64_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt8),
CastGpuKernel, Complex<double>, uint8_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt16),
CastGpuKernel, Complex<double>, uint16_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt32),
CastGpuKernel, Complex<double>, uint32_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt64),
CastGpuKernel, Complex<double>, uint64_t)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat32),
CastGpuKernel, Complex<double>, float)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
CastGpuKernel, Complex<double>, double)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat16),
CastGpuKernel, Complex<double>, half)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeBool),
CastGpuKernel, Complex<double>, bool)
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex64),
CastGpuKernel, Complex<double>, Complex<float>)
} // namespace kernel
} // namespace mindspore

View File

@ -37,14 +37,14 @@ struct EqualFunc {
};
template <>
struct EqualFunc <half> {
struct EqualFunc<half> {
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? true : false;
}
};
template <>
struct EqualFunc <float> {
struct EqualFunc<float> {
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
return std::abs(lhs - rhs) < 1e-9 ? true : false;
}
@ -56,15 +56,16 @@ struct GreaterEqualFunc {
};
template <>
struct GreaterEqualFunc <half> {
struct GreaterEqualFunc<half> {
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ?
true : (__half2float(lhs) > __half2float(rhs) ? true : false);
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9
? true
: (__half2float(lhs) > __half2float(rhs) ? true : false);
}
};
template <>
struct GreaterEqualFunc <float> {
struct GreaterEqualFunc<float> {
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
return std::abs(lhs - rhs) < 1e-9 ? true : (lhs > rhs ? true : false);
}
@ -76,15 +77,16 @@ struct LessEqualFunc {
};
template <>
struct LessEqualFunc <half> {
struct LessEqualFunc<half> {
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ?
true : (__half2float(lhs) < __half2float(rhs) ? true : false);
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9
? true
: (__half2float(lhs) < __half2float(rhs) ? true : false);
}
};
template <>
struct LessEqualFunc <float> {
struct LessEqualFunc<float> {
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
return std::abs(lhs - rhs) < 1e-9 ? true : (lhs < rhs ? true : false);
}
@ -96,14 +98,14 @@ struct NotEqualFunc {
};
template <>
struct NotEqualFunc <half> {
struct NotEqualFunc<half> {
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? false : true;
}
};
template <>
struct NotEqualFunc <float> {
struct NotEqualFunc<float> {
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
return std::abs(lhs - rhs) < 1e-9 ? false : true;
}
@ -155,28 +157,52 @@ struct PowerFunc<half2> {
template <typename T>
struct RealDivFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs / rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs / rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
return (lhs / rhs);
}
};
template <typename T>
struct DivFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs / rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs / rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
return (lhs / rhs);
}
};
template <typename T>
struct MulFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs * rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs * rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
return (lhs * rhs);
}
};
template <typename T>
struct SubFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs - rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs - rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs - rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
return (lhs - rhs);
}
};
template <typename T>
struct AddFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs + rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs + rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
return (lhs + rhs);
}
};
// DivNoNan check if rhs is less than epsilon
template <typename T>
struct DivNoNanFunc {
@ -525,6 +551,13 @@ __global__ void ElewiseArithKernel(const int nums, const T *x0, const T *x1, T *
}
}
template <typename T1, typename T2, typename T3, typename Func>
__global__ void ElewiseArithComplexKernel(const int nums, const T1 *x0, const T2 *x1, Complex<T3> *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
y[pos] = Func()(x0[pos], x1[pos]);
}
}
template <typename T>
void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
switch (op) {
@ -567,6 +600,30 @@ void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, c
}
}
template <typename T1, typename T2, typename T3>
void ElewiseArithComplexKernel(const int &nums, enum BroadcastOpType op, const T1 *x0, const T2 *x1, Complex<T3> *y,
cudaStream_t stream) {
switch (op) {
case BROADCAST_TYPE_ADD:
return ElewiseArithComplexKernel<T1, T2, T3, AddFunc<T3>>
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_SUB:
return ElewiseArithComplexKernel<T1, T2, T3, SubFunc<T3>>
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_MUL:
return ElewiseArithComplexKernel<T1, T2, T3, MulFunc<T3>>
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_DIV:
return ElewiseArithComplexKernel<T1, T2, T3, DivFunc<T3>>
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_REALDIV:
return ElewiseArithComplexKernel<T1, T2, T3, RealDivFunc<T3>>
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
default:
break;
}
}
template <typename T>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
return ElewiseArithKernel(nums, op, x0, x1, y, stream);
@ -584,6 +641,12 @@ void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, cons
}
}
template <typename T1, typename T2, typename T3>
void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const T1 *x0, const T2 *x1, Complex<T3> *y,
cudaStream_t stream) {
return ElewiseArithComplexKernel(nums, op, x0, x1, y, stream);
}
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const double *x0, const double *x1, double *y,
cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, float *y,
@ -608,6 +671,18 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint6
uint64_t *y, cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const bool *x0, const bool *x1, bool *y,
cudaStream_t stream);
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const Complex<float> *x0,
const Complex<float> *x1, Complex<float> *y, cudaStream_t stream);
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const Complex<float> *x0, const float *x1,
Complex<float> *y, cudaStream_t stream);
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const float *x0, const Complex<float> *x1,
Complex<float> *y, cudaStream_t stream);
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const Complex<double> *x0,
const Complex<double> *x1, Complex<double> *y, cudaStream_t stream);
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const Complex<double> *x0, const double *x1,
Complex<double> *y, cudaStream_t stream);
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const double *x0, const Complex<double> *x1,
Complex<double> *y, cudaStream_t stream);
// Broadcast comparison
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
@ -734,8 +809,8 @@ template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint64_t *x0,
const uint64_t *x1, bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0,
const bool *x1, bool *y, cudaStream_t stream);
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0, const bool *x1,
bool *y, cudaStream_t stream);
// Broadcast Arithmetic
template <typename T, typename Func>
__global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
@ -772,6 +847,41 @@ __global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const siz
}
}
template <typename T1, typename T2, typename T3, typename Func>
__global__ void BroadcastComplexArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T1 *x0, const T2 *x1, Complex<T3> *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
pos += blockDim.x * gridDim.x) {
size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
size_t k = pos / (d3 * d4 * d5 * d6) % d2;
size_t l = pos / (d4 * d5 * d6) % d3;
size_t m = pos / (d5 * d6) % d4;
size_t n = pos / d6 % d5;
size_t o = pos % d6;
size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
l_index += Index(l, l3) * l4 * l5 * l6;
l_index += Index(m, l4) * l5 * l6;
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
r_index += Index(l, r3) * r4 * r5 * r6;
r_index += Index(m, r4) * r5 * r6;
r_index += Index(n, r5) * r6;
r_index += Index(o, r6);
y[pos] = Func()(x0[l_index], x1[r_index]);
}
}
template <typename T>
void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, T *y,
@ -871,6 +981,45 @@ void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t
}
}
template <typename T1, typename T2, typename T3>
void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T1 *x0, const T2 *x1,
Complex<T3> *y, cudaStream_t stream) {
size_t size = 1;
for (auto d : y_dims) {
size *= d;
}
switch (op) {
case BROADCAST_TYPE_ADD:
return BroadcastComplexArithKernel<T1, T2, T3, AddFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_SUB:
return BroadcastComplexArithKernel<T1, T2, T3, SubFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_MUL:
return BroadcastComplexArithKernel<T1, T2, T3, MulFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_DIV:
return BroadcastComplexArithKernel<T1, T2, T3, DivFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_REALDIV:
return BroadcastComplexArithKernel<T1, T2, T3, RealDivFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
default:
break;
}
}
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const double *x0,
const double *x1, double *y, cudaStream_t stream);
@ -905,8 +1054,30 @@ template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vect
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint64_t *x0,
const uint64_t *x1, uint64_t *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0,
const bool *x1, bool *y, cudaStream_t stream);
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0, const bool *x1,
bool *y, cudaStream_t stream);
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op,
const Complex<float> *x0, const Complex<float> *x1, Complex<float> *y,
cudaStream_t stream);
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op,
const Complex<float> *x0, const float *x1, Complex<float> *y, cudaStream_t stream);
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0,
const Complex<float> *x1, Complex<float> *y, cudaStream_t stream);
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op,
const Complex<double> *x0, const Complex<double> *x1, Complex<double> *y,
cudaStream_t stream);
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op,
const Complex<double> *x0, const double *x1, Complex<double> *y,
cudaStream_t stream);
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const double *x0,
const Complex<double> *x1, Complex<double> *y, cudaStream_t stream);
// BroadcastTo
template <typename T>
__global__ void BroadcastToKernel(const size_t i0, const size_t i1, const size_t i2, const size_t i3, const size_t o0,

View File

@ -19,6 +19,7 @@
#include <vector>
#include "runtime/device/gpu/cuda_common.h"
#include "utils/complex.h"
const float kFloatEplison = 1e-37;
@ -57,6 +58,10 @@ void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *
template <typename T>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream);
template <typename T1, typename T2, typename T3>
void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const T1 *x0, const T2 *x1,
Complex<T3> *y, cudaStream_t stream);
template <typename T>
void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, bool *y,
@ -67,6 +72,11 @@ void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, T *y,
cudaStream_t stream);
template <typename T1, typename T2, typename T3>
void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T1 *x0, const T2 *x1,
Complex<T3> *y, cudaStream_t stream);
template <typename T>
void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr, T *output_addr,

View File

@ -117,6 +117,8 @@ template void Cast(const int input_size, const int8_t *input_addr, float *output
template void Cast(const int input_size, const int8_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -130,6 +132,9 @@ template void Cast(const int input_size, const int16_t *input_addr, float *outpu
template void Cast(const int input_size, const int16_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -143,6 +148,9 @@ template void Cast(const int input_size, const int32_t *input_addr, float *outpu
template void Cast(const int input_size, const int32_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -156,6 +164,8 @@ template void Cast(const int input_size, const int64_t *input_addr, float *outpu
template void Cast(const int input_size, const int64_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -169,6 +179,8 @@ template void Cast(const int input_size, const uint8_t *input_addr, float *outpu
template void Cast(const int input_size, const uint8_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -182,6 +194,8 @@ template void Cast(const int input_size, const uint16_t *input_addr, float *outp
template void Cast(const int input_size, const uint16_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -195,6 +209,8 @@ template void Cast(const int input_size, const uint32_t *input_addr, float *outp
template void Cast(const int input_size, const uint32_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -208,6 +224,8 @@ template void Cast(const int input_size, const uint64_t *input_addr, float *outp
template void Cast(const int input_size, const uint64_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -221,6 +239,8 @@ template void Cast(const int input_size, const half *input_addr, float *output_a
template void Cast(const int input_size, const half *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -234,6 +254,8 @@ template void Cast(const int input_size, const float *input_addr, float *output_
template void Cast(const int input_size, const float *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -247,6 +269,8 @@ template void Cast(const int input_size, const double *input_addr, float *output
template void Cast(const int input_size, const double *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const bool *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const bool *input_addr, int16_t *output_addr, cudaStream_t stream);
@ -260,3 +284,35 @@ template void Cast(const int input_size, const bool *input_addr, float *output_a
template void Cast(const int input_size, const bool *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const bool *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const bool *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const bool *input_addr, Complex<float> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const bool *input_addr, Complex<double> *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<float> *input_addr, Complex<double> *output_addr,
cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const Complex<double> *input_addr, Complex<float> *output_addr,
cudaStream_t stream);

View File

@ -19,6 +19,7 @@
#include <vector>
#include "runtime/device/gpu/cuda_common.h"
#include "utils/complex.h"
template <typename S, typename T>
void Cast(const int input_size, const S *input_addr, T *output_addr, cudaStream_t stream);

View File

@ -17,23 +17,10 @@
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/pme_common.cuh"
template <typename T>
__global__ static void Split_Complex(const int element_numbers, T *real_part, T *imag_part,
const cufftComplex *complex_element) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
real_part[i] = complex_element[i].x;
imag_part[i] = complex_element[i].y;
}
}
template <typename T>
void FFT3D(int Nfft, T *input_tensor, T *complex_fq, T *output_real, T *output_imag,
const cufftHandle &FFT_plan_r2c, cudaStream_t stream) {
cufftComplex *COMPLEX_FQ = reinterpret_cast<cufftComplex *>(complex_fq);
cufftExecR2C(FFT_plan_r2c, input_tensor, COMPLEX_FQ);
Split_Complex<T><<<Nfft / 1024 + 1, 1024, 0, stream>>>(Nfft, output_real, output_imag, COMPLEX_FQ);
void FFT3D(int Nfft, T *input_tensor, Complex<T> *output_tensor, const cufftHandle &FFT_plan_r2c, cudaStream_t stream) {
cufftExecR2C(FFT_plan_r2c, input_tensor, reinterpret_cast<cufftComplex *>(output_tensor));
return;
}
template void FFT3D<float>(int Nfft, float *input_tensor, float *complex_fq, float *output_real,
float *output_imag, const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
template void FFT3D<float>(int Nfft, float *input_tensor, Complex<float> *output_tensor,
const cufftHandle &FFT_plan_r2c, cudaStream_t stream);

View File

@ -17,10 +17,10 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_FFT_3D_IMPL_H_
#include <cufft.h>
#include "utils/complex.h"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void FFT3D(int Nfft, T *input_tensor, T *complex_fq, T *output_real, T *output_imag,
const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
void FFT3D(int Nfft, T *input_tensor, Complex<T> *output_tensor, const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
#endif

View File

@ -17,23 +17,11 @@
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/pme_common.cuh"
template <typename T>
__global__ static void Merge_Complex(const int element_numbers, T *real_part, T *imag_part,
cufftComplex *complex_element) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
complex_element[i].x = real_part[i];
complex_element[i].y = imag_part[i];
}
}
template <typename T>
void IFFT3D(int Nfft, T *input_real, T *input_imag, T *complex_fq, T *output_tensor,
const cufftHandle &FFT_plan_c2r, cudaStream_t stream) {
cufftComplex *COMPLEX_FQ = reinterpret_cast<cufftComplex *>(complex_fq);
Merge_Complex<T><<<Nfft / 1024 + 1, 1024, 0, stream>>>(Nfft, input_real, input_imag, COMPLEX_FQ);
cufftExecC2R(FFT_plan_c2r, COMPLEX_FQ, output_tensor);
void IFFT3D(int Nfft, Complex<T> *input_tensor, T *output_tensor, const cufftHandle &FFT_plan_c2r,
cudaStream_t stream) {
cufftExecC2R(FFT_plan_c2r, reinterpret_cast<cufftComplex *>(input_tensor), output_tensor);
return;
}
template void IFFT3D<float>(int Nfft, float *input_real, float *input_imag, float *complex_fq,
float *output_tensor, const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
template void IFFT3D<float>(int Nfft, Complex<float> *input_tensor, float *output_tensor,
const cufftHandle &FFT_plan_c2r, cudaStream_t stream);

View File

@ -17,10 +17,10 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_IFFT_3D_IMPL_H_
#include <cufft.h>
#include "utils/complex.h"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void IFFT3D(int Nfft, T *input_real, T *input_imag, T *complex_fq, T *output_tensor,
const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
void IFFT3D(int Nfft, Complex<T> *input_tensor, T *output_tensor, const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
#endif

View File

@ -306,6 +306,34 @@ __global__ void AbsKernel(const half *input, half *output, const size_t count) {
return;
}
template <typename T>
__global__ void AbsKernel(const Complex<T> *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = abs(input[i]);
}
return;
}
template <typename T>
__global__ void RealKernel(const Complex<T> *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = input[i].real();
}
return;
}
template <typename T>
__global__ void ImagKernel(const Complex<T> *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = input[i].imag();
}
return;
}
template <typename T>
__global__ void ConjKernel(const Complex<T> *input, Complex<T> *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = Complex<T>(input[i].real(), -input[i].imag());
}
return;
}
template <typename T>
__global__ void FloorKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = floorf(input[i]);
@ -478,6 +506,26 @@ void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream
return;
}
template <typename T>
void Abs(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream) {
AbsKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Real(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream) {
RealKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Imag(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ImagKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Conj(const Complex<T> *input, Complex<T> *output, const size_t count, cudaStream_t cuda_stream) {
ConjKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
FloorKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
@ -671,3 +719,15 @@ template void Floor<int>(const int *input, int *output, const size_t count, cuda
template void Rint<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template void Round<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template void Sign<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
// complex64
template void Real<float>(const Complex<float> *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Imag<float>(const Complex<float> *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Conj<float>(const Complex<float> *input, Complex<float> *output, const size_t count,
cudaStream_t cuda_stream);
// complex128
template void Real<double>(const Complex<double> *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Imag<double>(const Complex<double> *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Conj<double>(const Complex<double> *input, Complex<double> *output, const size_t count,
cudaStream_t cuda_stream);

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
#include "runtime/device/gpu/cuda_common.h"
#include "utils/complex.h"
template <typename T>
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
@ -64,5 +65,10 @@ template <typename T>
void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Sign(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Real(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Imag(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Conj(const Complex<T> *input, Complex<T> *output, const size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_

View File

@ -0,0 +1,46 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/broadcast_complex_gpu_kernel.h"
namespace mindspore {
namespace kernel {
#define MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(OPNAME, T0_MS_DTYPE, T1_MS_DTYPE, T0_DTYPE, T1_DTYPE) \
MS_REG_GPU_KERNEL_THREE(OPNAME, \
KernelAttr().AddInputAttr(T0_MS_DTYPE).AddInputAttr(T0_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
BroadcastComplexOpGpuKernel, T0_DTYPE, T0_DTYPE, T0_DTYPE) \
MS_REG_GPU_KERNEL_THREE(OPNAME, \
KernelAttr().AddInputAttr(T0_MS_DTYPE).AddInputAttr(T1_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
BroadcastComplexOpGpuKernel, T0_DTYPE, T1_DTYPE, T0_DTYPE) \
MS_REG_GPU_KERNEL_THREE(OPNAME, \
KernelAttr().AddInputAttr(T1_MS_DTYPE).AddInputAttr(T0_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
BroadcastComplexOpGpuKernel, T1_DTYPE, T0_DTYPE, T0_DTYPE)
template <typename T>
using Complex = mindspore::utils::Complex<T>;
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Add, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Add, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Sub, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Sub, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Mul, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Mul, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Div, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Div, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(RealDiv, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(RealDiv, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,152 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_COMPLEX_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_COMPLEX_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <map>
#include <complex>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
template <typename T, typename S, typename G>
class BroadcastComplexOpGpuKernel : public GpuKernel {
public:
BroadcastComplexOpGpuKernel() { ResetResource(); }
~BroadcastComplexOpGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *lhs = GetDeviceAddress<T>(inputs, 0);
S *rhs = GetDeviceAddress<S>(inputs, 1);
G *output = GetDeviceAddress<G>(outputs, 0);
if (need_broadcast_) {
BroadcastComplexArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ElewiseComplexArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
GetOpType(kernel_node);
auto shape1 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
auto shape2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
need_broadcast_ = AnfAlgo::IsTensorBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than " << MAX_DIMS;
}
lhs_shape_.resize(MAX_DIMS, 1);
rhs_shape_.resize(MAX_DIMS, 1);
output_shape_.resize(MAX_DIMS, 1);
for (size_t i = 0; i < shape3.size(); i++) {
if (need_broadcast_) {
output_shape_[i] = shape3[i];
}
output_num_ *= shape3[i];
}
int lhs_offset = shape3.size() - shape1.size();
for (size_t j = 0; j < shape1.size(); j++) {
if (need_broadcast_) {
lhs_shape_[j + lhs_offset] = shape1[j];
}
input1_num_ *= shape1[j];
}
int rhs_offset = shape3.size() - shape2.size();
for (size_t k = 0; k < shape2.size(); k++) {
if (need_broadcast_) {
rhs_shape_[k + rhs_offset] = shape2[k];
}
input2_num_ *= shape2[k];
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
op_type_ = BROADCAST_TYPE_INVALID;
need_broadcast_ = false;
input1_num_ = 1;
input2_num_ = 1;
output_num_ = 1;
lhs_shape_.clear();
rhs_shape_.clear();
output_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitResource() override { return; }
void InitSizeLists() override {
input_size_list_.push_back(input1_num_ * sizeof(T));
input_size_list_.push_back(input2_num_ * sizeof(S));
output_size_list_.push_back(output_num_ * sizeof(G));
}
private:
void GetOpType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {{"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL},
{"Sub", BROADCAST_TYPE_SUB},
{"Add", BROADCAST_TYPE_ADD},
{"Div", BROADCAST_TYPE_DIV}};
auto iter = kBroadcastArithmetricTypeMap.find(kernel_name);
if (iter != kBroadcastArithmetricTypeMap.end()) {
op_type_ = iter->second;
return;
}
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
}
BroadcastOpType op_type_;
bool need_broadcast_;
size_t input1_num_;
size_t input2_num_;
size_t output_num_;
std::vector<size_t> lhs_shape_;
std::vector<size_t> rhs_shape_;
std::vector<size_t> output_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
}; // namespace kernel
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_COMPLEX_GPU_KERNEL_H_

View File

@ -0,0 +1,34 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/complex.h"
#include "backend/kernel_compiler/gpu/math/unary_op_complex_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Real, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
UnaryOpComplexGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Real, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
UnaryOpComplexGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
UnaryOpComplexGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
UnaryOpComplexGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
UnaryOpComplexGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
UnaryOpComplexGpuKernel, double)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,160 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARY_COMPLEX_OP_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARY_COMPLEX_OP_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <map>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h"
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
class UnaryOpComplexGpuKernel : public GpuKernel {
public:
UnaryOpComplexGpuKernel() { ResetResource(); }
~UnaryOpComplexGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
Complex<T> *input_addr = GetDeviceAddress<Complex<T>>(inputs, 0);
if (is_c2r_op_) {
T *output_addr = GetDeviceAddress<T>(outputs, 0);
switch (unary_op_type_) {
case UNARY_OP_REAL: {
Real(input_addr, output_addr, inputs[0]->size / sizeof(Complex<T>),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_IMAG: {
Imag(input_addr, output_addr, inputs[0]->size / sizeof(Complex<T>),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
default: {
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
}
}
} else {
Complex<T> *output_addr = GetDeviceAddress<Complex<T>>(outputs, 0);
switch (unary_op_type_) {
case UNARY_OP_CONJ: {
Conj(input_addr, output_addr, inputs[0]->size / sizeof(Complex<T>),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
default: {
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
}
}
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
GetOpType(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but unary op needs 1 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "UnaryOpGpuKernel input is null";
InitSizeLists();
return true;
}
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
output_size_ *= input_shape[i];
}
if (is_c2r_op_) {
output_size_ /= 2;
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = sizeof(Complex<T>);
output_size_ = sizeof(Complex<T>);
workspace_size_ = 0;
is_null_input_ = false;
is_c2r_op_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
}
private:
void GetOpType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
static std::map<std::string, UnaryOptype> kComplexSupportedC2RTypeMap = {{"Real", UNARY_OP_REAL},
{"Imag", UNARY_OP_IMAG}};
auto iter = kComplexSupportedC2RTypeMap.find(kernel_name);
if (iter != kComplexSupportedC2RTypeMap.end()) {
unary_op_type_ = iter->second;
is_c2r_op_ = true;
return;
}
static std::map<std::string, UnaryOptype> kComplexSupportedC2CTypeMap = {{"Conj", UNARY_OP_CONJ}};
iter = kComplexSupportedC2CTypeMap.find(kernel_name);
if (iter != kComplexSupportedC2RTypeMap.end()) {
unary_op_type_ = iter->second;
is_c2r_op_ = false;
return;
}
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
}
private:
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
bool is_null_input_;
bool is_c2r_op_;
UnaryOptype unary_op_type_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARY_COMPLEX_OP_GPU_KERNEL_H_

View File

@ -51,6 +51,9 @@ enum UnaryOptype {
UNARY_OP_RINT,
UNARY_OP_ROUND,
UNARY_OP_SIGN,
UNARY_OP_REAL,
UNARY_OP_IMAG,
UNARY_OP_CONJ,
UNARY_OP_INVALID_TYPE = 255
};
@ -66,7 +69,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
{"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH},
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR},
{"Rint", UNARY_OP_RINT}, {"Round", UNARY_OP_ROUND},
{"Sign", UNARY_OP_SIGN}};
{"Real", UNARY_OP_REAL}, {"Imag", UNARY_OP_IMAG},
{"Sign", UNARY_OP_SIGN}, {"Conj", UNARY_OP_CONJ}};
template <typename T>
class UnaryOpGpuKernel : public GpuKernel {

View File

@ -17,9 +17,7 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
FFT3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
FFT3DGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FFT3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
FFT3DGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -27,6 +27,8 @@
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
class FFT3DGpuKernel : public GpuKernel {
public:
FFT3DGpuKernel() = default;
@ -52,22 +54,17 @@ class FFT3DGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto input_tensor = GetDeviceAddress<T>(inputs, 0);
auto complex_fq = GetDeviceAddress<T>(workspace, 0);
auto output_real = GetDeviceAddress<T>(outputs, 0);
auto output_imag = GetDeviceAddress<T>(outputs, 1);
auto output_tensor = GetDeviceAddress<Complex<T>>(outputs, 0);
cufftSetStream(FFT_plan_r2c, reinterpret_cast<cudaStream_t>(stream_ptr));
FFT3D<T>(Nfft, input_tensor, complex_fq, output_real, output_imag, FFT_plan_r2c,
reinterpret_cast<cudaStream_t>(stream_ptr));
FFT3D<T>(Nfft, input_tensor, output_tensor, FFT_plan_r2c, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(Nall * sizeof(T));
workspace_size_list_.push_back(Nfft * 2 * sizeof(T));
output_size_list_.push_back(Nfft * sizeof(T));
output_size_list_.push_back(Nfft * sizeof(T));
output_size_list_.push_back(Nfft * sizeof(Complex<T>));
}
private:

View File

@ -17,9 +17,7 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
IFFT3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
IFFT3DGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(IFFT3D, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
IFFT3DGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -27,6 +27,8 @@
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
class IFFT3DGpuKernel : public GpuKernel {
public:
IFFT3DGpuKernel() = default;
@ -51,22 +53,17 @@ class IFFT3DGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto input_real = GetDeviceAddress<T>(inputs, 0);
auto input_imag = GetDeviceAddress<T>(inputs, 1);
auto complex_fq = GetDeviceAddress<T>(workspace, 0);
auto input_tensor = GetDeviceAddress<Complex<T>>(inputs, 0);
auto output_tensor = GetDeviceAddress<T>(outputs, 0);
cufftSetStream(FFT_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
IFFT3D<T>(Nfft, input_real, input_imag, complex_fq, output_tensor, FFT_plan_c2r,
reinterpret_cast<cudaStream_t>(stream_ptr));
IFFT3D<T>(Nfft, input_tensor, output_tensor, FFT_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(Nfft * sizeof(T));
input_size_list_.push_back(Nfft * sizeof(T));
workspace_size_list_.push_back(Nfft * 2 * sizeof(T));
input_size_list_.push_back(Nfft * sizeof(Complex<T>));
output_size_list_.push_back(Nall * sizeof(T));
}

View File

@ -99,6 +99,10 @@ static irpb::DataType GetNumberDataType(const TypePtr &type) {
return irpb::DT_BASE_UINT;
case kNumberTypeFloat:
return irpb::DT_BASE_FLOAT;
case kNumberTypeComplex64:
return irpb::DT_COMPLEX64;
case kNumberTypeComplex128:
return irpb::DT_COMPLEX128;
default:
MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
}

View File

@ -89,6 +89,8 @@ enum DataType {
DT_ANYTHING = 40; // type anything
DT_REFKEY = 41; // type refkey
DT_REF = 42; // type ref
DT_COMPLEX64 = 43; // type generate complex64
DT_COMPLEX128 = 44; // type generate complex128
}
// Value definition for attribute value or parameter default value
@ -256,7 +258,7 @@ message ModelProto {
// The parameterized graph that is evaluated to execute the model.
optional GraphProto graph = 4;
// metadata info of opeartors
// metadata info of operators
optional OperatorSetProto metadata_operators = 5;
};

View File

@ -20,6 +20,7 @@
#include <limits>
#ifdef ENABLE_GPU
#include <thrust/complex.h>
#include <cublas_v2.h>
#endif
#include "base/float16.h"
#if defined(__CUDACC__)
@ -80,7 +81,11 @@ struct alignas(sizeof(T) * 2) Complex {
HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
HOST_DEVICE inline explicit operator float16() const { return static_cast<float16>(real_); }
#if defined(__CUDACC__)
HOST_DEVICE inline explicit operator half() const { return static_cast<half>(real_); }
#else
inline explicit operator float16() const { return static_cast<float16>(real_); }
#endif
HOST_DEVICE inline constexpr Complex<T> &operator=(const T &real) {
real_ = real;
@ -100,12 +105,14 @@ struct alignas(sizeof(T) * 2) Complex {
HOST_DEVICE inline Complex<T> &operator*=(const T &real) {
real_ *= real;
imag_ *= real;
return *this;
}
// Note: check division by zero before use it.
HOST_DEVICE inline Complex<T> &operator/=(const T &real) {
real_ /= real;
imag_ /= real;
return *this;
}

View File

@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
MatrixInverse, IndexAdd, Erfinv)
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
@ -509,7 +509,10 @@ __all__ = [
"BufferAppend",
"BufferGetItem",
"BufferSample",
"Erfinv"
"Erfinv",
"Conj",
"Real",
"Imag"
]
__sponge__ = [

View File

@ -86,6 +86,23 @@ class _MathBinaryOp(_BinaryOp):
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None):
"""Staticmethod of infer dtype for _MathBinaryOp."""
args_type = {"x": x_dtype, "y": y_dtype}
complex_types = [mstype.tensor_type(mstype.complex64), mstype.tensor_type(mstype.complex128)]
if x_dtype in complex_types or y_dtype in complex_types:
tpye_infer_dict = {
(mstype.complex64, mstype.complex64): mstype.tensor_type(mstype.complex64),
(mstype.complex64, mstype.float32): mstype.tensor_type(mstype.complex64),
(mstype.float32, mstype.complex64): mstype.tensor_type(mstype.complex64),
(mstype.complex128, mstype.complex128): mstype.tensor_type(mstype.complex128),
(mstype.complex128, mstype.float64): mstype.tensor_type(mstype.complex128),
(mstype.float64, mstype.complex128): mstype.tensor_type(mstype.complex128),
}
if (x_dtype.element_type(), y_dtype.element_type()) not in tpye_infer_dict.keys():
raise TypeError('Complex math binary op expecting Tensor [complex64, complex64],'
+ '[complex64, float32], [float32, complex64], [complex128, complex128],'
+ '[complex128, float64], [float64, complex128],'
+ f'but got : [{format(x_dtype)},{format(y_dtype)}].')
return tpye_infer_dict.get((x_dtype.element_type(), y_dtype.element_type()))
validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name)
return x_dtype
@ -4439,7 +4456,6 @@ class Abs(Primitive):
"""Initialize Abs"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
class Sign(PrimitiveWithInfer):
r"""
Performs sign on the tensor element-wise.
@ -5257,3 +5273,125 @@ class Erfinv(Primitive):
def __init__(self):
"""Initialize Erfinv"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
class Conj(PrimitiveWithInfer):
"""
Returns a Tensor that is the real part of the input.
Inputs:
- **input** (Tensor, complex) - The input tensor. types: complex64, complex128.
Outputs:
Tensor, has the float type.
Raises:
TypeError: If the dtype of input is not one of: complex64, complex128.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
>>> conj = ops.Conj()
>>> output = conj(x)
>>> print(output)
1.3-0.4j
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(
inputs=['input_tensor'],
outputs=['output_tensor'])
def infer_shape(self, input_shape):
return input_shape
def infer_dtype(self, input_dtype):
validator.check_tensor_dtype_valid('input_tensor', input_dtype,
[mstype.complex64, mstype.complex128], self.name)
return input_dtype
class Real(PrimitiveWithInfer):
"""
Returns a Tensor that is the real part of the input.
Inputs:
- **input** (Tensor, complex) - The input tensor. types: complex64, complex128.
Outputs:
Tensor, has the float type.
Raises:
TypeError: If the dtype of input is not one of: complex64, complex128.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
>>> conj = ops.Real()
>>> output = conj(x)
>>> print(output)
1.3
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(
inputs=['input_tensor'],
outputs=['output_tensor'])
def infer_shape(self, input_shape):
return input_shape
def infer_dtype(self, input_dtype):
validator.check_tensor_dtype_valid('input_tensor', input_dtype,
[mstype.complex64, mstype.complex128], self.name)
if input_dtype == mstype.tensor_type(mstype.complex64):
output_dtype = mstype.float32
elif input_dtype == mstype.tensor_type(mstype.complex128):
output_dtype = mstype.float64
return output_dtype
class Imag(PrimitiveWithInfer):
"""
Returns a new tensor containing imaginary value of the input.
Inputs:
- **input** (Tensor, complex) - The input tensor. types: complex64, complex128.
Outputs:
Tensor, has the float type.
Raises:
TypeError: If the dtype of input is not one of: complex64, complex128.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
>>> conj = ops.Imag()
>>> output = conj(x)
>>> print(output)
0.4
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(
inputs=['input_tensor'],
outputs=['output_tensor'])
def infer_shape(self, input_shape):
return input_shape
def infer_dtype(self, input_dtype):
validator.check_tensor_dtype_valid('input_tensor', input_dtype,
[mstype.complex64, mstype.complex128], self.name)
if input_dtype == mstype.tensor_type(mstype.complex64):
output_dtype = mstype.float32
elif input_dtype == mstype.tensor_type(mstype.complex128):
output_dtype = mstype.float64
return output_dtype

View File

@ -2988,9 +2988,7 @@ class FFT3D(PrimitiveWithInfer):
- **input_tensor** (Tensor, float32) - [fftx, ffty, fftz]
Outputs:
- **output_real** (float32) - the real part of the output tensor after
undergoing fast Fourier transform.
- **output_imag** (float32) - the imaginary part of the output tensor after
- **output_tensor** (complex64) - the real part of the output tensor after
undergoing fast Fourier transform.
Supported Platforms:
@ -3001,27 +2999,24 @@ class FFT3D(PrimitiveWithInfer):
def __init__(self):
self.init_prim_io_names(
inputs=['input_tensor'],
outputs=['output_real', 'output_imag'])
outputs=['output_tensor'])
def infer_shape(self, input_shape):
self.add_prim_attr('fftx', input_shape[0])
self.add_prim_attr('ffty', input_shape[1])
self.add_prim_attr('fftz', input_shape[2])
return [input_shape[0], input_shape[1], int(input_shape[2]/2)+1],\
[input_shape[0], input_shape[1], int(input_shape[2]/2)+1]
return [input_shape[0], input_shape[1], int(input_shape[2]/2)+1]
def infer_dtype(self, input_dtype):
validator.check_tensor_dtype_valid('input_tensor', input_dtype, mstype.number_type, self.name)
return input_dtype, input_dtype
validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.float32], self.name)
return mstype.complex64
class IFFT3D(PrimitiveWithInfer):
"""
Inverse FFT with Three-Dimensional Input.
Inputs:
- **input_real** (Tensor, float32) - [fftx, ffty, fftz]
- **input_imag** (Tensor, float32) - [fftx, ffty, fftz]
- **input_tensor** (Tensor, complex64) - [fftx, ffty, fftz]
Outputs:
- **output_tensor** (float32) - returns the tensor after undergoing
@ -3034,21 +3029,18 @@ class IFFT3D(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
self.init_prim_io_names(
inputs=['input_real', 'input_imag'],
inputs=['input_tensor'],
outputs=['output_tensor'])
def infer_shape(self, input_shape1, input_shape2):
for i in range(len(input_shape1)):
validator.check_int(input_shape1[i], input_shape2[i], Rel.EQ, "input_shape", self.name)
self.add_prim_attr('fftx', input_shape1[0])
self.add_prim_attr('ffty', input_shape1[1])
self.add_prim_attr('fftz', input_shape1[2])
return [input_shape1[0], input_shape1[1], (input_shape1[2]-1)*2]
def infer_shape(self, input_shape):
self.add_prim_attr('fftx', input_shape[0])
self.add_prim_attr('ffty', input_shape[1])
self.add_prim_attr('fftz', input_shape[2])
return [input_shape[0], input_shape[1], (input_shape[2]-1)*2]
def infer_dtype(self, input_real_dtype, input_imag_dtype):
validator.check_tensor_dtype_valid('input_real', input_real_dtype, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_imag', input_imag_dtype, mstype.number_type, self.name)
return input_real_dtype
def infer_dtype(self, input_dtype):
validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
return mstype.float32
class NeighborListUpdate(PrimitiveWithInfer):
"""

View File

@ -144,13 +144,13 @@ TEST_F(TestComplex, test_arithmetic) {
test_arithmetic_mul<Complex<float>, Complex<float>, Complex<float>>(
Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(-3.6963, 4.9284));
test_arithmetic_mul<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11,
Complex<float>(1.2321, 2.22));
Complex<float>(1.2321, 2.4642));
test_arithmetic_mul<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
Complex<float>(1.2321, 2.22));
Complex<float>(1.2321, 2.4642));
test_arithmetic_div<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22),
Complex<float>(1.11, 2.22), Complex<float>(1, 0));
test_arithmetic_div<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(1, 2.22));
test_arithmetic_div<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(1, 2));
test_arithmetic_div<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
Complex<float>(0.2, -0.4));
}