forked from mindspore-Ecosystem/mindspore
add gpu complex ops
This commit is contained in:
parent
2ee8bf46ca
commit
dad375abb9
|
@ -21,6 +21,7 @@
|
|||
"mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc" "containerOutOfBounds"
|
||||
"mindspore/mindspore/core/ops/strided_slice.cc" "zerodivcond"
|
||||
"mindspore/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc" "useStlAlgorithm"
|
||||
"mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.cc" "unknownMacro"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"
|
||||
|
|
|
@ -36,25 +36,41 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
constexpr char kAxis[] = "axis";
|
||||
constexpr char kTypeInt32[] = "Int32";
|
||||
const std::unordered_map<std::string, TypeId> type_id_maps = {
|
||||
{"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
|
||||
{"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
|
||||
{"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
|
||||
{"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
|
||||
{"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
|
||||
{"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
|
||||
{"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
|
||||
{"bool", TypeId::kNumberTypeBool}, {"complex64", TypeId::kNumberTypeComplex64}};
|
||||
const std::unordered_map<std::string, TypeId> type_id_maps = {{"float", TypeId::kNumberTypeFloat32},
|
||||
{"float16", TypeId::kNumberTypeFloat16},
|
||||
{"float32", TypeId::kNumberTypeFloat32},
|
||||
{"float64", TypeId::kNumberTypeFloat64},
|
||||
{"int", TypeId::kNumberTypeInt},
|
||||
{"int8", TypeId::kNumberTypeInt8},
|
||||
{"int16", TypeId::kNumberTypeInt16},
|
||||
{"int32", TypeId::kNumberTypeInt32},
|
||||
{"int64", TypeId::kNumberTypeInt64},
|
||||
{"uint", TypeId::kNumberTypeUInt},
|
||||
{"uint8", TypeId::kNumberTypeUInt8},
|
||||
{"uint16", TypeId::kNumberTypeUInt16},
|
||||
{"uint32", TypeId::kNumberTypeUInt32},
|
||||
{"uint64", TypeId::kNumberTypeUInt64},
|
||||
{"bool", TypeId::kNumberTypeBool},
|
||||
{"complex64", TypeId::kNumberTypeComplex64},
|
||||
{"complex128", TypeId::kNumberTypeComplex128}};
|
||||
|
||||
const std::map<TypeId, std::string> type_id_str_map = {
|
||||
{TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
|
||||
{TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
|
||||
{TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
|
||||
{TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
|
||||
{TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
|
||||
{TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
|
||||
{TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
|
||||
{TypeId::kNumberTypeBool, "bool"}, {TypeId::kNumberTypeComplex64, "complex64"}};
|
||||
const std::map<TypeId, std::string> type_id_str_map = {{TypeId::kNumberTypeFloat32, "float32"},
|
||||
{TypeId::kNumberTypeFloat16, "float16"},
|
||||
{TypeId::kNumberTypeFloat, "float"},
|
||||
{TypeId::kNumberTypeFloat64, "float64"},
|
||||
{TypeId::kNumberTypeInt, "int"},
|
||||
{TypeId::kNumberTypeInt8, "int8"},
|
||||
{TypeId::kNumberTypeInt16, "int16"},
|
||||
{TypeId::kNumberTypeInt32, "int32"},
|
||||
{TypeId::kNumberTypeInt64, "int64"},
|
||||
{TypeId::kNumberTypeUInt, "uint"},
|
||||
{TypeId::kNumberTypeUInt8, "uint8"},
|
||||
{TypeId::kNumberTypeUInt16, "uint16"},
|
||||
{TypeId::kNumberTypeUInt32, "uint32"},
|
||||
{TypeId::kNumberTypeUInt64, "uint64"},
|
||||
{TypeId::kNumberTypeBool, "bool"},
|
||||
{TypeId::kNumberTypeComplex64, "complex64"},
|
||||
{TypeId::kNumberTypeComplex128, "complex128"}};
|
||||
|
||||
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
|
||||
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
|
||||
|
|
|
@ -42,6 +42,12 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutput
|
|||
int8_t, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
int8_t, bool)
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, int8_t, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, int8_t, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
int16_t, int8_t)
|
||||
|
@ -67,6 +73,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutpu
|
|||
CastGpuKernel, int16_t, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
int16_t, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, int16_t, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, int16_t, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
int32_t, int8_t)
|
||||
|
@ -92,6 +102,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutpu
|
|||
CastGpuKernel, int32_t, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
int32_t, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, int32_t, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, int32_t, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
int64_t, int8_t)
|
||||
|
@ -117,6 +131,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutpu
|
|||
CastGpuKernel, int64_t, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
int64_t, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, int64_t, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, int64_t, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
uint8_t, int8_t)
|
||||
|
@ -142,6 +160,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutpu
|
|||
CastGpuKernel, uint8_t, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
uint8_t, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, uint8_t, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, uint8_t, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
uint16_t, int8_t)
|
||||
|
@ -167,6 +189,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutp
|
|||
CastGpuKernel, uint16_t, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
uint16_t, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, uint16_t, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, uint16_t, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
uint32_t, int8_t)
|
||||
|
@ -192,6 +218,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutp
|
|||
CastGpuKernel, uint32_t, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
uint32_t, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, uint32_t, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, uint32_t, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
uint64_t, int8_t)
|
||||
|
@ -217,6 +247,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutp
|
|||
CastGpuKernel, uint64_t, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
uint64_t, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, uint64_t, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, uint64_t, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
half, int8_t)
|
||||
|
@ -242,6 +276,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOut
|
|||
CastGpuKernel, half, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
half, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, half, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, half, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
float, int8_t)
|
||||
|
@ -267,6 +305,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
|
|||
CastGpuKernel, float, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
float, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, float, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, float, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
double, int8_t)
|
||||
|
@ -292,6 +334,10 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOut
|
|||
CastGpuKernel, double, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
double, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, double, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, double, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastGpuKernel,
|
||||
bool, int8_t)
|
||||
|
@ -317,5 +363,64 @@ MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutput
|
|||
bool, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastGpuKernel,
|
||||
bool, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, bool, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, bool, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt8),
|
||||
CastGpuKernel, Complex<float>, int8_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt16),
|
||||
CastGpuKernel, Complex<float>, int16_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32),
|
||||
CastGpuKernel, Complex<float>, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt64),
|
||||
CastGpuKernel, Complex<float>, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt8),
|
||||
CastGpuKernel, Complex<float>, uint8_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt16),
|
||||
CastGpuKernel, Complex<float>, uint16_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt32),
|
||||
CastGpuKernel, Complex<float>, uint32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CastGpuKernel, Complex<float>, uint64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastGpuKernel, Complex<float>, float)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastGpuKernel, Complex<float>, double)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastGpuKernel, Complex<float>, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeBool),
|
||||
CastGpuKernel, Complex<float>, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex128),
|
||||
CastGpuKernel, Complex<float>, Complex<double>)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt8),
|
||||
CastGpuKernel, Complex<double>, int8_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt16),
|
||||
CastGpuKernel, Complex<double>, int16_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32),
|
||||
CastGpuKernel, Complex<double>, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt64),
|
||||
CastGpuKernel, Complex<double>, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt8),
|
||||
CastGpuKernel, Complex<double>, uint8_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt16),
|
||||
CastGpuKernel, Complex<double>, uint16_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt32),
|
||||
CastGpuKernel, Complex<double>, uint32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeUInt64),
|
||||
CastGpuKernel, Complex<double>, uint64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat32),
|
||||
CastGpuKernel, Complex<double>, float)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||
CastGpuKernel, Complex<double>, double)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat16),
|
||||
CastGpuKernel, Complex<double>, half)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeBool),
|
||||
CastGpuKernel, Complex<double>, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(Cast, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex64),
|
||||
CastGpuKernel, Complex<double>, Complex<float>)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,14 +37,14 @@ struct EqualFunc {
|
|||
};
|
||||
|
||||
template <>
|
||||
struct EqualFunc <half> {
|
||||
struct EqualFunc<half> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? true : false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EqualFunc <float> {
|
||||
struct EqualFunc<float> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
|
||||
return std::abs(lhs - rhs) < 1e-9 ? true : false;
|
||||
}
|
||||
|
@ -56,15 +56,16 @@ struct GreaterEqualFunc {
|
|||
};
|
||||
|
||||
template <>
|
||||
struct GreaterEqualFunc <half> {
|
||||
struct GreaterEqualFunc<half> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ?
|
||||
true : (__half2float(lhs) > __half2float(rhs) ? true : false);
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9
|
||||
? true
|
||||
: (__half2float(lhs) > __half2float(rhs) ? true : false);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GreaterEqualFunc <float> {
|
||||
struct GreaterEqualFunc<float> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
|
||||
return std::abs(lhs - rhs) < 1e-9 ? true : (lhs > rhs ? true : false);
|
||||
}
|
||||
|
@ -76,15 +77,16 @@ struct LessEqualFunc {
|
|||
};
|
||||
|
||||
template <>
|
||||
struct LessEqualFunc <half> {
|
||||
struct LessEqualFunc<half> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ?
|
||||
true : (__half2float(lhs) < __half2float(rhs) ? true : false);
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9
|
||||
? true
|
||||
: (__half2float(lhs) < __half2float(rhs) ? true : false);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LessEqualFunc <float> {
|
||||
struct LessEqualFunc<float> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
|
||||
return std::abs(lhs - rhs) < 1e-9 ? true : (lhs < rhs ? true : false);
|
||||
}
|
||||
|
@ -96,14 +98,14 @@ struct NotEqualFunc {
|
|||
};
|
||||
|
||||
template <>
|
||||
struct NotEqualFunc <half> {
|
||||
struct NotEqualFunc<half> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const half &lhs, const half &rhs) {
|
||||
return std::abs(__half2float(lhs) - __half2float(rhs)) < 1e-9 ? false : true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NotEqualFunc <float> {
|
||||
struct NotEqualFunc<float> {
|
||||
__device__ __host__ __forceinline__ bool operator()(const float &lhs, const float &rhs) {
|
||||
return std::abs(lhs - rhs) < 1e-9 ? false : true;
|
||||
}
|
||||
|
@ -155,28 +157,52 @@ struct PowerFunc<half2> {
|
|||
template <typename T>
|
||||
struct RealDivFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs / rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs / rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return (lhs / rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DivFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs / rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs / rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return (lhs / rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MulFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs * rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs * rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return (lhs * rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SubFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs - rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs - rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs - rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return (lhs - rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AddFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const T &rhs) { return (lhs + rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const Complex<T> &rhs) { return (lhs + rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return (lhs + rhs);
|
||||
}
|
||||
};
|
||||
|
||||
// DivNoNan check if rhs is less than epsilon
|
||||
template <typename T>
|
||||
struct DivNoNanFunc {
|
||||
|
@ -525,6 +551,13 @@ __global__ void ElewiseArithKernel(const int nums, const T *x0, const T *x1, T *
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename Func>
|
||||
__global__ void ElewiseArithComplexKernel(const int nums, const T1 *x0, const T2 *x1, Complex<T3> *y) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
|
||||
y[pos] = Func()(x0[pos], x1[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
|
||||
switch (op) {
|
||||
|
@ -567,6 +600,30 @@ void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, c
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ElewiseArithComplexKernel(const int &nums, enum BroadcastOpType op, const T1 *x0, const T2 *x1, Complex<T3> *y,
|
||||
cudaStream_t stream) {
|
||||
switch (op) {
|
||||
case BROADCAST_TYPE_ADD:
|
||||
return ElewiseArithComplexKernel<T1, T2, T3, AddFunc<T3>>
|
||||
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_SUB:
|
||||
return ElewiseArithComplexKernel<T1, T2, T3, SubFunc<T3>>
|
||||
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_MUL:
|
||||
return ElewiseArithComplexKernel<T1, T2, T3, MulFunc<T3>>
|
||||
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_DIV:
|
||||
return ElewiseArithComplexKernel<T1, T2, T3, DivFunc<T3>>
|
||||
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_REALDIV:
|
||||
return ElewiseArithComplexKernel<T1, T2, T3, RealDivFunc<T3>>
|
||||
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
|
||||
return ElewiseArithKernel(nums, op, x0, x1, y, stream);
|
||||
|
@ -584,6 +641,12 @@ void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, cons
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const T1 *x0, const T2 *x1, Complex<T3> *y,
|
||||
cudaStream_t stream) {
|
||||
return ElewiseArithComplexKernel(nums, op, x0, x1, y, stream);
|
||||
}
|
||||
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const double *x0, const double *x1, double *y,
|
||||
cudaStream_t stream);
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, float *y,
|
||||
|
@ -608,6 +671,18 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint6
|
|||
uint64_t *y, cudaStream_t stream);
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const bool *x0, const bool *x1, bool *y,
|
||||
cudaStream_t stream);
|
||||
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const Complex<float> *x0,
|
||||
const Complex<float> *x1, Complex<float> *y, cudaStream_t stream);
|
||||
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const Complex<float> *x0, const float *x1,
|
||||
Complex<float> *y, cudaStream_t stream);
|
||||
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const float *x0, const Complex<float> *x1,
|
||||
Complex<float> *y, cudaStream_t stream);
|
||||
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const Complex<double> *x0,
|
||||
const Complex<double> *x1, Complex<double> *y, cudaStream_t stream);
|
||||
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const Complex<double> *x0, const double *x1,
|
||||
Complex<double> *y, cudaStream_t stream);
|
||||
template void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const double *x0, const Complex<double> *x1,
|
||||
Complex<double> *y, cudaStream_t stream);
|
||||
// Broadcast comparison
|
||||
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
|
||||
|
||||
|
@ -734,8 +809,8 @@ template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector
|
|||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint64_t *x0,
|
||||
const uint64_t *x1, bool *y, cudaStream_t stream);
|
||||
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0,
|
||||
const bool *x1, bool *y, cudaStream_t stream);
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0, const bool *x1,
|
||||
bool *y, cudaStream_t stream);
|
||||
// Broadcast Arithmetic
|
||||
template <typename T, typename Func>
|
||||
__global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
|
||||
|
@ -772,6 +847,41 @@ __global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const siz
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename Func>
|
||||
__global__ void BroadcastComplexArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
|
||||
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
|
||||
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
|
||||
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
|
||||
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const size_t d6, const T1 *x0, const T2 *x1, Complex<T3> *y) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
|
||||
size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
|
||||
size_t k = pos / (d3 * d4 * d5 * d6) % d2;
|
||||
size_t l = pos / (d4 * d5 * d6) % d3;
|
||||
size_t m = pos / (d5 * d6) % d4;
|
||||
size_t n = pos / d6 % d5;
|
||||
size_t o = pos % d6;
|
||||
|
||||
size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
|
||||
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
|
||||
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
|
||||
l_index += Index(l, l3) * l4 * l5 * l6;
|
||||
l_index += Index(m, l4) * l5 * l6;
|
||||
l_index += Index(n, l5) * l6;
|
||||
l_index += Index(o, l6);
|
||||
size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
|
||||
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
|
||||
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
|
||||
r_index += Index(l, r3) * r4 * r5 * r6;
|
||||
r_index += Index(m, r4) * r5 * r6;
|
||||
r_index += Index(n, r5) * r6;
|
||||
r_index += Index(o, r6);
|
||||
y[pos] = Func()(x0[l_index], x1[r_index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, T *y,
|
||||
|
@ -871,6 +981,45 @@ void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T1 *x0, const T2 *x1,
|
||||
Complex<T3> *y, cudaStream_t stream) {
|
||||
size_t size = 1;
|
||||
for (auto d : y_dims) {
|
||||
size *= d;
|
||||
}
|
||||
switch (op) {
|
||||
case BROADCAST_TYPE_ADD:
|
||||
return BroadcastComplexArithKernel<T1, T2, T3, AddFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_SUB:
|
||||
return BroadcastComplexArithKernel<T1, T2, T3, SubFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_MUL:
|
||||
return BroadcastComplexArithKernel<T1, T2, T3, MulFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_DIV:
|
||||
return BroadcastComplexArithKernel<T1, T2, T3, DivFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_REALDIV:
|
||||
return BroadcastComplexArithKernel<T1, T2, T3, RealDivFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const double *x0,
|
||||
const double *x1, double *y, cudaStream_t stream);
|
||||
|
@ -905,8 +1054,30 @@ template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vect
|
|||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint64_t *x0,
|
||||
const uint64_t *x1, uint64_t *y, cudaStream_t stream);
|
||||
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0,
|
||||
const bool *x1, bool *y, cudaStream_t stream);
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0, const bool *x1,
|
||||
bool *y, cudaStream_t stream);
|
||||
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op,
|
||||
const Complex<float> *x0, const Complex<float> *x1, Complex<float> *y,
|
||||
cudaStream_t stream);
|
||||
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op,
|
||||
const Complex<float> *x0, const float *x1, Complex<float> *y, cudaStream_t stream);
|
||||
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0,
|
||||
const Complex<float> *x1, Complex<float> *y, cudaStream_t stream);
|
||||
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op,
|
||||
const Complex<double> *x0, const Complex<double> *x1, Complex<double> *y,
|
||||
cudaStream_t stream);
|
||||
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op,
|
||||
const Complex<double> *x0, const double *x1, Complex<double> *y,
|
||||
cudaStream_t stream);
|
||||
template void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const double *x0,
|
||||
const Complex<double> *x1, Complex<double> *y, cudaStream_t stream);
|
||||
|
||||
// BroadcastTo
|
||||
template <typename T>
|
||||
__global__ void BroadcastToKernel(const size_t i0, const size_t i1, const size_t i2, const size_t i3, const size_t o0,
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "utils/complex.h"
|
||||
|
||||
const float kFloatEplison = 1e-37;
|
||||
|
||||
|
@ -57,6 +58,10 @@ void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *
|
|||
template <typename T>
|
||||
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream);
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ElewiseComplexArith(const int &nums, enum BroadcastOpType op, const T1 *x0, const T2 *x1,
|
||||
Complex<T3> *y, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, bool *y,
|
||||
|
@ -67,6 +72,11 @@ void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t
|
|||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, T *y,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T1 *x0, const T2 *x1,
|
||||
Complex<T3> *y, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
|
||||
const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr, T *output_addr,
|
||||
|
|
|
@ -117,6 +117,8 @@ template void Cast(const int input_size, const int8_t *input_addr, float *output
|
|||
template void Cast(const int input_size, const int8_t *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int8_t *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int8_t *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int8_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int8_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const int16_t *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int16_t *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -130,6 +132,9 @@ template void Cast(const int input_size, const int16_t *input_addr, float *outpu
|
|||
template void Cast(const int input_size, const int16_t *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int16_t *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int16_t *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int16_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int16_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
|
||||
template void Cast(const int input_size, const int32_t *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int32_t *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -143,6 +148,9 @@ template void Cast(const int input_size, const int32_t *input_addr, float *outpu
|
|||
template void Cast(const int input_size, const int32_t *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int32_t *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int32_t *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int32_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int32_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
|
||||
template void Cast(const int input_size, const int64_t *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int64_t *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -156,6 +164,8 @@ template void Cast(const int input_size, const int64_t *input_addr, float *outpu
|
|||
template void Cast(const int input_size, const int64_t *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int64_t *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int64_t *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int64_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const int64_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const uint8_t *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint8_t *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -169,6 +179,8 @@ template void Cast(const int input_size, const uint8_t *input_addr, float *outpu
|
|||
template void Cast(const int input_size, const uint8_t *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint8_t *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint8_t *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint8_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint8_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const uint16_t *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint16_t *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -182,6 +194,8 @@ template void Cast(const int input_size, const uint16_t *input_addr, float *outp
|
|||
template void Cast(const int input_size, const uint16_t *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint16_t *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint16_t *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint16_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint16_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const uint32_t *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint32_t *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -195,6 +209,8 @@ template void Cast(const int input_size, const uint32_t *input_addr, float *outp
|
|||
template void Cast(const int input_size, const uint32_t *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint32_t *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint32_t *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint32_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint32_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const uint64_t *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint64_t *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -208,6 +224,8 @@ template void Cast(const int input_size, const uint64_t *input_addr, float *outp
|
|||
template void Cast(const int input_size, const uint64_t *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint64_t *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint64_t *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint64_t *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const uint64_t *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const half *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const half *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -221,6 +239,8 @@ template void Cast(const int input_size, const half *input_addr, float *output_a
|
|||
template void Cast(const int input_size, const half *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const half *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const half *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const half *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const half *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const float *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const float *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -234,6 +254,8 @@ template void Cast(const int input_size, const float *input_addr, float *output_
|
|||
template void Cast(const int input_size, const float *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const float *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const float *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const float *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const float *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const double *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const double *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -247,6 +269,8 @@ template void Cast(const int input_size, const double *input_addr, float *output
|
|||
template void Cast(const int input_size, const double *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const double *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const double *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const double *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const double *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const bool *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const bool *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
|
@ -260,3 +284,35 @@ template void Cast(const int input_size, const bool *input_addr, float *output_a
|
|||
template void Cast(const int input_size, const bool *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const bool *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const bool *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const bool *input_addr, Complex<float> *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const bool *input_addr, Complex<double> *output_addr, cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, int32_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, int64_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, uint8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, uint16_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, uint32_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, uint64_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, float *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<float> *input_addr, Complex<double> *output_addr,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, int8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, int16_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, int32_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, int64_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, uint8_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, uint16_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, uint32_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, uint64_t *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, float *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, double *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, half *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, bool *output_addr, cudaStream_t stream);
|
||||
template void Cast(const int input_size, const Complex<double> *input_addr, Complex<float> *output_addr,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "utils/complex.h"
|
||||
|
||||
template <typename S, typename T>
|
||||
void Cast(const int input_size, const S *input_addr, T *output_addr, cudaStream_t stream);
|
||||
|
|
|
@ -17,23 +17,10 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/pme_common.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ static void Split_Complex(const int element_numbers, T *real_part, T *imag_part,
|
||||
const cufftComplex *complex_element) {
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < element_numbers) {
|
||||
real_part[i] = complex_element[i].x;
|
||||
imag_part[i] = complex_element[i].y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FFT3D(int Nfft, T *input_tensor, T *complex_fq, T *output_real, T *output_imag,
|
||||
const cufftHandle &FFT_plan_r2c, cudaStream_t stream) {
|
||||
cufftComplex *COMPLEX_FQ = reinterpret_cast<cufftComplex *>(complex_fq);
|
||||
cufftExecR2C(FFT_plan_r2c, input_tensor, COMPLEX_FQ);
|
||||
Split_Complex<T><<<Nfft / 1024 + 1, 1024, 0, stream>>>(Nfft, output_real, output_imag, COMPLEX_FQ);
|
||||
void FFT3D(int Nfft, T *input_tensor, Complex<T> *output_tensor, const cufftHandle &FFT_plan_r2c, cudaStream_t stream) {
|
||||
cufftExecR2C(FFT_plan_r2c, input_tensor, reinterpret_cast<cufftComplex *>(output_tensor));
|
||||
return;
|
||||
}
|
||||
|
||||
template void FFT3D<float>(int Nfft, float *input_tensor, float *complex_fq, float *output_real,
|
||||
float *output_imag, const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
|
||||
template void FFT3D<float>(int Nfft, float *input_tensor, Complex<float> *output_tensor,
|
||||
const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
|
||||
|
|
|
@ -17,10 +17,10 @@
|
|||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_FFT_3D_IMPL_H_
|
||||
|
||||
#include <cufft.h>
|
||||
#include "utils/complex.h"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void FFT3D(int Nfft, T *input_tensor, T *complex_fq, T *output_real, T *output_imag,
|
||||
const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
|
||||
void FFT3D(int Nfft, T *input_tensor, Complex<T> *output_tensor, const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -17,23 +17,11 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/pme_common.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ static void Merge_Complex(const int element_numbers, T *real_part, T *imag_part,
|
||||
cufftComplex *complex_element) {
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < element_numbers) {
|
||||
complex_element[i].x = real_part[i];
|
||||
complex_element[i].y = imag_part[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IFFT3D(int Nfft, T *input_real, T *input_imag, T *complex_fq, T *output_tensor,
|
||||
const cufftHandle &FFT_plan_c2r, cudaStream_t stream) {
|
||||
cufftComplex *COMPLEX_FQ = reinterpret_cast<cufftComplex *>(complex_fq);
|
||||
Merge_Complex<T><<<Nfft / 1024 + 1, 1024, 0, stream>>>(Nfft, input_real, input_imag, COMPLEX_FQ);
|
||||
cufftExecC2R(FFT_plan_c2r, COMPLEX_FQ, output_tensor);
|
||||
void IFFT3D(int Nfft, Complex<T> *input_tensor, T *output_tensor, const cufftHandle &FFT_plan_c2r,
|
||||
cudaStream_t stream) {
|
||||
cufftExecC2R(FFT_plan_c2r, reinterpret_cast<cufftComplex *>(input_tensor), output_tensor);
|
||||
return;
|
||||
}
|
||||
|
||||
template void IFFT3D<float>(int Nfft, float *input_real, float *input_imag, float *complex_fq,
|
||||
float *output_tensor, const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
|
||||
template void IFFT3D<float>(int Nfft, Complex<float> *input_tensor, float *output_tensor,
|
||||
const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
|
||||
|
|
|
@ -17,10 +17,10 @@
|
|||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_IFFT_3D_IMPL_H_
|
||||
|
||||
#include <cufft.h>
|
||||
#include "utils/complex.h"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void IFFT3D(int Nfft, T *input_real, T *input_imag, T *complex_fq, T *output_tensor,
|
||||
const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
|
||||
void IFFT3D(int Nfft, Complex<T> *input_tensor, T *output_tensor, const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -306,6 +306,34 @@ __global__ void AbsKernel(const half *input, half *output, const size_t count) {
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AbsKernel(const Complex<T> *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = abs(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void RealKernel(const Complex<T> *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = input[i].real();
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void ImagKernel(const Complex<T> *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = input[i].imag();
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void ConjKernel(const Complex<T> *input, Complex<T> *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = Complex<T>(input[i].real(), -input[i].imag());
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void FloorKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = floorf(input[i]);
|
||||
|
@ -478,6 +506,26 @@ void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Abs(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
AbsKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Real(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
RealKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Imag(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
ImagKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Conj(const Complex<T> *input, Complex<T> *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
ConjKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
FloorKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
|
@ -671,3 +719,15 @@ template void Floor<int>(const int *input, int *output, const size_t count, cuda
|
|||
template void Rint<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Sign<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
// complex64
|
||||
template void Real<float>(const Complex<float> *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Imag<float>(const Complex<float> *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Conj<float>(const Complex<float> *input, Complex<float> *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
// complex128
|
||||
template void Real<double>(const Complex<double> *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Imag<double>(const Complex<double> *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Conj<double>(const Complex<double> *input, Complex<double> *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "utils/complex.h"
|
||||
template <typename T>
|
||||
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
|
@ -64,5 +65,10 @@ template <typename T>
|
|||
void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Sign(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void Real(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Imag(const Complex<T> *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Conj(const Complex<T> *input, Complex<T> *output, const size_t count, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/broadcast_complex_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
#define MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(OPNAME, T0_MS_DTYPE, T1_MS_DTYPE, T0_DTYPE, T1_DTYPE) \
|
||||
MS_REG_GPU_KERNEL_THREE(OPNAME, \
|
||||
KernelAttr().AddInputAttr(T0_MS_DTYPE).AddInputAttr(T0_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
|
||||
BroadcastComplexOpGpuKernel, T0_DTYPE, T0_DTYPE, T0_DTYPE) \
|
||||
MS_REG_GPU_KERNEL_THREE(OPNAME, \
|
||||
KernelAttr().AddInputAttr(T0_MS_DTYPE).AddInputAttr(T1_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
|
||||
BroadcastComplexOpGpuKernel, T0_DTYPE, T1_DTYPE, T0_DTYPE) \
|
||||
MS_REG_GPU_KERNEL_THREE(OPNAME, \
|
||||
KernelAttr().AddInputAttr(T1_MS_DTYPE).AddInputAttr(T0_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
|
||||
BroadcastComplexOpGpuKernel, T1_DTYPE, T0_DTYPE, T0_DTYPE)
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Add, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Add, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Sub, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Sub, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Mul, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Mul, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Div, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(Div, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(RealDiv, kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float);
|
||||
MS_REG_BROADCAST_COMPLEX_GPU_KERNEL(RealDiv, kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_COMPLEX_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_COMPLEX_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <complex>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int MAX_DIMS = 7;
|
||||
template <typename T, typename S, typename G>
|
||||
class BroadcastComplexOpGpuKernel : public GpuKernel {
|
||||
public:
|
||||
BroadcastComplexOpGpuKernel() { ResetResource(); }
|
||||
~BroadcastComplexOpGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *lhs = GetDeviceAddress<T>(inputs, 0);
|
||||
S *rhs = GetDeviceAddress<S>(inputs, 1);
|
||||
|
||||
G *output = GetDeviceAddress<G>(outputs, 0);
|
||||
if (need_broadcast_) {
|
||||
BroadcastComplexArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
ElewiseComplexArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
GetOpType(kernel_node);
|
||||
auto shape1 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
auto shape2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
auto shape3 = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
need_broadcast_ = AnfAlgo::IsTensorBroadcast(shape1, shape2);
|
||||
if (need_broadcast_ && shape1.size() > MAX_DIMS) {
|
||||
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than " << MAX_DIMS;
|
||||
}
|
||||
|
||||
lhs_shape_.resize(MAX_DIMS, 1);
|
||||
rhs_shape_.resize(MAX_DIMS, 1);
|
||||
output_shape_.resize(MAX_DIMS, 1);
|
||||
for (size_t i = 0; i < shape3.size(); i++) {
|
||||
if (need_broadcast_) {
|
||||
output_shape_[i] = shape3[i];
|
||||
}
|
||||
output_num_ *= shape3[i];
|
||||
}
|
||||
int lhs_offset = shape3.size() - shape1.size();
|
||||
for (size_t j = 0; j < shape1.size(); j++) {
|
||||
if (need_broadcast_) {
|
||||
lhs_shape_[j + lhs_offset] = shape1[j];
|
||||
}
|
||||
input1_num_ *= shape1[j];
|
||||
}
|
||||
int rhs_offset = shape3.size() - shape2.size();
|
||||
for (size_t k = 0; k < shape2.size(); k++) {
|
||||
if (need_broadcast_) {
|
||||
rhs_shape_[k + rhs_offset] = shape2[k];
|
||||
}
|
||||
input2_num_ *= shape2[k];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
op_type_ = BROADCAST_TYPE_INVALID;
|
||||
need_broadcast_ = false;
|
||||
input1_num_ = 1;
|
||||
input2_num_ = 1;
|
||||
output_num_ = 1;
|
||||
lhs_shape_.clear();
|
||||
rhs_shape_.clear();
|
||||
output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override { return; }
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input1_num_ * sizeof(T));
|
||||
input_size_list_.push_back(input2_num_ * sizeof(S));
|
||||
output_size_list_.push_back(output_num_ * sizeof(G));
|
||||
}
|
||||
|
||||
private:
|
||||
void GetOpType(const CNodePtr &kernel_node) {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {{"RealDiv", BROADCAST_TYPE_REALDIV},
|
||||
{"Mul", BROADCAST_TYPE_MUL},
|
||||
{"Sub", BROADCAST_TYPE_SUB},
|
||||
{"Add", BROADCAST_TYPE_ADD},
|
||||
{"Div", BROADCAST_TYPE_DIV}};
|
||||
|
||||
auto iter = kBroadcastArithmetricTypeMap.find(kernel_name);
|
||||
if (iter != kBroadcastArithmetricTypeMap.end()) {
|
||||
op_type_ = iter->second;
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
|
||||
}
|
||||
|
||||
BroadcastOpType op_type_;
|
||||
bool need_broadcast_;
|
||||
size_t input1_num_;
|
||||
size_t input2_num_;
|
||||
size_t output_num_;
|
||||
std::vector<size_t> lhs_shape_;
|
||||
std::vector<size_t> rhs_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
}; // namespace kernel
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BROADCAST_COMPLEX_GPU_KERNEL_H_
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "utils/complex.h"
|
||||
#include "backend/kernel_compiler/gpu/math/unary_op_complex_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Real, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpComplexGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Real, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpComplexGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpComplexGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Imag, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpComplexGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
UnaryOpComplexGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Conj, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
UnaryOpComplexGpuKernel, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,160 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARY_COMPLEX_OP_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARY_COMPLEX_OP_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
template <typename T>
|
||||
class UnaryOpComplexGpuKernel : public GpuKernel {
|
||||
public:
|
||||
UnaryOpComplexGpuKernel() { ResetResource(); }
|
||||
~UnaryOpComplexGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
Complex<T> *input_addr = GetDeviceAddress<Complex<T>>(inputs, 0);
|
||||
if (is_c2r_op_) {
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
switch (unary_op_type_) {
|
||||
case UNARY_OP_REAL: {
|
||||
Real(input_addr, output_addr, inputs[0]->size / sizeof(Complex<T>),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_IMAG: {
|
||||
Imag(input_addr, output_addr, inputs[0]->size / sizeof(Complex<T>),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Complex<T> *output_addr = GetDeviceAddress<Complex<T>>(outputs, 0);
|
||||
switch (unary_op_type_) {
|
||||
case UNARY_OP_CONJ: {
|
||||
Conj(input_addr, output_addr, inputs[0]->size / sizeof(Complex<T>),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
GetOpType(kernel_node);
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but unary op needs 1 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "UnaryOpGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
output_size_ *= input_shape[i];
|
||||
}
|
||||
if (is_c2r_op_) {
|
||||
output_size_ /= 2;
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = sizeof(Complex<T>);
|
||||
output_size_ = sizeof(Complex<T>);
|
||||
workspace_size_ = 0;
|
||||
is_null_input_ = false;
|
||||
is_c2r_op_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
void GetOpType(const CNodePtr &kernel_node) {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
static std::map<std::string, UnaryOptype> kComplexSupportedC2RTypeMap = {{"Real", UNARY_OP_REAL},
|
||||
{"Imag", UNARY_OP_IMAG}};
|
||||
auto iter = kComplexSupportedC2RTypeMap.find(kernel_name);
|
||||
if (iter != kComplexSupportedC2RTypeMap.end()) {
|
||||
unary_op_type_ = iter->second;
|
||||
is_c2r_op_ = true;
|
||||
return;
|
||||
}
|
||||
static std::map<std::string, UnaryOptype> kComplexSupportedC2CTypeMap = {{"Conj", UNARY_OP_CONJ}};
|
||||
iter = kComplexSupportedC2CTypeMap.find(kernel_name);
|
||||
if (iter != kComplexSupportedC2RTypeMap.end()) {
|
||||
unary_op_type_ = iter->second;
|
||||
is_c2r_op_ = false;
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
bool is_null_input_;
|
||||
bool is_c2r_op_;
|
||||
UnaryOptype unary_op_type_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARY_COMPLEX_OP_GPU_KERNEL_H_
|
|
@ -51,6 +51,9 @@ enum UnaryOptype {
|
|||
UNARY_OP_RINT,
|
||||
UNARY_OP_ROUND,
|
||||
UNARY_OP_SIGN,
|
||||
UNARY_OP_REAL,
|
||||
UNARY_OP_IMAG,
|
||||
UNARY_OP_CONJ,
|
||||
UNARY_OP_INVALID_TYPE = 255
|
||||
};
|
||||
|
||||
|
@ -66,7 +69,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
|
|||
{"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH},
|
||||
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR},
|
||||
{"Rint", UNARY_OP_RINT}, {"Round", UNARY_OP_ROUND},
|
||||
{"Sign", UNARY_OP_SIGN}};
|
||||
{"Real", UNARY_OP_REAL}, {"Imag", UNARY_OP_IMAG},
|
||||
{"Sign", UNARY_OP_SIGN}, {"Conj", UNARY_OP_CONJ}};
|
||||
|
||||
template <typename T>
|
||||
class UnaryOpGpuKernel : public GpuKernel {
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
FFT3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
FFT3DGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FFT3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
|
||||
FFT3DGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
template <typename T>
|
||||
class FFT3DGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FFT3DGpuKernel() = default;
|
||||
|
@ -52,22 +54,17 @@ class FFT3DGpuKernel : public GpuKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto input_tensor = GetDeviceAddress<T>(inputs, 0);
|
||||
auto complex_fq = GetDeviceAddress<T>(workspace, 0);
|
||||
auto output_real = GetDeviceAddress<T>(outputs, 0);
|
||||
auto output_imag = GetDeviceAddress<T>(outputs, 1);
|
||||
auto output_tensor = GetDeviceAddress<Complex<T>>(outputs, 0);
|
||||
|
||||
cufftSetStream(FFT_plan_r2c, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
FFT3D<T>(Nfft, input_tensor, complex_fq, output_real, output_imag, FFT_plan_r2c,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
FFT3D<T>(Nfft, input_tensor, output_tensor, FFT_plan_r2c, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(Nall * sizeof(T));
|
||||
workspace_size_list_.push_back(Nfft * 2 * sizeof(T));
|
||||
output_size_list_.push_back(Nfft * sizeof(T));
|
||||
output_size_list_.push_back(Nfft * sizeof(T));
|
||||
output_size_list_.push_back(Nfft * sizeof(Complex<T>));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
IFFT3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
IFFT3DGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(IFFT3D, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeFloat32),
|
||||
IFFT3DGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
template <typename T>
|
||||
class IFFT3DGpuKernel : public GpuKernel {
|
||||
public:
|
||||
IFFT3DGpuKernel() = default;
|
||||
|
@ -51,22 +53,17 @@ class IFFT3DGpuKernel : public GpuKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto input_real = GetDeviceAddress<T>(inputs, 0);
|
||||
auto input_imag = GetDeviceAddress<T>(inputs, 1);
|
||||
auto complex_fq = GetDeviceAddress<T>(workspace, 0);
|
||||
auto input_tensor = GetDeviceAddress<Complex<T>>(inputs, 0);
|
||||
auto output_tensor = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
cufftSetStream(FFT_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
IFFT3D<T>(Nfft, input_real, input_imag, complex_fq, output_tensor, FFT_plan_c2r,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
IFFT3D<T>(Nfft, input_tensor, output_tensor, FFT_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(Nfft * sizeof(T));
|
||||
input_size_list_.push_back(Nfft * sizeof(T));
|
||||
workspace_size_list_.push_back(Nfft * 2 * sizeof(T));
|
||||
input_size_list_.push_back(Nfft * sizeof(Complex<T>));
|
||||
output_size_list_.push_back(Nall * sizeof(T));
|
||||
}
|
||||
|
||||
|
|
|
@ -99,6 +99,10 @@ static irpb::DataType GetNumberDataType(const TypePtr &type) {
|
|||
return irpb::DT_BASE_UINT;
|
||||
case kNumberTypeFloat:
|
||||
return irpb::DT_BASE_FLOAT;
|
||||
case kNumberTypeComplex64:
|
||||
return irpb::DT_COMPLEX64;
|
||||
case kNumberTypeComplex128:
|
||||
return irpb::DT_COMPLEX128;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
|
||||
}
|
||||
|
|
|
@ -89,6 +89,8 @@ enum DataType {
|
|||
DT_ANYTHING = 40; // type anything
|
||||
DT_REFKEY = 41; // type refkey
|
||||
DT_REF = 42; // type ref
|
||||
DT_COMPLEX64 = 43; // type generate complex64
|
||||
DT_COMPLEX128 = 44; // type generate complex128
|
||||
}
|
||||
|
||||
// Value definition for attribute value or parameter default value
|
||||
|
@ -256,7 +258,7 @@ message ModelProto {
|
|||
// The parameterized graph that is evaluated to execute the model.
|
||||
optional GraphProto graph = 4;
|
||||
|
||||
// metadata info of opeartors
|
||||
// metadata info of operators
|
||||
optional OperatorSetProto metadata_operators = 5;
|
||||
};
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <limits>
|
||||
#ifdef ENABLE_GPU
|
||||
#include <thrust/complex.h>
|
||||
#include <cublas_v2.h>
|
||||
#endif
|
||||
#include "base/float16.h"
|
||||
#if defined(__CUDACC__)
|
||||
|
@ -80,7 +81,11 @@ struct alignas(sizeof(T) * 2) Complex {
|
|||
HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
|
||||
HOST_DEVICE inline explicit operator float16() const { return static_cast<float16>(real_); }
|
||||
#if defined(__CUDACC__)
|
||||
HOST_DEVICE inline explicit operator half() const { return static_cast<half>(real_); }
|
||||
#else
|
||||
inline explicit operator float16() const { return static_cast<float16>(real_); }
|
||||
#endif
|
||||
|
||||
HOST_DEVICE inline constexpr Complex<T> &operator=(const T &real) {
|
||||
real_ = real;
|
||||
|
@ -100,12 +105,14 @@ struct alignas(sizeof(T) * 2) Complex {
|
|||
|
||||
HOST_DEVICE inline Complex<T> &operator*=(const T &real) {
|
||||
real_ *= real;
|
||||
imag_ *= real;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Note: check division by zero before use it.
|
||||
HOST_DEVICE inline Complex<T> &operator/=(const T &real) {
|
||||
real_ /= real;
|
||||
imag_ /= real;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
|
||||
MatrixInverse, IndexAdd, Erfinv)
|
||||
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag)
|
||||
|
||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||
|
@ -509,7 +509,10 @@ __all__ = [
|
|||
"BufferAppend",
|
||||
"BufferGetItem",
|
||||
"BufferSample",
|
||||
"Erfinv"
|
||||
"Erfinv",
|
||||
"Conj",
|
||||
"Real",
|
||||
"Imag"
|
||||
]
|
||||
|
||||
__sponge__ = [
|
||||
|
|
|
@ -86,6 +86,23 @@ class _MathBinaryOp(_BinaryOp):
|
|||
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None):
|
||||
"""Staticmethod of infer dtype for _MathBinaryOp."""
|
||||
args_type = {"x": x_dtype, "y": y_dtype}
|
||||
complex_types = [mstype.tensor_type(mstype.complex64), mstype.tensor_type(mstype.complex128)]
|
||||
if x_dtype in complex_types or y_dtype in complex_types:
|
||||
tpye_infer_dict = {
|
||||
(mstype.complex64, mstype.complex64): mstype.tensor_type(mstype.complex64),
|
||||
(mstype.complex64, mstype.float32): mstype.tensor_type(mstype.complex64),
|
||||
(mstype.float32, mstype.complex64): mstype.tensor_type(mstype.complex64),
|
||||
(mstype.complex128, mstype.complex128): mstype.tensor_type(mstype.complex128),
|
||||
(mstype.complex128, mstype.float64): mstype.tensor_type(mstype.complex128),
|
||||
(mstype.float64, mstype.complex128): mstype.tensor_type(mstype.complex128),
|
||||
}
|
||||
if (x_dtype.element_type(), y_dtype.element_type()) not in tpye_infer_dict.keys():
|
||||
raise TypeError('Complex math binary op expecting Tensor [complex64, complex64],'
|
||||
+ '[complex64, float32], [float32, complex64], [complex128, complex128],'
|
||||
+ '[complex128, float64], [float64, complex128],'
|
||||
+ f'but got : [{format(x_dtype)},{format(y_dtype)}].')
|
||||
return tpye_infer_dict.get((x_dtype.element_type(), y_dtype.element_type()))
|
||||
|
||||
validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name)
|
||||
return x_dtype
|
||||
|
||||
|
@ -4439,7 +4456,6 @@ class Abs(Primitive):
|
|||
"""Initialize Abs"""
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
|
||||
|
||||
class Sign(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs sign on the tensor element-wise.
|
||||
|
@ -5257,3 +5273,125 @@ class Erfinv(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize Erfinv"""
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
|
||||
class Conj(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a Tensor that is the real part of the input.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor, complex) - The input tensor. types: complex64, complex128.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the float type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the dtype of input is not one of: complex64, complex128.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||
>>> conj = ops.Conj()
|
||||
>>> output = conj(x)
|
||||
>>> print(output)
|
||||
1.3-0.4j
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_tensor'],
|
||||
outputs=['output_tensor'])
|
||||
|
||||
def infer_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def infer_dtype(self, input_dtype):
|
||||
validator.check_tensor_dtype_valid('input_tensor', input_dtype,
|
||||
[mstype.complex64, mstype.complex128], self.name)
|
||||
return input_dtype
|
||||
|
||||
class Real(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a Tensor that is the real part of the input.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor, complex) - The input tensor. types: complex64, complex128.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the float type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the dtype of input is not one of: complex64, complex128.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||
>>> conj = ops.Real()
|
||||
>>> output = conj(x)
|
||||
>>> print(output)
|
||||
1.3
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_tensor'],
|
||||
outputs=['output_tensor'])
|
||||
|
||||
def infer_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def infer_dtype(self, input_dtype):
|
||||
validator.check_tensor_dtype_valid('input_tensor', input_dtype,
|
||||
[mstype.complex64, mstype.complex128], self.name)
|
||||
if input_dtype == mstype.tensor_type(mstype.complex64):
|
||||
output_dtype = mstype.float32
|
||||
elif input_dtype == mstype.tensor_type(mstype.complex128):
|
||||
output_dtype = mstype.float64
|
||||
return output_dtype
|
||||
|
||||
class Imag(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a new tensor containing imaginary value of the input.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor, complex) - The input tensor. types: complex64, complex128.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the float type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the dtype of input is not one of: complex64, complex128.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.asarray(np.complex(1.3+0.4j)), mindspore.complex64)
|
||||
>>> conj = ops.Imag()
|
||||
>>> output = conj(x)
|
||||
>>> print(output)
|
||||
0.4
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_tensor'],
|
||||
outputs=['output_tensor'])
|
||||
|
||||
def infer_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def infer_dtype(self, input_dtype):
|
||||
validator.check_tensor_dtype_valid('input_tensor', input_dtype,
|
||||
[mstype.complex64, mstype.complex128], self.name)
|
||||
if input_dtype == mstype.tensor_type(mstype.complex64):
|
||||
output_dtype = mstype.float32
|
||||
elif input_dtype == mstype.tensor_type(mstype.complex128):
|
||||
output_dtype = mstype.float64
|
||||
return output_dtype
|
||||
|
|
|
@ -2988,9 +2988,7 @@ class FFT3D(PrimitiveWithInfer):
|
|||
- **input_tensor** (Tensor, float32) - [fftx, ffty, fftz]
|
||||
|
||||
Outputs:
|
||||
- **output_real** (float32) - the real part of the output tensor after
|
||||
undergoing fast Fourier transform.
|
||||
- **output_imag** (float32) - the imaginary part of the output tensor after
|
||||
- **output_tensor** (complex64) - the real part of the output tensor after
|
||||
undergoing fast Fourier transform.
|
||||
|
||||
Supported Platforms:
|
||||
|
@ -3001,27 +2999,24 @@ class FFT3D(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_tensor'],
|
||||
outputs=['output_real', 'output_imag'])
|
||||
outputs=['output_tensor'])
|
||||
|
||||
def infer_shape(self, input_shape):
|
||||
self.add_prim_attr('fftx', input_shape[0])
|
||||
self.add_prim_attr('ffty', input_shape[1])
|
||||
self.add_prim_attr('fftz', input_shape[2])
|
||||
return [input_shape[0], input_shape[1], int(input_shape[2]/2)+1],\
|
||||
[input_shape[0], input_shape[1], int(input_shape[2]/2)+1]
|
||||
return [input_shape[0], input_shape[1], int(input_shape[2]/2)+1]
|
||||
|
||||
def infer_dtype(self, input_dtype):
|
||||
validator.check_tensor_dtype_valid('input_tensor', input_dtype, mstype.number_type, self.name)
|
||||
return input_dtype, input_dtype
|
||||
|
||||
validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.float32], self.name)
|
||||
return mstype.complex64
|
||||
|
||||
class IFFT3D(PrimitiveWithInfer):
|
||||
"""
|
||||
Inverse FFT with Three-Dimensional Input.
|
||||
|
||||
Inputs:
|
||||
- **input_real** (Tensor, float32) - [fftx, ffty, fftz]
|
||||
- **input_imag** (Tensor, float32) - [fftx, ffty, fftz]
|
||||
- **input_tensor** (Tensor, complex64) - [fftx, ffty, fftz]
|
||||
|
||||
Outputs:
|
||||
- **output_tensor** (float32) - returns the tensor after undergoing
|
||||
|
@ -3034,21 +3029,18 @@ class IFFT3D(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_real', 'input_imag'],
|
||||
inputs=['input_tensor'],
|
||||
outputs=['output_tensor'])
|
||||
|
||||
def infer_shape(self, input_shape1, input_shape2):
|
||||
for i in range(len(input_shape1)):
|
||||
validator.check_int(input_shape1[i], input_shape2[i], Rel.EQ, "input_shape", self.name)
|
||||
self.add_prim_attr('fftx', input_shape1[0])
|
||||
self.add_prim_attr('ffty', input_shape1[1])
|
||||
self.add_prim_attr('fftz', input_shape1[2])
|
||||
return [input_shape1[0], input_shape1[1], (input_shape1[2]-1)*2]
|
||||
def infer_shape(self, input_shape):
|
||||
self.add_prim_attr('fftx', input_shape[0])
|
||||
self.add_prim_attr('ffty', input_shape[1])
|
||||
self.add_prim_attr('fftz', input_shape[2])
|
||||
return [input_shape[0], input_shape[1], (input_shape[2]-1)*2]
|
||||
|
||||
def infer_dtype(self, input_real_dtype, input_imag_dtype):
|
||||
validator.check_tensor_dtype_valid('input_real', input_real_dtype, mstype.number_type, self.name)
|
||||
validator.check_tensor_dtype_valid('input_imag', input_imag_dtype, mstype.number_type, self.name)
|
||||
return input_real_dtype
|
||||
def infer_dtype(self, input_dtype):
|
||||
validator.check_tensor_dtype_valid('input_tensor', input_dtype, [mstype.complex64], self.name)
|
||||
return mstype.float32
|
||||
|
||||
class NeighborListUpdate(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -144,13 +144,13 @@ TEST_F(TestComplex, test_arithmetic) {
|
|||
test_arithmetic_mul<Complex<float>, Complex<float>, Complex<float>>(
|
||||
Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(-3.6963, 4.9284));
|
||||
test_arithmetic_mul<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11,
|
||||
Complex<float>(1.2321, 2.22));
|
||||
Complex<float>(1.2321, 2.4642));
|
||||
test_arithmetic_mul<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
|
||||
Complex<float>(1.2321, 2.22));
|
||||
Complex<float>(1.2321, 2.4642));
|
||||
|
||||
test_arithmetic_div<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22),
|
||||
Complex<float>(1.11, 2.22), Complex<float>(1, 0));
|
||||
test_arithmetic_div<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(1, 2.22));
|
||||
test_arithmetic_div<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(1, 2));
|
||||
test_arithmetic_div<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
|
||||
Complex<float>(0.2, -0.4));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue