From 6279ff4bbb4ebfaeb0ae801273c6253d958637e8 Mon Sep 17 00:00:00 2001 From: xcnick Date: Tue, 6 Apr 2021 16:15:59 +0800 Subject: [PATCH] add mod op kernel for gpu --- .../gpu/cuda_impl/broadcast_impl.cu | 51 +++++++++++++++++++ .../gpu/cuda_impl/broadcast_impl.cuh | 1 + .../gpu/math/broadcast_gpu_kernel.cc | 15 ++++++ .../gpu/math/broadcast_gpu_kernel.h | 2 +- tests/st/ops/gpu/test_broadcast_op.py | 21 ++++++++ 5 files changed, 89 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index c00bfed8bab..0d107621b2d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -176,6 +176,50 @@ struct FloorDivFunc { } }; +template +struct ModFunc { + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { + T data_div = lhs / rhs; + T data_div_min = data_div < 0.0 ? data_div : 0.0; + T data_div_max = data_div > 0.0 ? data_div : 0.0; + T data_div_max_floor = floorf(data_div_max); + T data_div_min_ceil = ceilf(data_div_min); + T data_div_res = data_div_max_floor + data_div_min_ceil; + return lhs - data_div_res * rhs; + } +}; + +template <> +struct ModFunc { + __device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) { + float l = __half2float(lhs); + float r = __half2float(rhs); + float data_div = l / r; + float data_div_min = data_div < 0.0 ? data_div : 0.0; + float data_div_max = data_div > 0.0 ? data_div : 0.0; + float data_div_max_floor = floorf(data_div_max); + float data_div_min_ceil = ceilf(data_div_min); + float data_div_res = data_div_max_floor + data_div_min_ceil; + return __float2half_rn(l - data_div_res * r); + } +}; + +template <> +struct ModFunc { + __device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) { + float2 l = __half22float2(lhs); + float2 r = __half22float2(rhs); + float2 data_div; + data_div.x = l.x / r.x; + data_div.y = l.y / r.y; + data_div.x = data_div.x < 0.0 ? ceilf(data_div.x) : floorf(data_div.x); + data_div.y = data_div.y < 0.0 ? ceilf(data_div.y) : floorf(data_div.y); + data_div.x = l.x - data_div.x * r.x; + data_div.y = l.y - data_div.y * r.y; + return __float22half2_rn(data_div); + } +}; + template struct AbsGradFunc { __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) { @@ -272,6 +316,8 @@ void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, c return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); case BROADCAST_TYPE_SQUARED_DIFFERENCE: return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_MOD: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); default: break; } @@ -503,6 +549,11 @@ void BroadcastArith(const std::vector &x0_dims, const std::vector><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); default: break; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index f3d0fa51055..9e541076e1a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -38,6 +38,7 @@ enum BroadcastOpType { BROADCAST_TYPE_DIVNONAN = 12, BROADCAST_TYPE_EQUAL = 13, BROADCAST_TYPE_SQUARED_DIFFERENCE = 14, + BROADCAST_TYPE_MOD = 15, BROADCAST_TYPE_INVALID = 0xffffffff, }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 481becc694b..4a505faf75a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -53,6 +53,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Pow, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), BroadcastOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE( + Mod, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + BroadcastOpGpuKernel, double) // fp32 MS_REG_GPU_KERNEL_ONE( @@ -104,6 +107,9 @@ MS_REG_GPU_KERNEL_ONE( DivNoNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + Mod, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float) // fp16 MS_REG_GPU_KERNEL_ONE( @@ -155,6 +161,9 @@ MS_REG_GPU_KERNEL_ONE( DivNoNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + Mod, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half) // int32 MS_REG_GPU_KERNEL_ONE( @@ -193,6 +202,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int) // int64 MS_REG_GPU_KERNEL_ONE( @@ -231,6 +243,9 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), BroadcastOpGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + Mod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BroadcastOpGpuKernel, int64_t) // int8 MS_REG_GPU_KERNEL_ONE( diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index 538c7299f00..ee3a2bac10d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -146,7 +146,7 @@ class BroadcastOpGpuKernel : public GpuKernel { {"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"Add", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, - {"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN}, + {"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN}, {"Mod", BROADCAST_TYPE_MOD}, }; iter = kBroadcastArithmetricTypeMap.find(kernel_name); diff --git a/tests/st/ops/gpu/test_broadcast_op.py b/tests/st/ops/gpu/test_broadcast_op.py index d56f3bb4bd7..4d2016efa80 100644 --- a/tests/st/ops/gpu/test_broadcast_op.py +++ b/tests/st/ops/gpu/test_broadcast_op.py @@ -79,6 +79,11 @@ def test_nobroadcast(): output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero)) assert np.allclose(output_ms.asnumpy(), x2_np_zero) + output_ms = P.Mod()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.fmod(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -129,6 +134,10 @@ def test_nobroadcast_fp16(): output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero)) assert np.allclose(output_ms.asnumpy(), x2_np_zero) + output_ms = P.Mod()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.fmod(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -188,6 +197,10 @@ def test_broadcast(): output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero)) assert np.allclose(output_ms.asnumpy(), x2_np_zero) + output_ms = P.Mod()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.fmod(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -247,6 +260,10 @@ def test_broadcast_diff_dims(): output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero)) assert np.allclose(output_ms.asnumpy(), x2_np_zero) + output_ms = P.Mod()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.fmod(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -298,6 +315,10 @@ def test_broadcast_fp16(): output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero)) assert np.allclose(output_ms.asnumpy(), x2_np_zero) + output_ms = P.Mod()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.fmod(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training