forked from mindspore-Ecosystem/mindspore
Add atan2 op for gpu
This commit is contained in:
parent
c92d4f36aa
commit
3f6c54a557
|
@ -278,6 +278,40 @@ struct SquaredDifferenceFunc {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Atan2Func {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return atan2f(lhs, rhs); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Atan2Func<double> {
|
||||
__device__ __host__ __forceinline__ double operator()(const double &lhs, const double &rhs) {
|
||||
return atan2(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Atan2Func<half> {
|
||||
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
|
||||
float l = __half2float(lhs);
|
||||
float r = __half2float(rhs);
|
||||
float res = atan2f(l, r);
|
||||
return __float2half_rn(res);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Atan2Func<half2> {
|
||||
__device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
|
||||
float2 l = __half22float2(lhs);
|
||||
float2 r = __half22float2(rhs);
|
||||
float2 res;
|
||||
res.x = atan2f(l.x, r.x);
|
||||
res.y = atan2f(l.y, r.y);
|
||||
return __float22half2_rn(res);
|
||||
}
|
||||
};
|
||||
|
||||
// Element-wise Comparison
|
||||
template <typename T, typename Func>
|
||||
__global__ void ElewiseCmpKernel(const int nums, const T *x0, const T *x1, bool *y) {
|
||||
|
@ -354,6 +388,8 @@ void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, c
|
|||
return ElewiseArithKernel<T, ModFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_FLOORMOD:
|
||||
return ElewiseArithKernel<T, FloorModFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
case BROADCAST_TYPE_ATAN2:
|
||||
return ElewiseArithKernel<T, Atan2Func<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -595,6 +631,11 @@ void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t
|
|||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
case BROADCAST_TYPE_ATAN2:
|
||||
return BroadcastArithKernel<T, Atan2Func<T>><<<(size + 255) / 256, 256, 0, stream>>>(
|
||||
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
|
||||
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
|
||||
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ enum BroadcastOpType {
|
|||
BROADCAST_TYPE_SQUARED_DIFFERENCE = 14,
|
||||
BROADCAST_TYPE_MOD = 15,
|
||||
BROADCAST_TYPE_FLOORMOD = 16,
|
||||
BROADCAST_TYPE_ATAN2 = 17,
|
||||
BROADCAST_TYPE_INVALID = 0xffffffff,
|
||||
};
|
||||
|
||||
|
|
|
@ -60,6 +60,10 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
FloorMod,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
BroadcastOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Atan2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
BroadcastOpGpuKernel, double)
|
||||
|
||||
// fp32
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
@ -118,6 +122,10 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
FloorMod,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BroadcastOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Atan2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BroadcastOpGpuKernel, float)
|
||||
|
||||
// fp16
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
@ -176,6 +184,10 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
FloorMod,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BroadcastOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Atan2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BroadcastOpGpuKernel, half)
|
||||
|
||||
// int32
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
|
|
@ -147,7 +147,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
|
||||
{"Add", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
|
||||
{"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN}, {"Mod", BROADCAST_TYPE_MOD},
|
||||
{"FloorMod", BROADCAST_TYPE_FLOORMOD},
|
||||
{"FloorMod", BROADCAST_TYPE_FLOORMOD}, {"Atan2", BROADCAST_TYPE_ATAN2},
|
||||
};
|
||||
|
||||
iter = kBroadcastArithmetricTypeMap.find(kernel_name);
|
||||
|
|
|
@ -4084,7 +4084,7 @@ class Atan2(_MathBinaryOp):
|
|||
TypeError: If `input_x` or `input_y` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([0, 1]), mindspore.float32)
|
||||
|
|
|
@ -87,6 +87,10 @@ def test_nobroadcast():
|
|||
output_np = np.mod(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Atan2()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.arctan2(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -146,6 +150,10 @@ def test_nobroadcast_fp16():
|
|||
output_np = np.mod(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Atan2()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.arctan2(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -213,6 +221,10 @@ def test_broadcast():
|
|||
output_np = np.mod(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Atan2()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.arctan2(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -280,6 +292,10 @@ def test_broadcast_diff_dims():
|
|||
output_np = np.mod(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Atan2()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.arctan2(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -339,6 +355,10 @@ def test_broadcast_fp16():
|
|||
output_np = np.mod(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Atan2()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.arctan2(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
Loading…
Reference in New Issue