!8146 Improve performance for GPU-ScatterUpdate, add int32 support

Merge pull request !8146 from 34bunny/GPU-ScatterUpdateFix
This commit is contained in:
mindspore-ci-bot 2020-11-04 22:05:52 +08:00 committed by Gitee
commit 4b4ca1a188
5 changed files with 91 additions and 25 deletions

View File

@ -32,5 +32,12 @@ MS_REG_GPU_KERNEL_ONE(ScatterUpdate,
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
ScatterUpdateKernel, half)
MS_REG_GPU_KERNEL_ONE(ScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterUpdateKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -40,8 +40,10 @@ class ScatterUpdateKernel : public GpuKernel {
int *indices = GetDeviceAddress<int>(inputs, 1);
T *updates = GetDeviceAddress<T>(inputs, 2);
T *output = GetDeviceAddress<T>(outputs, 0);
CalScatterUpdate(input_size_, inner_size_, indices_size_, input, indices, updates, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CalScatterUpdate(inner_size_, indices_size_, indices, updates, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

View File

@ -17,29 +17,27 @@
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh"
template <typename T>
__global__ void ScatterUpdate(const int input_size, const int inner_size, const int indices_size, const T *input,
const int *indices, const T *updates, T *output) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) {
output[pos] = input[pos];
__global__ void ScatterUpdate(const int inner_size, const int updates_size, const int *indices, const T *updates,
T *output) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const int index = pos / inner_size;
const int offset = pos % inner_size;
for (int i = 0; i < indices_size; i++) {
const int update_pos = i * inner_size + offset;
output[pos] = (indices[i] == index ? updates[update_pos] : output[pos]);
}
const int current_pos = indices[index] * inner_size + offset;
output[current_pos] = updates[pos];
}
}
template <typename T>
void CalScatterUpdate(const int &input_size, const int &inner_size, const int &indices_size, const T *input,
const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) {
ScatterUpdate<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, inner_size, indices_size, input,
indices, updates, output);
void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output,
cudaStream_t cuda_stream) {
const int updates_size = inner_size * indices_size;
ScatterUpdate<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates,
output);
}
template void CalScatterUpdate<float>(const int &input_size, const int &inner_size, const int &indices_size,
const float *input, const int *indices, const float *updates, float *output,
cudaStream_t cuda_stream);
template void CalScatterUpdate<half>(const int &input_size, const int &inner_size, const int &indices_size,
const half *input, const int *indices, const half *updates, half *output,
cudaStream_t cuda_stream);
template void CalScatterUpdate<float>(const int &inner_size, const int &indices_size, const int *indices,
const float *updates, float *output, cudaStream_t cuda_stream);
template void CalScatterUpdate<half>(const int &inner_size, const int &indices_size, const int *indices,
const half *updates, half *output, cudaStream_t cuda_stream);
template void CalScatterUpdate<int>(const int &inner_size, const int &indices_size, const int *indices,
const int *updates, int *output, cudaStream_t cuda_stream);

View File

@ -20,7 +20,7 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalScatterUpdate(const int &input_size, const int &inner_size, const int &indices_size, const T *input,
const int *indices, const T *updates, T *output, cudaStream_t cuda_stream);
void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_

View File

@ -75,7 +75,19 @@ def test_scatter_update_float16():
updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float16))
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[0., 1., 2.],
[3., 4., 5.]])
[3., 4., 5.]]).astype(np.float16)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_int32():
inputx = Tensor(np.zeros((2, 3)).astype(np.int32))
indices = Tensor(np.array([0, 1]).astype(np.int32))
updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.int32))
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[0., 1., 2.],
[3., 4., 5.]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@ -89,7 +101,7 @@ def test_scatter_update_large_float16():
expected = np.array([[69., 70., 71.],
[66., 67., 68.],
[63., 64., 65.],
[72., 73., 74.]])
[72., 73., 74.]]).astype(np.float16)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@ -102,5 +114,52 @@ def test_scatter_update_disordered_float16():
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[45., 44., 43., 42.],
[63., 64., 65., 66.],
[67., 68., 69., 70.]])
[67., 68., 69., 70.]]).astype(np.float16)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_disordered_int32():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
indices = Tensor(np.array([1, 2]).astype(np.int32))
updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.int32))
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[45., 44., 43., 42.],
[63., 64., 65., 66.],
[67., 68., 69., 70.]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_large_shape_float16():
inputx = Tensor(np.arange(96).reshape((4, 2, 3, 4)).astype(np.float16))
indices = Tensor(np.array([1, 0]).astype(np.int32))
updates = Tensor(np.flip(np.arange(48).reshape((2, 2, 3, 4)).astype(np.float16)))
output = scatter_update_net(inputx, indices, updates)
expected = np.array([[[[23., 22., 21., 20.],
[19., 18., 17., 16.],
[15., 14., 13., 12.]],
[[11., 10., 9., 8.],
[7., 6., 5., 4.],
[3., 2., 1., 0.]]],
[[[47., 46., 45., 44.],
[43., 42., 41., 40.],
[39., 38., 37., 36.]],
[[35., 34., 33., 32.],
[31., 30., 29., 28.],
[27., 26., 25., 24.]]],
[[[48., 49., 50., 51.],
[52., 53., 54., 55.],
[56., 57., 58., 59.]],
[[60., 61., 62., 63.],
[64., 65., 66., 67.],
[68., 69., 70., 71.]]],
[[[72., 73., 74., 75.],
[76., 77., 78., 79.],
[80., 81., 82., 83.]],
[[84., 85., 86., 87.],
[88., 89., 90., 91.],
[92., 93., 94., 95.]]]]).astype(np.float16)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)