forked from mindspore-Ecosystem/mindspore
add tensor scatter update support different type inputs
This commit is contained in:
parent
fcac556d58
commit
e3820ad2cb
|
@ -25,7 +25,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
|||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
TensorScatterUpdateGpuFwdKernel, half, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -33,7 +32,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TensorScatterUpdateGpuFwdKernel, float, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
|
@ -41,7 +39,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
|||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
TensorScatterUpdateGpuFwdKernel, double, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
|
@ -49,7 +46,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
|||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
TensorScatterUpdateGpuFwdKernel, char, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
|
@ -57,5 +53,19 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
|||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
TensorScatterUpdateGpuFwdKernel, uchar, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
TensorScatterUpdateGpuFwdKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
TensorScatterUpdateGpuFwdKernel, double, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,12 +48,11 @@ __global__ void TensorScatterUpdateKernel(T *input, S *indices, T *update, T *ou
|
|||
|
||||
template <typename T, typename S>
|
||||
void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
|
||||
S *work_shape, cudaStream_t stream) {
|
||||
TensorScatterUpdateKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(input, indices, update, output,
|
||||
block_size, input_size, output_size,
|
||||
indices_dim_0, indices_dim_1,
|
||||
indices_stride, work_shape);
|
||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||
S *indices_stride, S *work_shape, cudaStream_t stream) {
|
||||
TensorScatterUpdateKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(
|
||||
input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride,
|
||||
work_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -77,13 +76,18 @@ template void TensorScatterUpdate<char, int>(char *input, int *indices, char *up
|
|||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void TensorScatterUpdate<int, int>(int *input, int *indices, int *update, int *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void TensorScatterUpdate<unsigned char, int>(unsigned char *input, int *indices, unsigned char *update,
|
||||
unsigned char *output, const size_t &block_size,
|
||||
const size_t &input_size, const size_t &output_size,
|
||||
const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||
int *indices_stride, int *work_shape, cudaStream_t stream);
|
||||
template void TensorScatterUpdate<int, int>(int *input, int *indices, int *update, int *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void TensorScatterUpdate<double, int64_t>(double *input, int64_t *indices, double *update, double *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int64_t *indices_stride,
|
||||
int64_t *work_shape, cudaStream_t stream);
|
||||
|
|
|
@ -21,6 +21,6 @@
|
|||
|
||||
template <typename T, typename S>
|
||||
void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
|
||||
S *work_shape, cudaStream_t stream);
|
||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||
S *indices_stride, S *work_shape, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
|
@ -32,6 +33,10 @@ def scatter_net(x, indices, update):
|
|||
scatter = Net()
|
||||
return scatter(Tensor(x), Tensor(indices), Tensor(update)).asnumpy()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
@ -77,3 +82,43 @@ def test_scatter():
|
|||
[15, 99, 17, 44, 19],
|
||||
[100, 21, 22, 23, 55]]).astype(np.float32)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
arr_input = np.arange(25).reshape(5, 5).astype(np.float64)
|
||||
arr_indices = np.array([[[0, 0],
|
||||
[1, 1],
|
||||
[2, 2],
|
||||
[3, 3],
|
||||
[4, 4]],
|
||||
[[0, 4],
|
||||
[1, 3],
|
||||
[2, 2],
|
||||
[3, 1],
|
||||
[4, 0]]]).astype(np.int64)
|
||||
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float64)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = np.array([[11, 1, 2, 3, 66],
|
||||
[5, 22, 7, 77, 9],
|
||||
[10, 11, 33, 13, 14],
|
||||
[15, 99, 17, 44, 19],
|
||||
[100, 21, 22, 23, 55]]).astype(np.float64)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
arr_input = np.arange(25).reshape(5, 5).astype(np.int32)
|
||||
arr_indices = np.array([[[0, 0],
|
||||
[1, 1],
|
||||
[2, 2],
|
||||
[3, 3],
|
||||
[4, 4]],
|
||||
[[0, 4],
|
||||
[1, 3],
|
||||
[2, 2],
|
||||
[3, 1],
|
||||
[4, 0]]]).astype(np.int32)
|
||||
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.int32)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = np.array([[11, 1, 2, 3, 66],
|
||||
[5, 22, 7, 77, 9],
|
||||
[10, 11, 33, 13, 14],
|
||||
[15, 99, 17, 44, 19],
|
||||
[100, 21, 22, 23, 55]]).astype(np.int32)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
|
Loading…
Reference in New Issue