forked from mindspore-Ecosystem/mindspore
[assistant][ops][I5EWOV] Fix SquaredDifference GPU with DataType
This commit is contained in:
parent
c64083fe5f
commit
075e568f94
|
@ -11,8 +11,8 @@ mindspore.ops.SquaredDifference
|
|||
out_{i} = (x_{i} - y_{i}) * (x_{i} - y_{i}) = (x_{i} - y_{i})^2
|
||||
|
||||
输入:
|
||||
- **x** (Union[Tensor, Number, bool]) - 第一个输入,为数值型,或为bool,或数据类型为float16、float32、int32或bool的Tensor。
|
||||
- **y** (Union[Tensor, Number, bool]) - 第二个输入,通常为数值型,如果第一个输入是数据类型为float16、float32、int32或bool的Tensor时,第二个输入是bool。
|
||||
- **x** (Union[Tensor, Number, bool]) - 第一个输入,为数值型,或为bool,或为Tensor。
|
||||
- **y** (Union[Tensor, Number, bool]) - 第二个输入,通常为数值型,或为Tensor,当第一个输入是Tensor时或为bool。
|
||||
|
||||
输出:
|
||||
Tensor,shape与广播后的shape相同,数据类型为两个输入中精度较高或数字较高的类型。
|
||||
|
|
|
@ -942,10 +942,35 @@ struct AbsGradFunc<half2> {
|
|||
|
||||
template <typename T>
|
||||
struct SquaredDifferenceFunc {
|
||||
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
|
||||
T diff = lhs - rhs;
|
||||
return diff * diff;
|
||||
}
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
Complex<T> diff = lhs - rhs;
|
||||
Complex<T> conj_diff(diff.real(), -diff.imag());
|
||||
return conj_diff * diff;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SquaredDifferenceFunc<Complex<double>> {
|
||||
__device__ __host__ __forceinline__ Complex<double> operator()(const Complex<double> &lhs,
|
||||
const Complex<double> &rhs) {
|
||||
Complex<double> diff = lhs - rhs;
|
||||
Complex<double> conj_diff(diff.real(), -diff.imag());
|
||||
return conj_diff * diff;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SquaredDifferenceFunc<Complex<float>> {
|
||||
__device__ __host__ __forceinline__ Complex<float> operator()(const Complex<float> &lhs,
|
||||
const Complex<float> &rhs) {
|
||||
Complex<float> diff = lhs - rhs;
|
||||
Complex<float> conj_diff(diff.real(), -diff.imag());
|
||||
return conj_diff * diff;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -1214,6 +1239,9 @@ void ElewiseArithComplexKernel(const int &nums, enum BroadcastOpType op, const T
|
|||
case BROADCAST_TYPE_XLOGY:
|
||||
return ElewiseArithComplexKernel<T1, T2, T3, XLogyFunc<Complex<T3>>>
|
||||
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_SQUARED_DIFFERENCE:
|
||||
return ElewiseArithComplexKernel<T1, T2, T3, SquaredDifferenceFunc<Complex<T3>>>
|
||||
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -1660,6 +1688,11 @@ void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector
|
|||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_SQUARED_DIFFERENCE:
|
||||
return BroadcastComplexArithKernel<T1, T2, T3, SquaredDifferenceFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_COMPLEX:
|
||||
return BroadcastComplexArithKernel<T1, T2, T3, ComplexFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
|
|
|
@ -107,17 +107,47 @@ bool SquaredDifferenceOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr>
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SquaredDifferenceOpGpuKernelMod::LaunchComplexKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *lhs = GetDeviceAddress<T>(inputs, 0);
|
||||
T *rhs = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
if (need_broadcast_) {
|
||||
BroadcastComplexArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
} else {
|
||||
ElewiseComplexArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#define DTYPE_REGISTER_ATTR(INPUT1, INPUT2, OUTPUT, T) \
|
||||
{ \
|
||||
KernelAttr().AddInputAttr(INPUT1).AddInputAttr(INPUT2).AddOutputAttr(OUTPUT), \
|
||||
&SquaredDifferenceOpGpuKernelMod::LaunchKernel<T> \
|
||||
}
|
||||
|
||||
#define COMPLEX_REGISTER_ATTR(INPUT1, INPUT2, OUTPUT, T) \
|
||||
{ \
|
||||
KernelAttr().AddInputAttr(INPUT1).AddInputAttr(INPUT2).AddOutputAttr(OUTPUT), \
|
||||
&SquaredDifferenceOpGpuKernelMod::LaunchComplexKernel<T> \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SquaredDifferenceOpGpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, float),
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeFloat64, kNumberTypeFloat64, kNumberTypeFloat64, double),
|
||||
COMPLEX_REGISTER_ATTR(kNumberTypeComplex64, kNumberTypeComplex64, kNumberTypeComplex64, Complex<float>),
|
||||
COMPLEX_REGISTER_ATTR(kNumberTypeComplex128, kNumberTypeComplex128, kNumberTypeComplex128, Complex<double>),
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, half),
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int)};
|
||||
|
||||
return func_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SquaredDifference, SquaredDifferenceOpGpuKernelMod);
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <complex>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
@ -57,6 +58,9 @@ class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod,
|
|||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
bool LaunchComplexKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
BroadcastOpType op_type_{BROADCAST_TYPE_SQUARED_DIFFERENCE};
|
||||
bool need_broadcast_;
|
||||
size_t output_num_;
|
||||
|
|
|
@ -2016,10 +2016,9 @@ class SquaredDifference(Primitive):
|
|||
out_{i} = (x_{i} - y_{i}) * (x_{i} - y_{i}) = (x_{i} - y_{i})^2
|
||||
|
||||
Inputs:
|
||||
- **x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool,
|
||||
or a tensor whose data type is float16, float32, int32 or bool.
|
||||
- **x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool, or a tensor.
|
||||
- **y** (Union[Tensor, Number, bool]) - The second input is a number, or a bool when the first input
|
||||
is a tensor or a tensor whose data type is float16, float32, int32 or bool.
|
||||
is a tensor, or a tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the one after broadcasting,
|
||||
|
|
Loading…
Reference in New Issue