Raise error when Cpu kernel ScatterNdUpdate indices out of range

This commit is contained in:
huanghui 2022-04-12 11:33:57 +08:00
parent 7ee68462d3
commit 0719db6fca
2 changed files with 64 additions and 19 deletions
mindspore/ccsrc/plugin/device/cpu/kernel
tests/st/ops/cpu

View File

@ -28,7 +28,7 @@ constexpr size_t kMinIndiceRank = 2;
constexpr char kKernelName[] = "ScatterNdUpdate"; constexpr char kKernelName[] = "ScatterNdUpdate";
template <typename T> template <typename T>
void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) { bool Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
MS_EXCEPTION_IF_NULL(params); MS_EXCEPTION_IF_NULL(params);
T *x = params->x_; T *x = params->x_;
int *indices = params->indices_; int *indices = params->indices_;
@ -41,20 +41,30 @@ void Compute(const ComputeParams<T> *params, const size_t start, const size_t en
for (int i = SizeToInt(start); i < SizeToInt(end); ++i) { for (int i = SizeToInt(start); i < SizeToInt(end); ++i) {
int offset = 0; int offset = 0;
std::vector<int> local_indices;
for (int j = 0; j < params->indices_unit_rank_; ++j) { for (int j = 0; j < params->indices_unit_rank_; ++j) {
auto index = indices[i * params->indices_unit_rank_ + j]; auto index = indices[i * params->indices_unit_rank_ + j];
(void)local_indices.emplace_back(index);
if (index < 0) { if (index < 0) {
MS_LOG(EXCEPTION) << "For '" << kKernelName MS_LOG(ERROR) << "For '" << kKernelName
<< "', each element in 'indices' should be greater than or equal to 0, but got " << index; << "', each element in 'indices' should be greater than or equal to 0, but got " << index;
return false;
} }
offset += index * out_strides->at(j) * params->unit_size_; offset += index * out_strides->at(j) * params->unit_size_;
} }
auto ret = memcpy_s(x + offset, params->x_mem_size_ - offset, updates + params->unit_size_ * i, if (offset * sizeof(T) > params->x_mem_size_) {
MS_LOG(ERROR) << "For '" << kKernelName
<< "', indices out of range for input_x. Please check the indices which is " << local_indices;
return false;
}
auto ret = memcpy_s(x + offset, params->x_mem_size_ - offset * sizeof(T), updates + params->unit_size_ * i,
params->unit_size_ * sizeof(T)); params->unit_size_ * sizeof(T));
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', memcpy_s error. Error no: " << ret; MS_LOG(ERROR) << "For '" << kKernelName << "', memcpy_s error. Error no: " << ret;
return false;
} }
} }
return true;
} }
} // namespace } // namespace
@ -153,18 +163,24 @@ void ScatterUpdateCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
std::vector<common::Task> tasks; std::vector<common::Task> tasks;
size_t start = 0; size_t start = 0;
int status = 0;
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
size_t once_compute_size = (num_units_ + max_thread_num - 1) / max_thread_num; size_t once_compute_size = (num_units_ + max_thread_num - 1) / max_thread_num;
while (start < num_units_) { while (start < num_units_) {
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size); size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
auto task = [&params, start, end]() { auto task = [&params, start, end, &status]() {
Compute<T>(&params, start, end); if (!Compute<T>(&params, start, end)) {
status = -1;
}
return common::SUCCESS; return common::SUCCESS;
}; };
(void)tasks.emplace_back(task); (void)tasks.emplace_back(task);
start += once_compute_size; start += once_compute_size;
} }
ParallelLaunch(tasks); ParallelLaunch(tasks);
if (status == -1) {
MS_LOG(EXCEPTION) << "Some errors occurred! The error message is as above";
}
(void)memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size); (void)memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
} }

View File

@ -32,9 +32,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@pytest.mark.parametrize('dtype', [np.float32, np.float64]) @pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_op1(dtype): def test_op1(dtype):
""" """
Feature: ALL TO ALL Feature: Op ScatterNdUpdate
Description: test cases for updating float values Description: test ScatterNdUpdate
Expectation: the result match scipy Expectation: success
""" """
class ScatterNdUpdate(nn.Cell): class ScatterNdUpdate(nn.Cell):
@ -64,9 +64,9 @@ def test_op1(dtype):
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) @pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
def test_op2(dtype): def test_op2(dtype):
""" """
Feature: ALL TO ALL Feature: Op ScatterNdUpdate
Description: test cases for updating int values Description: test ScatterNdUpdate
Expectation: the result match scipy Expectation: success
""" """
class ScatterNdUpdate(nn.Cell): class ScatterNdUpdate(nn.Cell):
@ -96,9 +96,9 @@ def test_op2(dtype):
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) @pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
def test_op3(dtype): def test_op3(dtype):
""" """
Feature: ALL TO ALL Feature: Op ScatterNdUpdate
Description: test cases for updating int values Description: test ScatterNdUpdate
Expectation: the result match scipy Expectation: success
""" """
class ScatterNdUpdate(nn.Cell): class ScatterNdUpdate(nn.Cell):
@ -132,9 +132,9 @@ def test_op3(dtype):
@pytest.mark.parametrize('dtype', [np.float32, np.float64]) @pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_op4(dtype): def test_op4(dtype):
""" """
Feature: ALL TO ALL Feature: Op ScatterNdUpdate
Description: test cases for updating single float value Description: test ScatterNdUpdate
Expectation: the result match scipy Expectation: success
""" """
class ScatterNdUpdate(nn.Cell): class ScatterNdUpdate(nn.Cell):
@ -154,3 +154,32 @@ def test_op4(dtype):
print("x:\n", out) print("x:\n", out)
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]] expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype)) assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_op5(dtype):
"""
Feature: Op ScatterNdUpdate
Description: test ScatterNdUpdate with index out of range
Expectation: raise RuntimeError
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.ones([1, 4, 1], dtype=dtype)), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0, 2], [3, 2], [1, 3]]), mstype.int32)
update = Tensor(np.array([[1], [1], [1]], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
with pytest.raises(RuntimeError) as errinfo:
scatter_nd_update(indices, update)
assert "Some errors occurred! The error message is as above" in str(errinfo.value)