add tensor scatter update support different type inputs

This commit is contained in:
TFBunny 2021-04-13 14:35:47 -04:00
parent fcac556d58
commit e3820ad2cb
4 changed files with 76 additions and 17 deletions

View File

@ -25,7 +25,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
TensorScatterUpdateGpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
@ -33,7 +32,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TensorScatterUpdateGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
@ -41,7 +39,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
TensorScatterUpdateGpuFwdKernel, double, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
@ -49,7 +46,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
TensorScatterUpdateGpuFwdKernel, char, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
@ -57,5 +53,19 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
TensorScatterUpdateGpuFwdKernel, uchar, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
TensorScatterUpdateGpuFwdKernel, int, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
TensorScatterUpdateGpuFwdKernel, double, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -48,12 +48,11 @@ __global__ void TensorScatterUpdateKernel(T *input, S *indices, T *update, T *ou
template <typename T, typename S>
void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
S *work_shape, cudaStream_t stream) {
TensorScatterUpdateKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(input, indices, update, output,
block_size, input_size, output_size,
indices_dim_0, indices_dim_1,
indices_stride, work_shape);
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
S *indices_stride, S *work_shape, cudaStream_t stream) {
TensorScatterUpdateKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(
input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride,
work_shape);
return;
}
@ -77,13 +76,18 @@ template void TensorScatterUpdate<char, int>(char *input, int *indices, char *up
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void TensorScatterUpdate<int, int>(int *input, int *indices, int *update, int *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void TensorScatterUpdate<unsigned char, int>(unsigned char *input, int *indices, unsigned char *update,
unsigned char *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1,
int *indices_stride, int *work_shape, cudaStream_t stream);
template void TensorScatterUpdate<int, int>(int *input, int *indices, int *update, int *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void TensorScatterUpdate<double, int64_t>(double *input, int64_t *indices, double *update, double *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride,
int64_t *work_shape, cudaStream_t stream);

View File

@ -21,6 +21,6 @@
template <typename T, typename S>
void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
S *work_shape, cudaStream_t stream);
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
S *indices_stride, S *work_shape, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
@ -32,6 +33,10 @@ def scatter_net(x, indices, update):
scatter = Net()
return scatter(Tensor(x), Tensor(indices), Tensor(update)).asnumpy()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -77,3 +82,43 @@ def test_scatter():
[15, 99, 17, 44, 19],
[100, 21, 22, 23, 55]]).astype(np.float32)
np.testing.assert_allclose(out, expected, rtol=1e-6)
arr_input = np.arange(25).reshape(5, 5).astype(np.float64)
arr_indices = np.array([[[0, 0],
[1, 1],
[2, 2],
[3, 3],
[4, 4]],
[[0, 4],
[1, 3],
[2, 2],
[3, 1],
[4, 0]]]).astype(np.int64)
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float64)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = np.array([[11, 1, 2, 3, 66],
[5, 22, 7, 77, 9],
[10, 11, 33, 13, 14],
[15, 99, 17, 44, 19],
[100, 21, 22, 23, 55]]).astype(np.float64)
np.testing.assert_allclose(out, expected, rtol=1e-6)
arr_input = np.arange(25).reshape(5, 5).astype(np.int32)
arr_indices = np.array([[[0, 0],
[1, 1],
[2, 2],
[3, 3],
[4, 4]],
[[0, 4],
[1, 3],
[2, 2],
[3, 1],
[4, 0]]]).astype(np.int32)
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.int32)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = np.array([[11, 1, 2, 3, 66],
[5, 22, 7, 77, 9],
[10, 11, 33, 13, 14],
[15, 99, 17, 44, 19],
[100, 21, 22, 23, 55]]).astype(np.int32)
np.testing.assert_allclose(out, expected, rtol=1e-6)