Raise error when Cpu kernel ScatterNdUpdate indices out of range
This commit is contained in:
parent
7ee68462d3
commit
0719db6fca
|
@ -28,7 +28,7 @@ constexpr size_t kMinIndiceRank = 2;
|
|||
constexpr char kKernelName[] = "ScatterNdUpdate";
|
||||
|
||||
template <typename T>
|
||||
void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
|
||||
bool Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
|
||||
MS_EXCEPTION_IF_NULL(params);
|
||||
T *x = params->x_;
|
||||
int *indices = params->indices_;
|
||||
|
@ -41,20 +41,30 @@ void Compute(const ComputeParams<T> *params, const size_t start, const size_t en
|
|||
|
||||
for (int i = SizeToInt(start); i < SizeToInt(end); ++i) {
|
||||
int offset = 0;
|
||||
std::vector<int> local_indices;
|
||||
for (int j = 0; j < params->indices_unit_rank_; ++j) {
|
||||
auto index = indices[i * params->indices_unit_rank_ + j];
|
||||
(void)local_indices.emplace_back(index);
|
||||
if (index < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kKernelName
|
||||
<< "', each element in 'indices' should be greater than or equal to 0, but got " << index;
|
||||
MS_LOG(ERROR) << "For '" << kKernelName
|
||||
<< "', each element in 'indices' should be greater than or equal to 0, but got " << index;
|
||||
return false;
|
||||
}
|
||||
offset += index * out_strides->at(j) * params->unit_size_;
|
||||
}
|
||||
auto ret = memcpy_s(x + offset, params->x_mem_size_ - offset, updates + params->unit_size_ * i,
|
||||
if (offset * sizeof(T) > params->x_mem_size_) {
|
||||
MS_LOG(ERROR) << "For '" << kKernelName
|
||||
<< "', indices out of range for input_x. Please check the indices which is " << local_indices;
|
||||
return false;
|
||||
}
|
||||
auto ret = memcpy_s(x + offset, params->x_mem_size_ - offset * sizeof(T), updates + params->unit_size_ * i,
|
||||
params->unit_size_ * sizeof(T));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kKernelName << "', memcpy_s error. Error no: " << ret;
|
||||
MS_LOG(ERROR) << "For '" << kKernelName << "', memcpy_s error. Error no: " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -153,18 +163,24 @@ void ScatterUpdateCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
|
|||
|
||||
std::vector<common::Task> tasks;
|
||||
size_t start = 0;
|
||||
int status = 0;
|
||||
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
size_t once_compute_size = (num_units_ + max_thread_num - 1) / max_thread_num;
|
||||
while (start < num_units_) {
|
||||
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
|
||||
auto task = [¶ms, start, end]() {
|
||||
Compute<T>(¶ms, start, end);
|
||||
auto task = [¶ms, start, end, &status]() {
|
||||
if (!Compute<T>(¶ms, start, end)) {
|
||||
status = -1;
|
||||
}
|
||||
return common::SUCCESS;
|
||||
};
|
||||
(void)tasks.emplace_back(task);
|
||||
start += once_compute_size;
|
||||
}
|
||||
ParallelLaunch(tasks);
|
||||
if (status == -1) {
|
||||
MS_LOG(EXCEPTION) << "Some errors occurred! The error message is as above";
|
||||
}
|
||||
(void)memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
|
||||
}
|
||||
|
||||
|
|
|
@ -32,9 +32,9 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
|||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_op1(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for updating float values
|
||||
Expectation: the result match scipy
|
||||
Feature: Op ScatterNdUpdate
|
||||
Description: test ScatterNdUpdate
|
||||
Expectation: success
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
|
@ -64,9 +64,9 @@ def test_op1(dtype):
|
|||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
||||
def test_op2(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for updating int values
|
||||
Expectation: the result match scipy
|
||||
Feature: Op ScatterNdUpdate
|
||||
Description: test ScatterNdUpdate
|
||||
Expectation: success
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
|
@ -96,9 +96,9 @@ def test_op2(dtype):
|
|||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
||||
def test_op3(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for updating int values
|
||||
Expectation: the result match scipy
|
||||
Feature: Op ScatterNdUpdate
|
||||
Description: test ScatterNdUpdate
|
||||
Expectation: success
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
|
@ -132,9 +132,9 @@ def test_op3(dtype):
|
|||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_op4(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for updating single float value
|
||||
Expectation: the result match scipy
|
||||
Feature: Op ScatterNdUpdate
|
||||
Description: test ScatterNdUpdate
|
||||
Expectation: success
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
|
@ -154,3 +154,32 @@ def test_op4(dtype):
|
|||
print("x:\n", out)
|
||||
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
|
||||
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_op5(dtype):
|
||||
"""
|
||||
Feature: Op ScatterNdUpdate
|
||||
Description: test ScatterNdUpdate with index out of range
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.ones([1, 4, 1], dtype=dtype)), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
indices = Tensor(np.array([[0, 2], [3, 2], [1, 3]]), mstype.int32)
|
||||
update = Tensor(np.array([[1], [1], [1]], dtype=dtype))
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
with pytest.raises(RuntimeError) as errinfo:
|
||||
scatter_nd_update(indices, update)
|
||||
assert "Some errors occurred! The error message is as above" in str(errinfo.value)
|
||||
|
|
Loading…
Reference in New Issue