update input for ScatterUpdate/Add, fix doc text and fix dynamic shape

This commit is contained in:
TFbunny 2020-12-02 12:53:03 -05:00
parent 0856639fc5
commit 27a602f067
9 changed files with 104 additions and 88 deletions

View File

@ -40,10 +40,10 @@ class ScatterAddKernel : public GpuKernel {
int *indices = GetDeviceAddress<int>(inputs, 1); int *indices = GetDeviceAddress<int>(inputs, 1);
T *updates = GetDeviceAddress<T>(inputs, 2); T *updates = GetDeviceAddress<T>(inputs, 2);
T *output = GetDeviceAddress<T>(outputs, 0); T *output = GetDeviceAddress<T>(outputs, 0);
CalScatterAdd(inner_size_, indices_size_, indices, updates, input, reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)), reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed"); "cudaMemcpyAsync output failed");
CalScatterAdd(inner_size_, indices_size_, indices, updates, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }

View File

@ -40,10 +40,10 @@ class ScatterUpdateKernel : public GpuKernel {
int *indices = GetDeviceAddress<int>(inputs, 1); int *indices = GetDeviceAddress<int>(inputs, 1);
T *updates = GetDeviceAddress<T>(inputs, 2); T *updates = GetDeviceAddress<T>(inputs, 2);
T *output = GetDeviceAddress<T>(outputs, 0); T *output = GetDeviceAddress<T>(outputs, 0);
CalScatterUpdate(inner_size_, indices_size_, indices, updates, input, reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)), reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed"); "cudaMemcpyAsync output failed");
CalScatterUpdate(inner_size_, indices_size_, indices, updates, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }

View File

@ -19,26 +19,26 @@
template <typename T> template <typename T>
__global__ void ScatterAdd(const int inner_size, const int updates_size, const int *indices, const T *updates, __global__ void ScatterAdd(const int inner_size, const int updates_size, const int *indices, const T *updates,
T *output) { T *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size; const size_t index = pos / inner_size;
const size_t offset = pos % inner_size; const size_t offset = pos % inner_size;
const size_t current_pos = indices[index] * inner_size + offset; const size_t current_pos = indices[index] * inner_size + offset;
MsAtomicAdd(&output[current_pos], updates[pos]); MsAtomicAdd(&input[current_pos], updates[pos]);
} }
} }
template <typename T> template <typename T>
void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input,
cudaStream_t cuda_stream) { cudaStream_t cuda_stream) {
const int updates_size = inner_size * indices_size; const int updates_size = inner_size * indices_size;
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates, ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates,
output); input);
} }
template void CalScatterAdd<float>(const int &inner_size, const int &indices_size, const int *indices, template void CalScatterAdd<float>(const int &inner_size, const int &indices_size, const int *indices,
const float *updates, float *output, cudaStream_t cuda_stream); const float *updates, float *input, cudaStream_t cuda_stream);
template void CalScatterAdd<half>(const int &inner_size, const int &indices_size, const int *indices, template void CalScatterAdd<half>(const int &inner_size, const int &indices_size, const int *indices,
const half *updates, half *output, cudaStream_t cuda_stream); const half *updates, half *input, cudaStream_t cuda_stream);
template void CalScatterAdd<int>(const int &inner_size, const int &indices_size, const int *indices, const int *updates, template void CalScatterAdd<int>(const int &inner_size, const int &indices_size, const int *indices, const int *updates,
int *output, cudaStream_t cuda_stream); int *input, cudaStream_t cuda_stream);

View File

@ -20,7 +20,7 @@
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_

View File

