add additional type support to ScatterAdd

This commit is contained in:
TFbunny 2020-12-10 17:17:48 -05:00
parent 8ddb10fd8a
commit 43755a71e7
3 changed files with 51 additions and 0 deletions

View File

@ -39,5 +39,19 @@ MS_REG_GPU_KERNEL_ONE(ScatterAdd,
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterAddKernel, int)
MS_REG_GPU_KERNEL_ONE(ScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterAddKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(ScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterAddKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

View File

@ -42,3 +42,8 @@ template void CalScatterAdd<half>(const size_t &inner_size, const size_t &indice
const half *updates, half *input, cudaStream_t cuda_stream);
template void CalScatterAdd<int>(const size_t &inner_size, const size_t &indices_size, const int *indices,
const int *updates, int *input, cudaStream_t cuda_stream);
template void CalScatterAdd<unsigned char>(const size_t &inner_size, const size_t &indices_size, const int *indices,
const unsigned char *updates, unsigned char *input,
cudaStream_t cuda_stream);
template void CalScatterAdd<int8_t>(const size_t &inner_size, const size_t &indices_size, const int *indices,
const int8_t *updates, int8_t *input, cudaStream_t cuda_stream);

View File

@ -269,6 +269,38 @@ def test_scatter_add_disordered_dynamic_int32():
[492., 496., 500., 504.]]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_disordered_dynamic_int8():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8)))
indices = Tensor(np.array([[[0, 1, 2],
[2, 1, 0]],
[[0, 0, 0],
[2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int8))
output = scatter_add_d_net(inputx, indices, updates)
expected = np.array([[464., 468., 472., 476.],
[187., 188., 189., 190.],
[492., 496., 500., 504.]]).astype(np.int8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_disordered_dynamic_uint8():
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8)))
indices = Tensor(np.array([[[0, 1, 2],
[2, 1, 0]],
[[0, 0, 0],
[2, 2, 2]]]).astype(np.int32))
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.uint8))
output = scatter_add_d_net(inputx, indices, updates)
expected = np.array([[464., 468., 472., 476.],
[187., 188., 189., 190.],
[492., 496., 500., 504.]]).astype(np.uint8)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard