Raise error when Cpu kernel ScatterNdUpdate indices out of range
This commit is contained in:
parent
7ee68462d3
commit
0719db6fca
mindspore/ccsrc/plugin/device/cpu/kernel
tests/st/ops/cpu
|
@ -28,7 +28,7 @@ constexpr size_t kMinIndiceRank = 2;
|
||||||
constexpr char kKernelName[] = "ScatterNdUpdate";
|
constexpr char kKernelName[] = "ScatterNdUpdate";
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
|
bool Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
|
||||||
MS_EXCEPTION_IF_NULL(params);
|
MS_EXCEPTION_IF_NULL(params);
|
||||||
T *x = params->x_;
|
T *x = params->x_;
|
||||||
int *indices = params->indices_;
|
int *indices = params->indices_;
|
||||||
|
@ -41,20 +41,30 @@ void Compute(const ComputeParams<T> *params, const size_t start, const size_t en
|
||||||
|
|
||||||
for (int i = SizeToInt(start); i < SizeToInt(end); ++i) {
|
for (int i = SizeToInt(start); i < SizeToInt(end); ++i) {
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
|
std::vector<int> local_indices;
|
||||||
for (int j = 0; j < params->indices_unit_rank_; ++j) {
|
for (int j = 0; j < params->indices_unit_rank_; ++j) {
|
||||||
auto index = indices[i * params->indices_unit_rank_ + j];
|
auto index = indices[i * params->indices_unit_rank_ + j];
|
||||||
|
(void)local_indices.emplace_back(index);
|
||||||
if (index < 0) {
|
if (index < 0) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kKernelName
|
MS_LOG(ERROR) << "For '" << kKernelName
|
||||||
<< "', each element in 'indices' should be greater than or equal to 0, but got " << index;
|
<< "', each element in 'indices' should be greater than or equal to 0, but got " << index;
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
offset += index * out_strides->at(j) * params->unit_size_;
|
offset += index * out_strides->at(j) * params->unit_size_;
|
||||||
}
|
}
|
||||||
auto ret = memcpy_s(x + offset, params->x_mem_size_ - offset, updates + params->unit_size_ * i,
|
if (offset * sizeof(T) > params->x_mem_size_) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kKernelName
|
||||||
|
<< "', indices out of range for input_x. Please check the indices which is " << local_indices;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto ret = memcpy_s(x + offset, params->x_mem_size_ - offset * sizeof(T), updates + params->unit_size_ * i,
|
||||||
params->unit_size_ * sizeof(T));
|
params->unit_size_ * sizeof(T));
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', memcpy_s error. Error no: " << ret;
|
MS_LOG(ERROR) << "For '" << kKernelName << "', memcpy_s error. Error no: " << ret;
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -153,18 +163,24 @@ void ScatterUpdateCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
|
||||||
|
|
||||||
std::vector<common::Task> tasks;
|
std::vector<common::Task> tasks;
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
|
int status = 0;
|
||||||
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||||
size_t once_compute_size = (num_units_ + max_thread_num - 1) / max_thread_num;
|
size_t once_compute_size = (num_units_ + max_thread_num - 1) / max_thread_num;
|
||||||
while (start < num_units_) {
|
while (start < num_units_) {
|
||||||
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
|
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
|
||||||
auto task = [¶ms, start, end]() {
|
auto task = [¶ms, start, end, &status]() {
|
||||||
Compute<T>(¶ms, start, end);
|
if (!Compute<T>(¶ms, start, end)) {
|
||||||
|
status = -1;
|
||||||
|
}
|
||||||
return common::SUCCESS;
|
return common::SUCCESS;
|
||||||
};
|
};
|
||||||
(void)tasks.emplace_back(task);
|
(void)tasks.emplace_back(task);
|
||||||
start += once_compute_size;
|
start += once_compute_size;
|
||||||
}
|
}
|
||||||
ParallelLaunch(tasks);
|
ParallelLaunch(tasks);
|
||||||
|
if (status == -1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Some errors occurred! The error message is as above";
|
||||||
|
}
|
||||||
(void)memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
|
(void)memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,9 +32,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||||
def test_op1(dtype):
|
def test_op1(dtype):
|
||||||
"""
|
"""
|
||||||
Feature: ALL TO ALL
|
Feature: Op ScatterNdUpdate
|
||||||
Description: test cases for updating float values
|
Description: test ScatterNdUpdate
|
||||||
Expectation: the result match scipy
|
Expectation: success
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class ScatterNdUpdate(nn.Cell):
|
class ScatterNdUpdate(nn.Cell):
|
||||||
|
@ -64,9 +64,9 @@ def test_op1(dtype):
|
||||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
||||||
def test_op2(dtype):
|
def test_op2(dtype):
|
||||||
"""
|
"""
|
||||||
Feature: ALL TO ALL
|
Feature: Op ScatterNdUpdate
|
||||||
Description: test cases for updating int values
|
Description: test ScatterNdUpdate
|
||||||
Expectation: the result match scipy
|
Expectation: success
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class ScatterNdUpdate(nn.Cell):
|
class ScatterNdUpdate(nn.Cell):
|
||||||
|
@ -96,9 +96,9 @@ def test_op2(dtype):
|
||||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
||||||
def test_op3(dtype):
|
def test_op3(dtype):
|
||||||
"""
|
"""
|
||||||
Feature: ALL TO ALL
|
Feature: Op ScatterNdUpdate
|
||||||
Description: test cases for updating int values
|
Description: test ScatterNdUpdate
|
||||||
Expectation: the result match scipy
|
Expectation: success
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class ScatterNdUpdate(nn.Cell):
|
class ScatterNdUpdate(nn.Cell):
|
||||||
|
@ -132,9 +132,9 @@ def test_op3(dtype):
|
||||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||||
def test_op4(dtype):
|
def test_op4(dtype):
|
||||||
"""
|
"""
|
||||||
Feature: ALL TO ALL
|
Feature: Op ScatterNdUpdate
|
||||||
Description: test cases for updating single float value
|
Description: test ScatterNdUpdate
|
||||||
Expectation: the result match scipy
|
Expectation: success
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class ScatterNdUpdate(nn.Cell):
|
class ScatterNdUpdate(nn.Cell):
|
||||||
|
@ -154,3 +154,32 @@ def test_op4(dtype):
|
||||||
print("x:\n", out)
|
print("x:\n", out)
|
||||||
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
|
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
|
||||||
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))
|
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||||
|
def test_op5(dtype):
|
||||||
|
"""
|
||||||
|
Feature: Op ScatterNdUpdate
|
||||||
|
Description: test ScatterNdUpdate with index out of range
|
||||||
|
Expectation: raise RuntimeError
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ScatterNdUpdate(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(ScatterNdUpdate, self).__init__()
|
||||||
|
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||||
|
self.x = Parameter(Tensor(np.ones([1, 4, 1], dtype=dtype)), name="x")
|
||||||
|
|
||||||
|
def construct(self, indices, update):
|
||||||
|
return self.scatter_nd_update(self.x, indices, update)
|
||||||
|
|
||||||
|
indices = Tensor(np.array([[0, 2], [3, 2], [1, 3]]), mstype.int32)
|
||||||
|
update = Tensor(np.array([[1], [1], [1]], dtype=dtype))
|
||||||
|
|
||||||
|
scatter_nd_update = ScatterNdUpdate()
|
||||||
|
with pytest.raises(RuntimeError) as errinfo:
|
||||||
|
scatter_nd_update(indices, update)
|
||||||
|
assert "Some errors occurred! The error message is as above" in str(errinfo.value)
|
||||||
|
|
Loading…
Reference in New Issue