@ -18,31 +18,31 @@
template <typename T> template <typename T>
__global__ void ScatterUpdate(const int inner_size, const int updates_size, const int *indices, const T *updates, __global__ void ScatterUpdate(const int inner_size, const int updates_size, const int *indices, const T *updates,
T *output) { T *input) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const int index = pos / inner_size; const int index = pos / inner_size;
const int offset = pos % inner_size; const int offset = pos % inner_size;
const int current_pos = indices[index] * inner_size + offset; const int current_pos = indices[index] * inner_size + offset;
output[current_pos] = updates[pos]; input[current_pos] = updates[pos];
} }
} }
template <typename T> template <typename T>
void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input,
cudaStream_t cuda_stream) { cudaStream_t cuda_stream) {
const int updates_size = inner_size * indices_size; const int updates_size = inner_size * indices_size;
ScatterUpdate<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates, ScatterUpdate<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates,
output); input);
} }
template void CalScatterUpdate<float>(const int &inner_size, const int &indices_size, const int *indices, template void CalScatterUpdate<float>(const int &inner_size, const int &indices_size, const int *indices,
const float *updates, float *output, cudaStream_t cuda_stream); const float *updates, float *input, cudaStream_t cuda_stream);
template void CalScatterUpdate<half>(const int &inner_size, const int &indices_size, const int *indices, template void CalScatterUpdate<half>(const int &inner_size, const int &indices_size, const int *indices,
const half *updates, half *output, cudaStream_t cuda_stream); const half *updates, half *input, cudaStream_t cuda_stream);
template void CalScatterUpdate<int>(const int &inner_size, const int &indices_size, const int *indices, template void CalScatterUpdate<int>(const int &inner_size, const int &indices_size, const int *indices,
const int *updates, int *output, cudaStream_t cuda_stream); const int *updates, int *input, cudaStream_t cuda_stream);
template void CalScatterUpdate<unsigned char>(const int &inner_size, const int &indices_size, const int *indices, template void CalScatterUpdate<unsigned char>(const int &inner_size, const int &indices_size, const int *indices,
const unsigned char *updates, unsigned char *output, const unsigned char *updates, unsigned char *input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CalScatterUpdate<int8_t>(const int &inner_size, const int &indices_size, const int *indices, template void CalScatterUpdate<int8_t>(const int &inner_size, const int &indices_size, const int *indices,
const int8_t *updates, int8_t *output, cudaStream_t cuda_stream); const int8_t *updates, int8_t *input, cudaStream_t cuda_stream);

View File

@ -20,7 +20,7 @@
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_

View File

@ -73,13 +73,23 @@ class _ScatterOp_Dynamic(PrimitiveWithCheck):
""" """
Defines Scatter operators with dynamic shape Defines Scatter operators with dynamic shape
""" """
__mindspore_signature__ = (
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
sig.make_sig('updates', dtype=sig.sig_dtype.T)
)
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
if np.all(np.array(x_shape) != -1): # x_shape cannot be dynamic
if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: if np.any(np.array(x_shape) == -1):
raise ValueError(f"For '{prim_name}', " raise ValueError(f"x does not support dynamic shape")
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " # support indices and updates dynamic
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") if np.any(np.array(indices_shape) == -1) or np.any(np.array(updates_shape) == -1):
pass
elif indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=False): def __init__(self, use_locking=False):
@ -3176,7 +3186,7 @@ class ScatterUpdate(_ScatterOp_Dynamic):
Tensor, has the same shape and type as `input_x`. Tensor, has the same shape and type as `input_x`.
Supported Platforms: Supported Platforms:
``Ascend`` ``Ascend`` ``GPU``
Examples: Examples:
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]) >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])

View File

