forked from mindspore-Ecosystem/mindspore
add tensor scatter update support different type inputs
This commit is contained in:
parent
fcac556d58
commit
e3820ad2cb
|
@ -25,7 +25,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
TensorScatterUpdateGpuFwdKernel, half, int)
|
TensorScatterUpdateGpuFwdKernel, half, int)
|
||||||
|
|
||||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
@ -33,7 +32,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
TensorScatterUpdateGpuFwdKernel, float, int)
|
TensorScatterUpdateGpuFwdKernel, float, int)
|
||||||
|
|
||||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
@ -41,7 +39,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
TensorScatterUpdateGpuFwdKernel, double, int)
|
TensorScatterUpdateGpuFwdKernel, double, int)
|
||||||
|
|
||||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeInt8)
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
@ -49,7 +46,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
.AddInputAttr(kNumberTypeInt8)
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
.AddOutputAttr(kNumberTypeInt8),
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
TensorScatterUpdateGpuFwdKernel, char, int)
|
TensorScatterUpdateGpuFwdKernel, char, int)
|
||||||
|
|
||||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
KernelAttr()
|
KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeUInt8)
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
@ -57,5 +53,19 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
.AddInputAttr(kNumberTypeUInt8)
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
.AddOutputAttr(kNumberTypeUInt8),
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
TensorScatterUpdateGpuFwdKernel, uchar, int)
|
TensorScatterUpdateGpuFwdKernel, uchar, int)
|
||||||
|
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
TensorScatterUpdateGpuFwdKernel, int, int)
|
||||||
|
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
TensorScatterUpdateGpuFwdKernel, double, int64_t)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -48,12 +48,11 @@ __global__ void TensorScatterUpdateKernel(T *input, S *indices, T *update, T *ou
|
||||||
|
|
||||||
template <typename T, typename S>
|
template <typename T, typename S>
|
||||||
void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
||||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
|
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||||
S *work_shape, cudaStream_t stream) {
|
S *indices_stride, S *work_shape, cudaStream_t stream) {
|
||||||
TensorScatterUpdateKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(input, indices, update, output,
|
TensorScatterUpdateKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(
|
||||||
block_size, input_size, output_size,
|
input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride,
|
||||||
indices_dim_0, indices_dim_1,
|
work_shape);
|
||||||
indices_stride, work_shape);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,13 +76,18 @@ template void TensorScatterUpdate<char, int>(char *input, int *indices, char *up
|
||||||
const size_t &output_size, const size_t &indices_dim_0,
|
const size_t &output_size, const size_t &indices_dim_0,
|
||||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream);
|
||||||
template void TensorScatterUpdate<int, int>(int *input, int *indices, int *update, int *output,
|
|
||||||
const size_t &block_size, const size_t &input_size,
|
|
||||||
const size_t &output_size, const size_t &indices_dim_0,
|
|
||||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
|
||||||
cudaStream_t stream);
|
|
||||||
template void TensorScatterUpdate<unsigned char, int>(unsigned char *input, int *indices, unsigned char *update,
|
template void TensorScatterUpdate<unsigned char, int>(unsigned char *input, int *indices, unsigned char *update,
|
||||||
unsigned char *output, const size_t &block_size,
|
unsigned char *output, const size_t &block_size,
|
||||||
const size_t &input_size, const size_t &output_size,
|
const size_t &input_size, const size_t &output_size,
|
||||||
const size_t &indices_dim_0, const size_t &indices_dim_1,
|
const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||||
int *indices_stride, int *work_shape, cudaStream_t stream);
|
int *indices_stride, int *work_shape, cudaStream_t stream);
|
||||||
|
template void TensorScatterUpdate<int, int>(int *input, int *indices, int *update, int *output,
|
||||||
|
const size_t &block_size, const size_t &input_size,
|
||||||
|
const size_t &output_size, const size_t &indices_dim_0,
|
||||||
|
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||||
|
cudaStream_t stream);
|
||||||
|
template void TensorScatterUpdate<double, int64_t>(double *input, int64_t *indices, double *update, double *output,
|
||||||
|
const size_t &block_size, const size_t &input_size,
|
||||||
|
const size_t &output_size, const size_t &indices_dim_0,
|
||||||
|
const size_t &indices_dim_1, int64_t *indices_stride,
|
||||||
|
int64_t *work_shape, cudaStream_t stream);
|
||||||
|
|
|
@ -21,6 +21,6 @@
|
||||||
|
|
||||||
template <typename T, typename S>
|
template <typename T, typename S>
|
||||||
void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
||||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
|
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||||
S *work_shape, cudaStream_t stream);
|
S *indices_stride, S *work_shape, cudaStream_t stream);
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
@ -32,6 +33,10 @@ def scatter_net(x, indices, update):
|
||||||
scatter = Net()
|
scatter = Net()
|
||||||
return scatter(Tensor(x), Tensor(indices), Tensor(update)).asnumpy()
|
return scatter(Tensor(x), Tensor(indices), Tensor(update)).asnumpy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
def test_scatter():
|
def test_scatter():
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
@ -77,3 +82,43 @@ def test_scatter():
|
||||||
[15, 99, 17, 44, 19],
|
[15, 99, 17, 44, 19],
|
||||||
[100, 21, 22, 23, 55]]).astype(np.float32)
|
[100, 21, 22, 23, 55]]).astype(np.float32)
|
||||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||||
|
|
||||||
|
arr_input = np.arange(25).reshape(5, 5).astype(np.float64)
|
||||||
|
arr_indices = np.array([[[0, 0],
|
||||||
|
[1, 1],
|
||||||
|
[2, 2],
|
||||||
|
[3, 3],
|
||||||
|
[4, 4]],
|
||||||
|
[[0, 4],
|
||||||
|
[1, 3],
|
||||||
|
[2, 2],
|
||||||
|
[3, 1],
|
||||||
|
[4, 0]]]).astype(np.int64)
|
||||||
|
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float64)
|
||||||
|
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||||
|
expected = np.array([[11, 1, 2, 3, 66],
|
||||||
|
[5, 22, 7, 77, 9],
|
||||||
|
[10, 11, 33, 13, 14],
|
||||||
|
[15, 99, 17, 44, 19],
|
||||||
|
[100, 21, 22, 23, 55]]).astype(np.float64)
|
||||||
|
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||||
|
|
||||||
|
arr_input = np.arange(25).reshape(5, 5).astype(np.int32)
|
||||||
|
arr_indices = np.array([[[0, 0],
|
||||||
|
[1, 1],
|
||||||
|
[2, 2],
|
||||||
|
[3, 3],
|
||||||
|
[4, 4]],
|
||||||
|
[[0, 4],
|
||||||
|
[1, 3],
|
||||||
|
[2, 2],
|
||||||
|
[3, 1],
|
||||||
|
[4, 0]]]).astype(np.int32)
|
||||||
|
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.int32)
|
||||||
|
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||||
|
expected = np.array([[11, 1, 2, 3, 66],
|
||||||
|
[5, 22, 7, 77, 9],
|
||||||
|
[10, 11, 33, 13, 14],
|
||||||
|
[15, 99, 17, 44, 19],
|
||||||
|
[100, 21, 22, 23, 55]]).astype(np.int32)
|
||||||
|
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||||
|
|
Loading…
Reference in New Issue