forked from mindspore-Ecosystem/mindspore
!8141 Improve performance for GPU-ScatterAdd, add use_locking and add int32 support
Merge pull request !8141 from 34bunny/GPU-ScatterAdd
This commit is contained in:
commit
4c2344ed35
|
@ -32,5 +32,12 @@ MS_REG_GPU_KERNEL_ONE(ScatterAdd,
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
ScatterAddKernel, half)
|
ScatterAddKernel, half)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ScatterAdd,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
ScatterAddKernel, int)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace kernel {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class ScatterAddKernel : public GpuKernel {
|
class ScatterAddKernel : public GpuKernel {
|
||||||
public:
|
public:
|
||||||
ScatterAddKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0) {}
|
ScatterAddKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0), use_locking_(true) {}
|
||||||
~ScatterAddKernel() override = default;
|
~ScatterAddKernel() override = default;
|
||||||
|
|
||||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
@ -40,7 +40,10 @@ class ScatterAddKernel : public GpuKernel {
|
||||||
int *indices = GetDeviceAddress<int>(inputs, 1);
|
int *indices = GetDeviceAddress<int>(inputs, 1);
|
||||||
T *updates = GetDeviceAddress<T>(inputs, 2);
|
T *updates = GetDeviceAddress<T>(inputs, 2);
|
||||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||||
CalScatterAdd(input_size_, inner_size_, indices_size_, input, indices, updates, output,
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||||
|
"cudaMemcpyAsync output failed");
|
||||||
|
CalScatterAdd(input_size_, inner_size_, indices_size_, use_locking_, input, indices, updates, output,
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -69,6 +72,7 @@ class ScatterAddKernel : public GpuKernel {
|
||||||
indices_size_ *= indices_shape[i];
|
indices_size_ *= indices_shape[i];
|
||||||
}
|
}
|
||||||
updates_size_ = indices_size_ * inner_size_;
|
updates_size_ = indices_size_ * inner_size_;
|
||||||
|
use_locking_ = GetAttr<bool>(kernel_node, "use_locking");
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -86,6 +90,7 @@ class ScatterAddKernel : public GpuKernel {
|
||||||
int inner_size_;
|
int inner_size_;
|
||||||
int indices_size_;
|
int indices_size_;
|
||||||
int updates_size_;
|
int updates_size_;
|
||||||
|
bool use_locking_;
|
||||||
std::vector<size_t> input_size_list_;
|
std::vector<size_t> input_size_list_;
|
||||||
std::vector<size_t> output_size_list_;
|
std::vector<size_t> output_size_list_;
|
||||||
std::vector<size_t> workspace_size_list_;
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
|
|
@ -14,32 +14,39 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||||
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh"
|
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void ScatterAdd(const int input_size, const int inner_size, const int indices_size, const T *input,
|
__global__ void ScatterAdd(const int input_size, const int inner_size, const int indices_size, const int updates_size,
|
||||||
const int *indices, const T *updates, T *output) {
|
const bool use_locking, const T *input, const int *indices, const T *updates, T *output) {
|
||||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) {
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||||
output[pos] = input[pos];
|
output[pos] = input[pos];
|
||||||
const size_t index = pos / inner_size;
|
const size_t index = pos / inner_size;
|
||||||
const size_t offset = pos % inner_size;
|
const size_t offset = pos % inner_size;
|
||||||
for (size_t i = 0; i < indices_size; i++) {
|
const size_t current_pos = indices[index] * inner_size + offset;
|
||||||
const T value = updates[i*inner_size+offset];
|
if (use_locking) {
|
||||||
output[pos] += (indices[i] == index ? value : static_cast<T>(0.0));
|
MsAtomicAdd(&output[current_pos], updates[pos]);
|
||||||
|
} else {
|
||||||
|
output[current_pos] += updates[pos];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const T *input,
|
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking,
|
||||||
const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) {
|
const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) {
|
||||||
ScatterAdd<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, inner_size, indices_size, input,
|
const int updates_size = inner_size * indices_size;
|
||||||
indices, updates, output);
|
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(
|
||||||
|
input_size, inner_size, indices_size, updates_size, use_locking, input, indices, updates, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
template void CalScatterAdd<float>(const int &input_size, const int &inner_size, const int &indices_size,
|
template void CalScatterAdd<float>(const int &input_size, const int &inner_size, const int &indices_size,
|
||||||
const float *input, const int *indices, const float *updates, float *output,
|
const bool &use_locking, const float *input, const int *indices,
|
||||||
cudaStream_t cuda_stream);
|
const float *updates, float *output, cudaStream_t cuda_stream);
|
||||||
template void CalScatterAdd<half>(const int &input_size, const int &inner_size, const int &indices_size,
|
template void CalScatterAdd<half>(const int &input_size, const int &inner_size, const int &indices_size,
|
||||||
const half *input, const int *indices, const half *updates, half *output,
|
const bool &use_locking, const half *input, const int *indices, const half *updates,
|
||||||
cudaStream_t cuda_stream);
|
half *output, cudaStream_t cuda_stream);
|
||||||
|
template void CalScatterAdd<int>(const int &input_size, const int &inner_size, const int &indices_size,
|
||||||
|
const bool &use_locking, const int *input, const int *indices, const int *updates,
|
||||||
|
int *output, cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include "runtime/device/gpu/cuda_common.h"
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const T *input,
|
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking,
|
||||||
const int *indices, const T *updates, T *output, cudaStream_t cuda_stream);
|
const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_
|
||||||
|
|
|
@ -24,9 +24,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
# all cases tested against dchip
|
# all cases tested against dchip
|
||||||
|
|
||||||
class TestScatterAddNet(nn.Cell):
|
class TestScatterAddNet(nn.Cell):
|
||||||
def __init__(self, inputx, indices, updates):
|
def __init__(self, lock, inputx, indices, updates):
|
||||||
super(TestScatterAddNet, self).__init__()
|
super(TestScatterAddNet, self).__init__()
|
||||||
self.scatter_add = P.ScatterAdd()
|
self.scatter_add = P.ScatterAdd(use_locking=lock)
|
||||||
self.inputx = Parameter(inputx, name="inputx")
|
self.inputx = Parameter(inputx, name="inputx")
|
||||||
self.indices = Parameter(indices, name="indices")
|
self.indices = Parameter(indices, name="indices")
|
||||||
self.updates = Parameter(updates, name="updates")
|
self.updates = Parameter(updates, name="updates")
|
||||||
|
@ -36,7 +36,13 @@ class TestScatterAddNet(nn.Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def scatter_add_net(inputx, indices, updates):
|
def scatter_add_net(inputx, indices, updates):
|
||||||
net = TestScatterAddNet(inputx, indices, updates)
|
lock = True
|
||||||
|
net = TestScatterAddNet(lock, inputx, indices, updates)
|
||||||
|
return net()
|
||||||
|
|
||||||
|
def scatter_add_use_locking_false_net(inputx, indices, updates):
|
||||||
|
lock = False
|
||||||
|
net = TestScatterAddNet(lock, inputx, indices, updates)
|
||||||
return net()
|
return net()
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -51,6 +57,52 @@ def test_scatter_add_small_float32():
|
||||||
[12., 14., 16.]])
|
[12., 14., 16.]])
|
||||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_scatter_add_large_shape_float32():
|
||||||
|
inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
|
||||||
|
indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
|
||||||
|
updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
|
||||||
|
output = scatter_add_net(inputx, indices, updates)
|
||||||
|
expected = np.array([[[[1., 2., 3., 4.],
|
||||||
|
[5., 6., 7., 8.],
|
||||||
|
[9., 10., 11., 12.]],
|
||||||
|
[[13., 14., 15., 16.],
|
||||||
|
[17., 18., 19., 20.],
|
||||||
|
[21., 22., 23., 24.]]],
|
||||||
|
[[[73., 74., 75., 76.],
|
||||||
|
[77., 78., 79., 80.],
|
||||||
|
[81., 82., 83., 84.]],
|
||||||
|
[[85., 86., 87., 88.],
|
||||||
|
[89., 90., 91., 92.],
|
||||||
|
[93., 94., 95., 96.]]],
|
||||||
|
[[[25., 26., 27., 28.],
|
||||||
|
[29., 30., 31., 32.],
|
||||||
|
[33., 34., 35., 36.]],
|
||||||
|
[[37., 38., 39., 40.],
|
||||||
|
[41., 42., 43., 44.],
|
||||||
|
[45., 46., 47., 48.]]],
|
||||||
|
[[[49., 50., 51., 52.],
|
||||||
|
[53., 54., 55., 56.],
|
||||||
|
[57., 58., 59., 60.]],
|
||||||
|
[[61., 62., 63., 64.],
|
||||||
|
[65., 66., 67., 68.],
|
||||||
|
[69., 70., 71., 72.]]]])
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_scatter_add_small_float32_use_locking_false():
|
||||||
|
inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
|
||||||
|
indices = Tensor(np.array([1, 0]).astype(np.int32))
|
||||||
|
updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
|
||||||
|
output = scatter_add_use_locking_false_net(inputx, indices, updates)
|
||||||
|
expected = np.array([[3., 4., 5.],
|
||||||
|
[0., 1., 2.]])
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
|
@ -112,3 +164,35 @@ def test_scatter_add_disordered_float16():
|
||||||
[187., 188., 189., 190.],
|
[187., 188., 189., 190.],
|
||||||
[492., 496., 500., 504.]])
|
[492., 496., 500., 504.]])
|
||||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_scatter_add_large_int32():
|
||||||
|
inputx = Tensor(np.zeros((2, 3, 4)).astype(np.int32))
|
||||||
|
indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))
|
||||||
|
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
|
||||||
|
output = scatter_add_net(inputx, indices, updates)
|
||||||
|
expected = np.array([[[138., 140., 142., 144.],
|
||||||
|
[146., 148., 150., 152.],
|
||||||
|
[154., 156., 158., 160.]],
|
||||||
|
[[186., 188., 190., 192.],
|
||||||
|
[194., 196., 198., 200.],
|
||||||
|
[202., 204., 206., 208.]]]).astype(np.int32)
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_scatter_add_disordered_int32():
|
||||||
|
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
|
||||||
|
indices = Tensor(np.array([[[0, 1, 2],
|
||||||
|
[2, 1, 0]],
|
||||||
|
[[0, 0, 0],
|
||||||
|
[2, 2, 2]]]).astype(np.int32))
|
||||||
|
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
|
||||||
|
output = scatter_add_net(inputx, indices, updates)
|
||||||
|
expected = np.array([[464., 468., 472., 476.],
|
||||||
|
[187., 188., 189., 190.],
|
||||||
|
[492., 496., 500., 504.]]).astype(np.int32)
|
||||||
|
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||||
|
|
Loading…
Reference in New Issue