diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 3d32017c4df..4642412824d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -199,6 +199,10 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x cudaStream_t stream); template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, bool *y, cudaStream_t stream); +template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int8_t *x0, const int8_t *x1, bool *y, + cudaStream_t stream); +template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, bool *y, + cudaStream_t stream); // Element-wise ArithMetic template @@ -261,6 +265,10 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half cudaStream_t stream); template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, int *y, cudaStream_t stream); +template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int8_t *x0, const int8_t *x1, int8_t *y, + cudaStream_t stream); +template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, uint8_t *y, + cudaStream_t stream); // Broadcast comparation __device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } @@ -333,6 +341,12 @@ template void BroadcastCmp(const std::vector &x0_dims, const std::vector template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, bool *y, cudaStream_t stream); +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int8_t *x0, + const int8_t *x1, bool *y, cudaStream_t stream); +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const uint8_t *x0, + const uint8_t *x1, bool *y, cudaStream_t stream); // Broadcast Arithmetic template @@ -448,6 +462,12 @@ template void BroadcastArith(const std::vector &x0_dims, const std::vect template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, int *y, cudaStream_t stream); +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int8_t *x0, + const int8_t *x1, int8_t *y, cudaStream_t stream); +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const uint8_t *x0, + const uint8_t *x1, uint8_t *y, cudaStream_t stream); // BroadcastTo template diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index f890d6c341e..55f9456a8dd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -147,5 +147,15 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int) + +// int8 +MS_REG_GPU_KERNEL_ONE( + DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + BroadcastOpGpuKernel, int8_t) + +// uint8 +MS_REG_GPU_KERNEL_ONE( + DivNoNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + BroadcastOpGpuKernel, uint8_t) } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/gpu/test_broadcast_op.py b/tests/st/ops/gpu/test_broadcast_op.py index 3a4145fade6..d56f3bb4bd7 100644 --- a/tests/st/ops/gpu/test_broadcast_op.py +++ b/tests/st/ops/gpu/test_broadcast_op.py @@ -297,3 +297,43 @@ def test_broadcast_fp16(): x2_np_zero = np.zeros_like(x2_np) output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero)) assert np.allclose(output_ms.asnumpy(), x2_np_zero) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_divnonan_int8(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + np.random.seed(42) + x1_np_int8 = np.random.randint(1, 100, (10, 20)).astype(np.int8) + x2_np_int8 = np.random.randint(1, 100, (10, 20)).astype(np.int8) + + output_ms = P.DivNoNan()(Tensor(x1_np_int8), Tensor(x2_np_int8)) + output_np = x1_np_int8 // x2_np_int8 + print(output_ms.asnumpy(), output_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + x2_np_zero = np.zeros_like(x2_np_int8) + output_ms = P.DivNoNan()(Tensor(x1_np_int8), Tensor(x2_np_zero)) + assert np.allclose(output_ms.asnumpy(), x2_np_zero) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_divnonan_uint8(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + np.random.seed(42) + x1_np_uint8 = np.random.randint(1, 100, (10, 20)).astype(np.uint8) + x2_np_uint8 = np.random.randint(1, 100, (10, 20)).astype(np.uint8) + + output_ms = P.DivNoNan()(Tensor(x1_np_uint8), Tensor(x2_np_uint8)) + output_np = x1_np_uint8 // x2_np_uint8 + print(output_ms.asnumpy(), output_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + x2_np_zero = np.zeros_like(x2_np_uint8) + output_ms = P.DivNoNan()(Tensor(x1_np_uint8), Tensor(x2_np_zero)) + assert np.allclose(output_ms.asnumpy(), x2_np_zero)