Fix bug of ScatterNd* operator in CPU.
This commit is contained in:
parent
07ff4be83a
commit
644c83a645
|
@ -151,8 +151,8 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
|
|||
}
|
||||
target = output;
|
||||
}
|
||||
bool out_bound = false;
|
||||
auto task = [this, &compute_func, &target, &indices, &updates, &out_bound](size_t start, size_t end) {
|
||||
int64_t invalid_index_pos = -1;
|
||||
auto task = [this, &compute_func, &target, &indices, &updates, &invalid_index_pos](size_t start, size_t end) {
|
||||
size_t pre_batch_idx = -1;
|
||||
for (size_t upd_idx = start, out_idx = 0; upd_idx < end; ++upd_idx, ++out_idx) {
|
||||
size_t batch_idx = upd_idx / inner_size_;
|
||||
|
@ -161,14 +161,16 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
|
|||
if (batch_idx != pre_batch_idx) {
|
||||
pre_batch_idx = batch_idx;
|
||||
out_idx = upd_idx % inner_size_;
|
||||
size_t index_idx = batch_idx * slice_size_;
|
||||
for (size_t i = 0; i < slice_size_; i++) {
|
||||
auto index = indices[batch_idx * slice_size_ + i];
|
||||
auto index = indices[index_idx + i];
|
||||
out_idx += batch_strides_[i] * index * inner_size_;
|
||||
if (index < 0 && index >= static_cast<S>(input_shape_[i])) {
|
||||
out_bound = true;
|
||||
if (index < 0 || index >= static_cast<S>(input_shape_[i])) {
|
||||
invalid_index_pos = SizeToLong(index_idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (out_bound) {
|
||||
if (invalid_index_pos != -1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -181,8 +183,20 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
|
|||
constexpr size_t min_block_size = 128;
|
||||
auto block_size = std::max(min_block_size, element_size / GetActorMgrInnerThreadPool()->GetKernelThreadNum());
|
||||
ParallelLaunch(task, element_size, block_size, this);
|
||||
if (out_bound) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input 'indices' is out of bounds.";
|
||||
if (invalid_index_pos != -1) {
|
||||
std::stringstream indices_ss;
|
||||
std::stringstream input_shape_ss;
|
||||
for (size_t i = 0; i < slice_size_; i++) {
|
||||
if (i > 0) {
|
||||
indices_ss << ", ";
|
||||
input_shape_ss << ", ";
|
||||
}
|
||||
indices_ss << std::to_string(indices[invalid_index_pos + i]);
|
||||
input_shape_ss << std::to_string(input_shape_[i]);
|
||||
}
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << invalid_index_pos << "-th value of 'indices'["
|
||||
<< indices_ss.str() << "] is out of range[" + input_shape_ss.str() + "].";
|
||||
return false;
|
||||
}
|
||||
if (!is_tensor_scatter_arithmetic_) {
|
||||
if (auto ret = memcpy_s(output, outputs[kIndex0]->size, input, inputs[kIndex0]->size); ret != EOK) {
|
||||
|
|
Loading…
Reference in New Issue