@ -56,8 +56,9 @@ class TestScatterAddDynamicNet(nn.Cell):
self.updates = Parameter(updates, name="updates") self.updates = Parameter(updates, name="updates")
def construct(self): def construct(self):
out = self.test_dynamic(self.inputx) indices = self.test_dynamic(self.indices)
out = self.scatter_add(out, self.indices, self.updates) updates = self.test_dynamic(self.updates)
out = self.scatter_add(self.inputx, indices, updates)
return out return out
def scatter_add_d_net(inputx, indices, updates): def scatter_add_d_net(inputx, indices, updates):
@ -66,22 +67,24 @@ def scatter_add_d_net(inputx, indices, updates):
return net() return net()
class TestScatterAddDynamicNet2(nn.Cell): class TestScatterAddDynamicNet2(nn.Cell):
def __init__(self): def __init__(self, inputx):
super(TestScatterAddDynamicNet2, self).__init__() super(TestScatterAddDynamicNet2, self).__init__()
self.scatter_add = P.ScatterAdd() self.scatter_add = P.ScatterAdd()
self.test_dynamic = inner.GpuConvertToDynamicShape() self.test_dynamic = inner.GpuConvertToDynamicShape()
self.inputx = Parameter(inputx, name="inputx")
def construct(self, inputx, indices, updates): def construct(self, indices, updates):
out = self.test_dynamic(inputx) indices = self.test_dynamic(indices)
out = self.scatter_add(out, indices, updates) updates = self.test_dynamic(updates)
out = self.scatter_add(self.inputx, indices, updates)
return out return out
def scatter_add_d2_net(inputx_1, indices_1, updates_1, inputx_2, def scatter_add_d2_net(inputx, indices_1, updates_1,
indices_2, updates_2): indices_2, updates_2):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = TestScatterAddDynamicNet2() net = TestScatterAddDynamicNet2(inputx)
out1 = net(inputx_1, indices_1, updates_1) out1 = net(indices_1, updates_1)
out2 = net(inputx_2, indices_2, updates_2) out2 = net(indices_2, updates_2)
return (out1, out2) return (out1, out2)
@pytest.mark.level0 @pytest.mark.level0
@ -96,6 +99,20 @@ def test_scatter_add_small_float32():
[12., 14., 16.]]) [12., 14., 16.]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected) np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_add_input_updated():
inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
lock = True
net = TestScatterAddNet(lock, inputx, indices, updates)
net()
expected = np.array([[6., 8., 10.],
[12., 14., 16.]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -274,39 +291,16 @@ def test_scatter_add_input_less_than_1_dynamic_float32():
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_scatter_add_dynamic_two_inputs(): def test_scatter_add_dynamic_two_inputs():
inputx_1 = Tensor(np.zeros((2, 3)).astype(np.float32)) inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
indices_1 = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) indices_1 = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
updates_1 = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) updates_1 = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
inputx_2 = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) indices_2 = Tensor(np.array([[0, 0], [1, 1], [1, 0]]).astype(np.int32))
indices_2 = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) updates_2 = Tensor(np.flip(np.arange(18).reshape((3, 2, 3)).astype(np.float32)))
updates_2 = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) output_1, output_2 = scatter_add_d2_net(inputx, indices_1, updates_1,
output_1, output_2 = scatter_add_d2_net(inputx_1, indices_1, updates_1, indices_2, updates_2)
inputx_2, indices_2, updates_2)
expected_1 = np.array([[6., 8., 10.], expected_1 = np.array([[6., 8., 10.],
[12., 14., 16.]]) [12., 14., 16.]])
expected_2 = np.array([[[[1., 2., 3., 4.], expected_2 = np.array([[39., 38., 37.],
[5., 6., 7., 8.], [36., 35., 34.]])
[9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]],
[[[73., 74., 75., 76.],
[77., 78., 79., 80.],
[81., 82., 83., 84.]],
[[85., 86., 87., 88.],
[89., 90., 91., 92.],
[93., 94., 95., 96.]]],
[[[25., 26., 27., 28.],
[29., 30., 31., 32.],
[33., 34., 35., 36.]],
[[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]]],
[[[49., 50., 51., 52.],
[53., 54., 55., 56.],
[57., 58., 59., 60.]],
[[61., 62., 63., 64.],
[65., 66., 67., 68.],
[69., 70., 71., 72.]]]])
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1) np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2) np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)

View File

