forked from mindspore-Ecosystem/mindspore
add additional type support to ScatterAdd
This commit is contained in:
parent
8ddb10fd8a
commit
43755a71e7
|
@ -39,5 +39,19 @@ MS_REG_GPU_KERNEL_ONE(ScatterAdd,
|
|||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
ScatterAddKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(ScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
ScatterAddKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(ScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
ScatterAddKernel, uint8_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,3 +42,8 @@ template void CalScatterAdd<half>(const size_t &inner_size, const size_t &indice
|
|||
const half *updates, half *input, cudaStream_t cuda_stream);
|
||||
template void CalScatterAdd<int>(const size_t &inner_size, const size_t &indices_size, const int *indices,
|
||||
const int *updates, int *input, cudaStream_t cuda_stream);
|
||||
template void CalScatterAdd<unsigned char>(const size_t &inner_size, const size_t &indices_size, const int *indices,
|
||||
const unsigned char *updates, unsigned char *input,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalScatterAdd<int8_t>(const size_t &inner_size, const size_t &indices_size, const int *indices,
|
||||
const int8_t *updates, int8_t *input, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -269,6 +269,38 @@ def test_scatter_add_disordered_dynamic_int32():
|
|||
[492., 496., 500., 504.]]).astype(np.int32)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_add_disordered_dynamic_int8():
|
||||
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8)))
|
||||
indices = Tensor(np.array([[[0, 1, 2],
|
||||
[2, 1, 0]],
|
||||
[[0, 0, 0],
|
||||
[2, 2, 2]]]).astype(np.int32))
|
||||
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int8))
|
||||
output = scatter_add_d_net(inputx, indices, updates)
|
||||
expected = np.array([[464., 468., 472., 476.],
|
||||
[187., 188., 189., 190.],
|
||||
[492., 496., 500., 504.]]).astype(np.int8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_add_disordered_dynamic_uint8():
|
||||
inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8)))
|
||||
indices = Tensor(np.array([[[0, 1, 2],
|
||||
[2, 1, 0]],
|
||||
[[0, 0, 0],
|
||||
[2, 2, 2]]]).astype(np.int32))
|
||||
updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.uint8))
|
||||
output = scatter_add_d_net(inputx, indices, updates)
|
||||
expected = np.array([[464., 468., 472., 476.],
|
||||
[187., 188., 189., 190.],
|
||||
[492., 496., 500., 504.]]).astype(np.uint8)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue