diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc index fb36a2788bf..c3324a56c07 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc @@ -32,5 +32,12 @@ MS_REG_GPU_KERNEL_ONE(ScatterAdd, .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), ScatterAddKernel, half) +MS_REG_GPU_KERNEL_ONE(ScatterAdd, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + ScatterAddKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.h index 5ebdcd001e7..395c9e0b399 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class ScatterAddKernel : public GpuKernel { public: - ScatterAddKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0) {} + ScatterAddKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0), use_locking_(true) {} ~ScatterAddKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -40,7 +40,10 @@ class ScatterAddKernel : public GpuKernel { int *indices = GetDeviceAddress(inputs, 1); T *updates = GetDeviceAddress(inputs, 2); T *output = GetDeviceAddress(outputs, 0); - CalScatterAdd(input_size_, inner_size_, indices_size_, input, indices, updates, output, + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + CalScatterAdd(input_size_, inner_size_, indices_size_, use_locking_, input, indices, updates, output, reinterpret_cast(stream_ptr)); return true; } @@ -69,6 +72,7 @@ class ScatterAddKernel : public GpuKernel { indices_size_ *= indices_shape[i]; } updates_size_ = indices_size_ * inner_size_; + use_locking_ = GetAttr(kernel_node, "use_locking"); InitSizeLists(); return true; } @@ -86,6 +90,7 @@ class ScatterAddKernel : public GpuKernel { int inner_size_; int indices_size_; int updates_size_; + bool use_locking_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu index 65ea9ae8395..04b3767a3aa 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu @@ -14,32 +14,39 @@ * limitations under the License. */ +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh" template -__global__ void ScatterAdd(const int input_size, const int inner_size, const int indices_size, const T *input, - const int *indices, const T *updates, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) { +__global__ void ScatterAdd(const int input_size, const int inner_size, const int indices_size, const int updates_size, + const bool use_locking, const T *input, const int *indices, const T *updates, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { output[pos] = input[pos]; const size_t index = pos / inner_size; const size_t offset = pos % inner_size; - for (size_t i = 0; i < indices_size; i++) { - const T value = updates[i*inner_size+offset]; - output[pos] += (indices[i] == index ? value : static_cast(0.0)); + const size_t current_pos = indices[index] * inner_size + offset; + if (use_locking) { + MsAtomicAdd(&output[current_pos], updates[pos]); + } else { + output[current_pos] += updates[pos]; } } } template -void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const T *input, - const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) { - ScatterAdd<<>>(input_size, inner_size, indices_size, input, - indices, updates, output); +void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking, + const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) { + const int updates_size = inner_size * indices_size; + ScatterAdd<<>>( + input_size, inner_size, indices_size, updates_size, use_locking, input, indices, updates, output); } template void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, - const float *input, const int *indices, const float *updates, float *output, - cudaStream_t cuda_stream); + const bool &use_locking, const float *input, const int *indices, + const float *updates, float *output, cudaStream_t cuda_stream); template void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, - const half *input, const int *indices, const half *updates, half *output, - cudaStream_t cuda_stream); + const bool &use_locking, const half *input, const int *indices, const half *updates, + half *output, cudaStream_t cuda_stream); +template void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, + const bool &use_locking, const int *input, const int *indices, const int *updates, + int *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh index 4cd4d8503dd..5450a0ef7f9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh @@ -20,7 +20,7 @@ #include "runtime/device/gpu/cuda_common.h" template -void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const T *input, - const int *indices, const T *updates, T *output, cudaStream_t cuda_stream); +void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking, + const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_ diff --git a/tests/st/ops/gpu/test_scatter_add_op.py b/tests/st/ops/gpu/test_scatter_add_op.py index b5ba65e0c82..12b5a0325ca 100644 --- a/tests/st/ops/gpu/test_scatter_add_op.py +++ b/tests/st/ops/gpu/test_scatter_add_op.py @@ -24,9 +24,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") # all cases tested against dchip class TestScatterAddNet(nn.Cell): - def __init__(self, inputx, indices, updates): + def __init__(self, lock, inputx, indices, updates): super(TestScatterAddNet, self).__init__() - self.scatter_add = P.ScatterAdd() + self.scatter_add = P.ScatterAdd(use_locking=lock) self.inputx = Parameter(inputx, name="inputx") self.indices = Parameter(indices, name="indices") self.updates = Parameter(updates, name="updates") @@ -36,7 +36,13 @@ class TestScatterAddNet(nn.Cell): return out def scatter_add_net(inputx, indices, updates): - net = TestScatterAddNet(inputx, indices, updates) + lock = True + net = TestScatterAddNet(lock, inputx, indices, updates) + return net() + +def scatter_add_use_locking_false_net(inputx, indices, updates): + lock = False + net = TestScatterAddNet(lock, inputx, indices, updates) return net() @pytest.mark.level0 @@ -51,6 +57,52 @@ def test_scatter_add_small_float32(): [12., 14., 16.]]) np.testing.assert_array_almost_equal(output.asnumpy(), expected) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_add_large_shape_float32(): + inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) + indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) + updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[[[1., 2., 3., 4.], + [5., 6., 7., 8.], + [9., 10., 11., 12.]], + [[13., 14., 15., 16.], + [17., 18., 19., 20.], + [21., 22., 23., 24.]]], + [[[73., 74., 75., 76.], + [77., 78., 79., 80.], + [81., 82., 83., 84.]], + [[85., 86., 87., 88.], + [89., 90., 91., 92.], + [93., 94., 95., 96.]]], + [[[25., 26., 27., 28.], + [29., 30., 31., 32.], + [33., 34., 35., 36.]], + [[37., 38., 39., 40.], + [41., 42., 43., 44.], + [45., 46., 47., 48.]]], + [[[49., 50., 51., 52.], + [53., 54., 55., 56.], + [57., 58., 59., 60.]], + [[61., 62., 63., 64.], + [65., 66., 67., 68.], + [69., 70., 71., 72.]]]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_add_small_float32_use_locking_false(): + inputx = Tensor(np.zeros((2, 3)).astype(np.float32)) + indices = Tensor(np.array([1, 0]).astype(np.int32)) + updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) + output = scatter_add_use_locking_false_net(inputx, indices, updates) + expected = np.array([[3., 4., 5.], + [0., 1., 2.]]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -112,3 +164,35 @@ def test_scatter_add_disordered_float16(): [187., 188., 189., 190.], [492., 496., 500., 504.]]) np.testing.assert_array_almost_equal(output.asnumpy(), expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_add_large_int32(): + inputx = Tensor(np.zeros((2, 3, 4)).astype(np.int32)) + indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[[138., 140., 142., 144.], + [146., 148., 150., 152.], + [154., 156., 158., 160.]], + [[186., 188., 190., 192.], + [194., 196., 198., 200.], + [202., 204., 206., 208.]]]).astype(np.int32) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_add_disordered_int32(): + inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))) + indices = Tensor(np.array([[[0, 1, 2], + [2, 1, 0]], + [[0, 0, 0], + [2, 2, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) + output = scatter_add_net(inputx, indices, updates) + expected = np.array([[464., 468., 472., 476.], + [187., 188., 189., 190.], + [492., 496., 500., 504.]]).astype(np.int32) + np.testing.assert_array_almost_equal(output.asnumpy(), expected)