Raise error when Cpu kernel ScatterNdUpdate indices out of range

This commit is contained in:
huanghui 2022-04-12 11:33:57 +08:00
parent 7ee68462d3
commit 0719db6fca
2 changed files with 64 additions and 19 deletions

View File

@ -28,7 +28,7 @@ constexpr size_t kMinIndiceRank = 2;
constexpr char kKernelName[] = "ScatterNdUpdate";
template <typename T>
void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
bool Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
MS_EXCEPTION_IF_NULL(params);
T *x = params->x_;
int *indices = params->indices_;
@ -41,20 +41,30 @@ void Compute(const ComputeParams<T> *params, const size_t start, const size_t en
for (int i = SizeToInt(start); i < SizeToInt(end); ++i) {
int offset = 0;
std::vector<int> local_indices;
for (int j = 0; j < params->indices_unit_rank_; ++j) {
auto index = indices[i * params->indices_unit_rank_ + j];
(void)local_indices.emplace_back(index);
if (index < 0) {
MS_LOG(EXCEPTION) << "For '" << kKernelName
<< "', each element in 'indices' should be greater than or equal to 0, but got " << index;
MS_LOG(ERROR) << "For '" << kKernelName
<< "', each element in 'indices' should be greater than or equal to 0, but got " << index;
return false;
}
offset += index * out_strides->at(j) * params->unit_size_;
}
auto ret = memcpy_s(x + offset, params->x_mem_size_ - offset, updates + params->unit_size_ * i,
if (offset * sizeof(T) > params->x_mem_size_) {
MS_LOG(ERROR) << "For '" << kKernelName
<< "', indices out of range for input_x. Please check the indices which is " << local_indices;
return false;
}
auto ret = memcpy_s(x + offset, params->x_mem_size_ - offset * sizeof(T), updates + params->unit_size_ * i,
params->unit_size_ * sizeof(T));
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', memcpy_s error. Error no: " << ret;
MS_LOG(ERROR) << "For '" << kKernelName << "', memcpy_s error. Error no: " << ret;
return false;
}
}
return true;
}
} // namespace
@ -153,18 +163,24 @@ void ScatterUpdateCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
std::vector<common::Task> tasks;
size_t start = 0;
int status = 0;
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
size_t once_compute_size = (num_units_ + max_thread_num - 1) / max_thread_num;
while (start < num_units_) {
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
auto task = [&params, start, end]() {
Compute<T>(&params, start, end);
auto task = [&params, start, end, &status]() {
if (!Compute<T>(&params, start, end)) {
status = -1;
}
return common::SUCCESS;
};
(void)tasks.emplace_back(task);
start += once_compute_size;
}
ParallelLaunch(tasks);
if (status == -1) {
MS_LOG(EXCEPTION) << "Some errors occurred! The error message is as above";
}
(void)memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
}

View File

@ -32,9 +32,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_op1(dtype):
"""
Feature: ALL TO ALL
Description: test cases for updating float values
Expectation: the result match scipy
Feature: Op ScatterNdUpdate
Description: test ScatterNdUpdate
Expectation: success
"""
class ScatterNdUpdate(nn.Cell):
@ -64,9 +64,9 @@ def test_op1(dtype):
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
def test_op2(dtype):
"""
Feature: ALL TO ALL
Description: test cases for updating int values
Expectation: the result match scipy
Feature: Op ScatterNdUpdate
Description: test ScatterNdUpdate
Expectation: success
"""
class ScatterNdUpdate(nn.Cell):
@ -96,9 +96,9 @@ def test_op2(dtype):
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
def test_op3(dtype):
"""
Feature: ALL TO ALL
Description: test cases for updating int values
Expectation: the result match scipy
Feature: Op ScatterNdUpdate
Description: test ScatterNdUpdate
Expectation: success
"""
class ScatterNdUpdate(nn.Cell):
@ -132,9 +132,9 @@ def test_op3(dtype):
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_op4(dtype):
"""
Feature: ALL TO ALL
Description: test cases for updating single float value
Expectation: the result match scipy
Feature: Op ScatterNdUpdate
Description: test ScatterNdUpdate
Expectation: success
"""
class ScatterNdUpdate(nn.Cell):
@ -154,3 +154,32 @@ def test_op4(dtype):
print("x:\n", out)
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_op5(dtype):
"""
Feature: Op ScatterNdUpdate
Description: test ScatterNdUpdate with index out of range
Expectation: raise RuntimeError
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.ones([1, 4, 1], dtype=dtype)), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0, 2], [3, 2], [1, 3]]), mstype.int32)
update = Tensor(np.array([[1], [1], [1]], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
with pytest.raises(RuntimeError) as errinfo:
scatter_nd_update(indices, update)
assert "Some errors occurred! The error message is as above" in str(errinfo.value)