!8141 Improve performance for GPU-ScatterAdd, add use_locking and add int32 support

Merge pull request !8141 from 34bunny/GPU-ScatterAdd
This commit is contained in:
mindspore-ci-bot 2020-11-04 21:07:57 +08:00 committed by Gitee
commit 4c2344ed35
5 changed files with 124 additions and 21 deletions

View File

@ -32,5 +32,12 @@ MS_REG_GPU_KERNEL_ONE(ScatterAdd,
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
ScatterAddKernel, half)
MS_REG_GPU_KERNEL_ONE(ScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterAddKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -27,7 +27,7 @@ namespace kernel {
template <typename T>
class ScatterAddKernel : public GpuKernel {
public:
ScatterAddKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0) {}
ScatterAddKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0), use_locking_(true) {}
~ScatterAddKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -40,7 +40,10 @@ class ScatterAddKernel : public GpuKernel {
int *indices = GetDeviceAddress<int>(inputs, 1);
T *updates = GetDeviceAddress<T>(inputs, 2);
T *output = GetDeviceAddress<T>(outputs, 0);
CalScatterAdd(input_size_, inner_size_, indices_size_, input, indices, updates, output,
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CalScatterAdd(input_size_, inner_size_, indices_size_, use_locking_, input, indices, updates, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -69,6 +72,7 @@ class ScatterAddKernel : public GpuKernel {
indices_size_ *= indices_shape[i];
}
updates_size_ = indices_size_ * inner_size_;
use_locking_ = GetAttr<bool>(kernel_node, "use_locking");
InitSizeLists();
return true;
}
@ -86,6 +90,7 @@ class ScatterAddKernel : public GpuKernel {
int inner_size_;
int indices_size_;
int updates_size_;
bool use_locking_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -14,32 +14,39 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh"
template <typename T>
__global__ void ScatterAdd(const int input_size, const int inner_size, const int indices_size, const T *input,
const int *indices, const T *updates, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) {
__global__ void ScatterAdd(const int input_size, const int inner_size, const int indices_size, const int updates_size,
const bool use_locking, const T *input, const int *indices, const T *updates, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
output[pos] = input[pos];
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
for (size_t i = 0; i < indices_size; i++) {
const T value = updates[i*inner_size+offset];
output[pos] += (indices[i] == index ? value : static_cast<T>(0.0));
const size_t current_pos = indices[index] * inner_size + offset;
if (use_locking) {
MsAtomicAdd(&output[current_pos], updates[pos]);
} else {
output[current_pos] += updates[pos];
}
}
}
template <typename T>
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const T *input,
const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) {
ScatterAdd<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, inner_size, indices_size, input,
indices, updates, output);
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking,
const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) {
const int updates_size = inner_size * indices_size;
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(
input_size, inner_size, indices_size, updates_size, use_locking, input, indices, updates, output);
}
template void CalScatterAdd<float>(const int &input_size, const int &inner_size, const int &indices_size,
const float *input, const int *indices, const float *updates, float *output,
cudaStream_t cuda_stream);
const bool &use_locking, const float *input, const int *indices,
const float *updates, float *output, cudaStream_t cuda_stream);
template void CalScatterAdd<half>(const int &input_size, const int &inner_size, const int &indices_size,
const half *input, const int *indices, const half *updates, half *output,
cudaStream_t cuda_stream);
const bool &use_locking, const half *input, const int *indices, const half *updates,
half *output, cudaStream_t cuda_stream);
template void CalScatterAdd<int>(const int &input_size, const int &inner_size, const int &indices_size,
const bool &use_locking, const int *input, const int *indices, const int *updates,
int *output, cudaStream_t cuda_stream);

View File

@ -20,7 +20,7 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const T *input,
const int *indices, const T *updates, T *output, cudaStream_t cuda_stream);
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking,
const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_

View File

@ -24,9 +24,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# all cases tested against dchip
class TestScatterAddNet(nn.Cell):
def __init__(self, inputx, indices, updates):
def __init__(self, lock, inputx, indices, updates):
super(TestScatterAddNet, self).__init__()
self.scatter_add = P.ScatterAdd()
self.scatter_add = P.ScatterAdd(use_locking=lock)
self.inputx = Parameter(inputx, name="inputx")
self.indices = Parameter(indices, name="indices")
self.updates = Parameter(updates, name="updates")
@ -36,7 +36,13 @@ class TestScatterAddNet(nn.Cell):
return out
def scatter_add_net(inputx, indices, updates):
net = TestScatterAddNet(inputx, indices, updates)
lock = True
net = TestScatterAddNet(lock, inputx, indices, updates)
return net()
def scatter_add_use_locking_false_net(inputx, indices, updates):
lock = False
net = TestScatterAddNet(lock, inputx, indices, updates)
return net()
@pytest.mark.level0
@ -51,6 +57,52 @@ def test_scatter_add_small_float32():
[12., 14., 16.]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_large_shape_float32():
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
output = scatter_add_net(inputx, indices, updates)
expected = np.array([[[[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]],
[[[73., 74., 75., 76.],
[77., 78., 79., 80.],
[81., 82., 83., 84.]],
[[85., 86., 87., 88.],
[89., 90., 91., 92.],
[93., 94., 95., 96.]]],
[[[25., 26., 27., 28.],
[29., 30., 31., 32.],
[33., 34., 35., 36.]],
[[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]]],
[[[49., 50., 51., 52.],
[53., 54., 55., 56.],
[57., 58., 59., 60.]],
[[61., 62., 63., 64.],
[65., 66., 67., 68.],
[69., 70., 71., 72.]]]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_small_float32_use_locking_false():
inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
indices = Tensor(np.array([1, 0]).astype(np.int32))
updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
output = scatter_add_use_locking_false_net(inputx, indices, updates)
expected = np.array([[3., 4., 5.],
[0., 1., 2.]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -112,3 +164,35 @@ def test_scatter_add_disordered_float16():
[187., 188., 189., 190.],
[492., 496., 500., 504.]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_large_int32():
inputx = Tensor(np.zeros((2, 3, 4)).astype(np.int32))
indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
output = scatter_add_net(inputx, indices, updates)
expected = np.array([[[138., 140., 142., 144.],
[146., 148., 150., 152.],
[154., 156., 158., 160.]],
[[186., 188., 190., 192.],
[194., 196., 198., 200.],
[202., 204., 206., 208.]]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_disordered_int32():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
indices = Tensor(np.array([[[0, 1, 2],
[2, 1, 0]],
[[0, 0, 0],
[2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
output = scatter_add_net(inputx, indices, updates)
expected = np.array([[464., 468., 472., 476.],
[187., 188., 189., 190.],
[492., 496., 500., 504.]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)