!35960 fix bug in ScatterNdFunctor, make it able to modify input Tensor

Merge pull request !35960 from zhujingxuan/fix_bug_scatter
This commit is contained in:
i-robot 2022-06-15 12:13:12 +00:00 committed by Gitee
commit b8961f9345
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 69 additions and 84 deletions

View File

@ -62,12 +62,13 @@ bool ScatterNdFunctorGPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &i
"For 'ScatterNdFunctorGPUKernelMod', cudaMemcpyAsync failed in ScatterNdFunctorGpuFwdKernel::LaunchKernel.")
}
CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices, updates,
input, device_id_, cuda_stream);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream),
cudaMemcpyAsync(output, input, inputs[0]->size, cudaMemcpyDeviceToDevice, cuda_stream),
"For 'ScatterNdFunctorGPUKernelMod', cudaMemcpyAsync output failed")
CalScatterNdFunctor(scatter_nd_functor_type_, unit_size_, num_units_, index_depth_, indices_stride, indices, updates,
output, device_id_, cuda_stream);
return true;
}
@ -85,6 +86,11 @@ bool ScatterNdFunctorGPUKernelMod::Init(const BaseOperatorPtr &base_operator,
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
if (scatter_nd_functor_type_ != SCATTER_ND_FUNC_UPDATE && (inputs[kIndex0]->GetDtype() == kNumberTypeBool)) {
const auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
return true;
}
@ -119,8 +125,6 @@ int ScatterNdFunctorGPUKernelMod::Resize(const BaseOperatorPtr &base_operator,
<< indices_shape.size();
}
input_size_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies{});
unit_size_ = 1;
for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) {
unit_size_ *= SizeToInt(updates_shape[i]);
@ -168,20 +172,26 @@ const std::vector<std::pair<KernelAttr, KernelRunFunc>> &ScatterNdFunctorGPUKern
DTYPE_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt64, int64_t, int),
DTYPE_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t, int64_t),
// Data type: uint64
DTYPE_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t, int),
DTYPE_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t, int64_t),
// Data type: int
DTYPE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int, int),
DTYPE_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt32, int, int64_t),
DTYPE_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeUInt64, kNumberTypeUInt64, uint64_t, int),
DTYPE_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, kNumberTypeUInt64, uint64_t, int64_t),
// Data type: int32_t
DTYPE_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t, int),
DTYPE_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt32, int32_t, int64_t),
// Data type: uint32_t
DTYPE_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeUInt32, kNumberTypeUInt32, uint32_t, int),
DTYPE_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, kNumberTypeUInt32, uint32_t, int64_t),
// Data type: int16_t
DTYPE_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt16, kNumberTypeInt16, int16_t, int),
DTYPE_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt16, kNumberTypeInt16, int16_t, int64_t),
// Data type: uint8_t
DTYPE_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, kNumberTypeUInt8, uint8_t, int),
DTYPE_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt8, uint8_t, int64_t),
// Data type: uint16_t
DTYPE_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeUInt16, kNumberTypeUInt16, uint16_t, int),
DTYPE_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, kNumberTypeUInt16, uint16_t, int64_t),
// Data type: int8_t
DTYPE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt8, kNumberTypeInt8, int8_t, int),
DTYPE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt8, kNumberTypeInt8, int8_t, int64_t),
// Data type: uint8_t
DTYPE_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, kNumberTypeUInt8, uint8_t, int),
DTYPE_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt8, uint8_t, int64_t),
// Data type: bool, only for scatter_nd_update
DTYPE_REGISTER(kNumberTypeBool, kNumberTypeInt32, kNumberTypeBool, kNumberTypeBool, bool, int),
DTYPE_REGISTER(kNumberTypeBool, kNumberTypeInt64, kNumberTypeBool, kNumberTypeBool, bool, int64_t),

View File

@ -58,7 +58,6 @@ class ScatterNdFunctorGPUKernelMod : public NativeGpuKernelMod, public MatchKern
const std::vector<AddressPtr> &outputs);
ScatterNdFunctorType scatter_nd_functor_type_;
size_t input_size_{0};
size_t unit_size_{0};
size_t num_units_{0};

View File

