[assistant][ops][I5EWOV] Fix SquaredDifference GPU with DataType

This commit is contained in:
zheng_pengfei 2022-10-26 10:40:27 +08:00
parent c64083fe5f
commit 075e568f94
5 changed files with 72 additions and 6 deletions

View File

@ -11,8 +11,8 @@ mindspore.ops.SquaredDifference
out_{i} = (x_{i} - y_{i}) * (x_{i} - y_{i}) = (x_{i} - y_{i})^2
输入:
- **x** (Union[Tensor, Number, bool]) - 第一个输入为数值型或为bool数据类型float16、float32、int32或bool的Tensor。
- **y** (Union[Tensor, Number, bool]) - 第二个输入,通常为数值型,如果第一个输入是数据类型为float16、float32、int32或bool的Tensor时第二个输入是bool。
- **x** (Union[Tensor, Number, bool]) - 第一个输入为数值型或为bool或为Tensor。
- **y** (Union[Tensor, Number, bool]) - 第二个输入,通常为数值型,或为Tensor当第一个输入是Tensor时或为bool。
输出:
Tensorshape与广播后的shape相同数据类型为两个输入中精度较高或数字较高的类型。

View File

@ -942,10 +942,35 @@ struct AbsGradFunc<half2> {
template <typename T>
struct SquaredDifferenceFunc {
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
T diff = lhs - rhs;
return diff * diff;
}
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
Complex<T> diff = lhs - rhs;
Complex<T> conj_diff(diff.real(), -diff.imag());
return conj_diff * diff;
}
};
template <>
struct SquaredDifferenceFunc<Complex<double>> {
__device__ __host__ __forceinline__ Complex<double> operator()(const Complex<double> &lhs,
const Complex<double> &rhs) {
Complex<double> diff = lhs - rhs;
Complex<double> conj_diff(diff.real(), -diff.imag());
return conj_diff * diff;
}
};
template <>
struct SquaredDifferenceFunc<Complex<float>> {
__device__ __host__ __forceinline__ Complex<float> operator()(const Complex<float> &lhs,
const Complex<float> &rhs) {
Complex<float> diff = lhs - rhs;
Complex<float> conj_diff(diff.real(), -diff.imag());
return conj_diff * diff;
}
};
template <typename T>
@ -1214,6 +1239,9 @@ void ElewiseArithComplexKernel(const int &nums, enum BroadcastOpType op, const T
case BROADCAST_TYPE_XLOGY:
return ElewiseArithComplexKernel<T1, T2, T3, XLogyFunc<Complex<T3>>>
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_SQUARED_DIFFERENCE:
return ElewiseArithComplexKernel<T1, T2, T3, SquaredDifferenceFunc<Complex<T3>>>
<<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
default:
break;
}
@ -1660,6 +1688,11 @@ void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_SQUARED_DIFFERENCE:
return BroadcastComplexArithKernel<T1, T2, T3, SquaredDifferenceFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_COMPLEX:
return BroadcastComplexArithKernel<T1, T2, T3, ComplexFunc<T3>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],

View File

@ -107,17 +107,47 @@ bool SquaredDifferenceOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr>
}
return true;
}
template <typename T>
bool SquaredDifferenceOpGpuKernelMod::LaunchComplexKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *lhs = GetDeviceAddress<T>(inputs, 0);
T *rhs = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
if (need_broadcast_) {
BroadcastComplexArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr_));
} else {
ElewiseComplexArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr_));
}
return true;
}
#define DTYPE_REGISTER_ATTR(INPUT1, INPUT2, OUTPUT, T) \
{ \
KernelAttr().AddInputAttr(INPUT1).AddInputAttr(INPUT2).AddOutputAttr(OUTPUT), \
&SquaredDifferenceOpGpuKernelMod::LaunchKernel<T> \
}
#define COMPLEX_REGISTER_ATTR(INPUT1, INPUT2, OUTPUT, T) \
{ \
KernelAttr().AddInputAttr(INPUT1).AddInputAttr(INPUT2).AddOutputAttr(OUTPUT), \
&SquaredDifferenceOpGpuKernelMod::LaunchComplexKernel<T> \
}
template <typename T>
using Complex = mindspore::utils::Complex<T>;
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SquaredDifferenceOpGpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
DTYPE_REGISTER_ATTR(kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, float),
DTYPE_REGISTER_ATTR(kNumberTypeFloat64, kNumberTypeFloat64, kNumberTypeFloat64, double),
COMPLEX_REGISTER_ATTR(kNumberTypeComplex64, kNumberTypeComplex64, kNumberTypeComplex64, Complex<float>),
COMPLEX_REGISTER_ATTR(kNumberTypeComplex128, kNumberTypeComplex128, kNumberTypeComplex128, Complex<double>),
DTYPE_REGISTER_ATTR(kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, half),
DTYPE_REGISTER_ATTR(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
DTYPE_REGISTER_ATTR(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int)};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SquaredDifference, SquaredDifferenceOpGpuKernelMod);

View File

@ -21,6 +21,7 @@
#include <vector>
#include <string>
#include <map>
#include <complex>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
@ -57,6 +58,9 @@ class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod,
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
template <typename T>
bool LaunchComplexKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
BroadcastOpType op_type_{BROADCAST_TYPE_SQUARED_DIFFERENCE};
bool need_broadcast_;
size_t output_num_;

View File

@ -2016,10 +2016,9 @@ class SquaredDifference(Primitive):
out_{i} = (x_{i} - y_{i}) * (x_{i} - y_{i}) = (x_{i} - y_{i})^2
Inputs:
- **x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool,
or a tensor whose data type is float16, float32, int32 or bool.
- **x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool, or a tensor.
- **y** (Union[Tensor, Number, bool]) - The second input is a number, or a bool when the first input
is a tensor or a tensor whose data type is float16, float32, int32 or bool.
is a tensor, or a tensor.
Outputs:
Tensor, the shape is the same as the one after broadcasting,