diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu index 17adb738e1e..861074264c3 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu @@ -42,6 +42,19 @@ struct PowerFunc { __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } }; +template <> +struct PowerFunc { + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { + return __float2half(pow(__half2float(lhs), __half2float(rhs))); + } +}; + +template <> +struct PowerFunc { + // invalid branch + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } +}; + __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } template @@ -131,8 +144,20 @@ template void Broadcast(const int &l0, const int &l1, const int &l2, const int & const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, const float *input0, const float *input1, float *output, cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const half *input0, const half *input1, bool *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const half *input0, const half *input1, half *output, + cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, bool *output, cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, float *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, + bool *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, + half *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc index 1761597c7b8..491b2040e6a 100644 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc @@ -18,6 +18,7 @@ namespace mindspore { namespace kernel { +// fp32 MS_REG_GPU_KERNEL_TWO( Greater, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), @@ -36,5 +37,25 @@ MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO( Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), BroadcastOpGpuKernel, float, float) + +// fp16 +MS_REG_GPU_KERNEL_TWO( + Greater, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half, bool) +MS_REG_GPU_KERNEL_TWO( + Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half, bool) +MS_REG_GPU_KERNEL_TWO( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h index dfb0487ee4b..03f4abb473d 100644 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h @@ -65,15 +65,20 @@ class BroadcastOpGpuKernel : public GpuKernel { MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; } - for (size_t i = 0; i < shape1.size(); i++) { - lhs_shape_[i] = shape1[i]; - rhs_shape_[i] = shape2[i]; + for (size_t i = 0; i < shape3.size(); i++) { output_shape_[i] = shape3[i]; - - input1_num_ *= shape1[i]; - input2_num_ *= shape2[i]; output_num_ *= shape3[i]; } + int offset = shape3.size() - shape1.size(); + for (size_t i = 0; i < shape1.size(); i++) { + lhs_shape_[i + offset] = shape1[i]; + input1_num_ *= shape1[i]; + } + offset = shape3.size() - shape2.size(); + for (size_t i = 0; i < shape2.size(); i++) { + rhs_shape_[i + offset] = shape2[i]; + input2_num_ *= shape2[i]; + } InitSizeLists(); return true; @@ -105,6 +110,9 @@ class BroadcastOpGpuKernel : public GpuKernel { } bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { + if (lhs.size() != rhs.size()) { + return true; + } for (size_t i = 0; i < lhs.size(); i++) { if (lhs[i] != rhs[i]) { return true; diff --git a/tests/st/ops/gpu/test_broadcast_op.py b/tests/st/ops/gpu/test_broadcast_op.py index 2baa72ad6fc..272389672ae 100644 --- a/tests/st/ops/gpu/test_broadcast_op.py +++ b/tests/st/ops/gpu/test_broadcast_op.py @@ -79,3 +79,33 @@ def test_broadcast(): output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) output_np = np.power(x1_np, x2_np) assert np.allclose(output_ms.asnumpy(), output_np) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_diff_dims(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + x1_np = np.random.rand(2).astype(np.float32) + x2_np = np.random.rand(2, 1).astype(np.float32) + + output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.minimum(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.maximum(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np > x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np < x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.power(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np)