@ -174,6 +174,18 @@ template CUDA_LIB_EXPORT void CalScatterNdFunctor<int32_t, int32_t>(enum Scatter
const int32_t *out_strides, const int32_t *indices,
const int32_t *updates, int32_t *input,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalScatterNdFunctor<uint32_t, int64_t>(enum ScatterNdFunctorType func_type,
const size_t &unit_size, const size_t &num_units,
const size_t &index_depth,
const int64_t *out_strides, const int64_t *indices,
const uint32_t *updates, uint32_t *input,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalScatterNdFunctor<uint32_t, int32_t>(enum ScatterNdFunctorType func_type,
const size_t &unit_size, const size_t &num_units,
const size_t &index_depth,
const int32_t *out_strides, const int32_t *indices,
const uint32_t *updates, uint32_t *input,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalScatterNdFunctor<int16_t, int64_t>(enum ScatterNdFunctorType func_type,
const size_t &unit_size, const size_t &num_units,
const size_t &index_depth,
@ -186,6 +198,18 @@ template CUDA_LIB_EXPORT void CalScatterNdFunctor<int16_t, int32_t>(enum Scatter
const int32_t *out_strides, const int32_t *indices,
const int16_t *updates, int16_t *input,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalScatterNdFunctor<uint16_t, int64_t>(enum ScatterNdFunctorType func_type,
const size_t &unit_size, const size_t &num_units,
const size_t &index_depth,
const int64_t *out_strides, const int64_t *indices,
const uint16_t *updates, uint16_t *input,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalScatterNdFunctor<uint16_t, int32_t>(enum ScatterNdFunctorType func_type,
const size_t &unit_size, const size_t &num_units,
const size_t &index_depth,
const int32_t *out_strides, const int32_t *indices,
const uint16_t *updates, uint16_t *input,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalScatterNdFunctor<uint8_t, int64_t>(enum ScatterNdFunctorType func_type,
const size_t &unit_size, const size_t &num_units,
const size_t &index_depth,

View File

@ -48,10 +48,10 @@ np_func_map = {
class TestScatterNdFuncNet(nn.Cell):
def __init__(self, func, lock, inputx, indices, updates):
def __init__(self, func, inputx, indices, updates):
super(TestScatterNdFuncNet, self).__init__()
self.scatter_func = func_map[func](use_locking=lock)
self.scatter_func = func_map[func](use_locking=True)
self.inputx = Parameter(inputx, name="inputx")
self.indices = Parameter(indices, name="indices")
self.updates = Parameter(updates, name="updates")
@ -74,8 +74,8 @@ def scatter_nd_func_np(func, inputx, indices, updates):
return result
def compare_scatter_nd_func(func, lock, inputx, indices, updates):
output = TestScatterNdFuncNet(func, lock, inputx, indices, updates)()
def compare_scatter_nd_func(func, inputx, indices, updates):
output = TestScatterNdFuncNet(func, inputx, indices, updates)()
expected = scatter_nd_func_np(func, inputx, indices, updates)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@ -83,13 +83,12 @@ def compare_scatter_nd_func(func, lock, inputx, indices, updates):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('lock', [True, False])
@pytest.mark.parametrize('func', ['update', 'add', 'sub', 'div', 'mul', 'max', 'min'])
@pytest.mark.parametrize('data_type',
[mstype.uint8, mstype.int8, mstype.int16, mstype.int32, mstype.float16, mstype.float32,
mstype.float64])
@pytest.mark.parametrize('index_type', [mstype.int32])
def test_scatter_nd_func_small(lock, func, data_type, index_type):
[mstype.uint8, mstype.int8, mstype.uint16, mstype.int16, mstype.uint32, mstype.int32,
mstype.uint64, mstype.int64, mstype.float16, mstype.float32, mstype.float64])
@pytest.mark.parametrize('index_type', [mstype.int32, mstype.int64])
def test_scatter_nd_func_small(func, data_type, index_type):
"""
Feature: ALL To ALL
Description: test cases for small input of ScatterNd* like functions
@ -99,14 +98,13 @@ def test_scatter_nd_func_small(lock, func, data_type, index_type):
indices = Tensor(np.array([[0, 0], [1, 1]]), index_type)
updates = Tensor(np.array([1.0, 2.2]), data_type)
compare_scatter_nd_func(func, lock, inputx, indices, updates)
compare_scatter_nd_func(func, inputx, indices, updates)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('lock', [True, False])
def test_scatter_nd_func_small_update(lock):
def test_scatter_nd_func_small_update():
"""
Feature: ALL To ALL
Description: test cases for bool input of ScatterNdUpdate
@ -116,63 +114,18 @@ def test_scatter_nd_func_small_update(lock):
indices = Tensor(np.array([[False], [True], [False], [True]]), mstype.int32)
updates = Tensor(np.array([9, 10, 11, 12]), mstype.bool_)
compare_scatter_nd_func("update", lock, inputx, indices, updates)
compare_scatter_nd_func("update", inputx, indices, updates)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('lock', [True, False])
@pytest.mark.parametrize('func', ['update', 'add', 'sub', 'div', 'mul', 'max', 'min'])
@pytest.mark.parametrize('data_type',
[mstype.uint8, mstype.int8, mstype.int16, mstype.int32, mstype.float16, mstype.float32,
mstype.float64])
@pytest.mark.parametrize('index_type', [mstype.int32])
def test_scatter_nd_func_small_int(lock, func, data_type, index_type):
"""
Feature: ALL To ALL
Description: test cases for int input of ScatterNd* like functions
Expectation: the result match to numpy implementation
"""
inputx = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), data_type)
indices = Tensor(np.array([[4], [3], [1], [7]]), index_type)
updates = Tensor(np.array([9, 10, 11, 12]), data_type)
compare_scatter_nd_func(func, lock, inputx, indices, updates)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('lock', [True, False])
@pytest.mark.parametrize('func', ['update', 'add', 'sub', 'div', 'mul', 'max', 'min'])
@pytest.mark.parametrize('data_type',
[mstype.uint8, mstype.int8, mstype.int16, mstype.int32, mstype.float16, mstype.float32,
mstype.float64])
@pytest.mark.parametrize('index_type', [mstype.int32])
def test_scatter_nd_func_small_negative(lock, func, data_type, index_type):
"""
Feature: ALL To ALL
Description: test cases for negative input of ScatterNd* like functions
Expectation: the result match to numpy implementation
"""
inputx = Tensor(np.array([-1, -2, -3, -4, -5, -6, -7, -8]), data_type)
indices = Tensor(np.array([[4], [3], [1], [7]]), index_type)
updates = Tensor(np.array([9, -10, 11, -12]), data_type)
compare_scatter_nd_func(func, lock, inputx, indices, updates)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('lock', [True, False])
@pytest.mark.parametrize('func', ['update', 'add', 'sub', 'div', 'mul', 'max', 'min'])
@pytest.mark.parametrize('data_type',
[mstype.uint8, mstype.int8, mstype.int16, mstype.int32, mstype.float16, mstype.float32,
mstype.float64])
@pytest.mark.parametrize('index_type', [mstype.int32])
def test_scatter_nd_func_multi_dims(lock, func, data_type, index_type):
[mstype.uint8, mstype.int8, mstype.uint16, mstype.int16, mstype.uint32, mstype.int32,
mstype.uint64, mstype.int64, mstype.float16, mstype.float32, mstype.float64])
@pytest.mark.parametrize('index_type', [mstype.int32, mstype.int64])
def test_scatter_nd_func_multi_dims(func, data_type, index_type):
"""
Feature: ALL To ALL
Description: test cases for multi-dims input of ScatterNd* like functions
@ -190,19 +143,18 @@ def test_scatter_nd_func_multi_dims(lock, func, data_type, index_type):
data_type,
)
compare_scatter_nd_func(func, lock, inputx, indices, updates)
compare_scatter_nd_func(func, inputx, indices, updates)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('lock', [True, False])
@pytest.mark.parametrize('func', ['update', 'add', 'sub', 'div', 'mul', 'max', 'min'])
@pytest.mark.parametrize('data_type',
[mstype.uint8, mstype.int8, mstype.int16, mstype.int32, mstype.float16, mstype.float32,
mstype.float64])
@pytest.mark.parametrize('index_type', [mstype.int32])
def test_scatter_nd_func_one_value(lock, func, data_type, index_type):
[mstype.uint8, mstype.int8, mstype.uint16, mstype.int16, mstype.uint32, mstype.int32,
mstype.uint64, mstype.int64, mstype.float16, mstype.float32, mstype.float64])
@pytest.mark.parametrize('index_type', [mstype.int32, mstype.int64])
def test_scatter_nd_func_one_value(func, data_type, index_type):
"""
Feature: ALL To ALL
Description: test cases for one value modification of ScatterNd* like functions
@ -212,7 +164,7 @@ def test_scatter_nd_func_one_value(lock, func, data_type, index_type):
indices = Tensor(np.array([[0, 1]]), index_type)
updates = Tensor(np.array([1.0]), data_type)
compare_scatter_nd_func(func, lock, inputx, indices, updates)
compare_scatter_nd_func(func, inputx, indices, updates)
@pytest.mark.level0
@ -232,4 +184,4 @@ def test_scatter_nd_div_division_by_zero(data_type, index_type):
indices = Tensor(np.array([[0, 0], [1, 1]]), index_type)
updates = Tensor(np.array([0, 2]), data_type)
compare_scatter_nd_func('div', False, inputx, indices, updates)
compare_scatter_nd_func('div', inputx, indices, updates)