@ -50,8 +50,9 @@ class TestScatterUpdateDynamicNet(nn.Cell):
self.updates = Parameter(updates, name="updates") self.updates = Parameter(updates, name="updates")
def construct(self): def construct(self):
out = self.test_dynamic(self.inputx) indices = self.test_dynamic(self.indices)
out = self.scatter_update(out, self.indices, self.updates) updates = self.test_dynamic(self.updates)
out = self.scatter_update(self.inputx, indices, updates)
return out return out
def scatter_update_d_net(inputx, indices, updates): def scatter_update_d_net(inputx, indices, updates):
@ -60,22 +61,24 @@ def scatter_update_d_net(inputx, indices, updates):
return net() return net()
class TestScatterUpdateDynamicNet2(nn.Cell): class TestScatterUpdateDynamicNet2(nn.Cell):
def __init__(self): def __init__(self, inputx):
super(TestScatterUpdateDynamicNet2, self).__init__() super(TestScatterUpdateDynamicNet2, self).__init__()
self.scatter_update = P.ScatterUpdate() self.scatter_update = P.ScatterUpdate()
self.test_dynamic = inner.GpuConvertToDynamicShape() self.test_dynamic = inner.GpuConvertToDynamicShape()
self.inputx = Parameter(inputx, name="inputx")
def construct(self, inputx, indices, updates): def construct(self, indices, updates):
out = self.test_dynamic(inputx) indices = self.test_dynamic(indices)
out = self.scatter_update(out, indices, updates) updates = self.test_dynamic(updates)
out = self.scatter_update(self.inputx, indices, updates)
return out return out
def scatter_update_d2_net(inputx_1, indices_1, updates_1, inputx_2, def scatter_update_d2_net(inputx, indices_1, updates_1,
indices_2, updates_2): indices_2, updates_2):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = TestScatterUpdateDynamicNet2() net = TestScatterUpdateDynamicNet2(inputx)
out1 = net(inputx_1, indices_1, updates_1) out1 = net(indices_1, updates_1)
out2 = net(inputx_2, indices_2, updates_2) out2 = net(indices_2, updates_2)
return (out1, out2) return (out1, out2)
@pytest.mark.level0 @pytest.mark.level0
@ -90,6 +93,19 @@ def test_scatter_update_small_float32():
[3., 4., 5.]]) [3., 4., 5.]])
np.testing.assert_array_almost_equal(output.asnumpy(), expected) np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_update_input_updated():
inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
indices = Tensor(np.array([0, 1]).astype(np.int32))
updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
net = TestScatterUpdateNet(inputx, indices, updates)
net()
expected = np.array([[0., 1., 2.],
[3., 4., 5.]])
np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -328,20 +344,16 @@ def test_scatter_update_disordered_dynamic_int32():
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_scatter_update_two_inputs(): def test_scatter_update_two_inputs():
inputx_1 = Tensor(np.zeros((2, 3)).astype(np.float32)) inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
indices_1 = Tensor(np.array([0, 1]).astype(np.int32)) indices_1 = Tensor(np.array([0, 1]).astype(np.int32))
updates_1 = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) updates_1 = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
inputx_2 = Tensor(np.array([[0.214141, 0.415151, 0.51516], indices_2 = Tensor(np.array([1]).astype(np.int32))
[0.876542, 0.451611, 0.55112], updates_2 = Tensor(np.arange(34, 37).reshape((1, 3)).astype(np.float32))
[0.111244, 0.633333, 0.34444]]).astype(np.float32)) output_1, output_2 = scatter_update_d2_net(inputx, indices_1, updates_1,
indices_2 = Tensor(np.array([1, 0, 2]).astype(np.int32)) indices_2, updates_2)
updates_2 = Tensor(np.arange(34, 43).reshape((3, 3)).astype(np.float32))
output_1, output_2 = scatter_update_d2_net(inputx_1, indices_1, updates_1,
inputx_2, indices_2, updates_2)
expected_1 = np.array([[0., 1., 2.], expected_1 = np.array([[0., 1., 2.],
[3., 4., 5.]]) [3., 4., 5.]], dtype=np.float32)
expected_2 = np.array([[37., 38., 39.], expected_2 = np.array([[0., 1., 2.],
[34., 35., 36.], [34., 35., 36.]], dtype=np.float32)
[40., 41., 42.]], dtype=np.float32)
np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1) np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2) np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)