!34959 Fix bug of ScatterNd* operator in CPU.

Merge pull request !34959 from hezhenhao1/add_scatter_mul
This commit is contained in:
i-robot 2022-05-26 07:06:05 +00:00 committed by Gitee
commit b7aa9ce900
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 22 additions and 8 deletions

View File

@ -151,8 +151,8 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
}
target = output;
}
bool out_bound = false;
auto task = [this, &compute_func, &target, &indices, &updates, &out_bound](size_t start, size_t end) {
int64_t invalid_index_pos = -1;
auto task = [this, &compute_func, &target, &indices, &updates, &invalid_index_pos](size_t start, size_t end) {
size_t pre_batch_idx = -1;
for (size_t upd_idx = start, out_idx = 0; upd_idx < end; ++upd_idx, ++out_idx) {
size_t batch_idx = upd_idx / inner_size_;
@ -161,14 +161,16 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
if (batch_idx != pre_batch_idx) {
pre_batch_idx = batch_idx;
out_idx = upd_idx % inner_size_;
size_t index_idx = batch_idx * slice_size_;
for (size_t i = 0; i < slice_size_; i++) {
auto index = indices[batch_idx * slice_size_ + i];
auto index = indices[index_idx + i];
out_idx += batch_strides_[i] * index * inner_size_;
if (index < 0 && index >= static_cast<S>(input_shape_[i])) {
out_bound = true;
if (index < 0 || index >= static_cast<S>(input_shape_[i])) {
invalid_index_pos = SizeToLong(index_idx);
break;
}
}
if (out_bound) {
if (invalid_index_pos != -1) {
break;
}
}
@ -181,8 +183,20 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
constexpr size_t min_block_size = 128;
auto block_size = std::max(min_block_size, element_size / GetActorMgrInnerThreadPool()->GetKernelThreadNum());
ParallelLaunch(task, element_size, block_size, this);
if (out_bound) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input 'indices' is out of bounds.";
if (invalid_index_pos != -1) {
std::stringstream indices_ss;
std::stringstream input_shape_ss;
for (size_t i = 0; i < slice_size_; i++) {
if (i > 0) {
indices_ss << ", ";
input_shape_ss << ", ";
}
indices_ss << std::to_string(indices[invalid_index_pos + i]);
input_shape_ss << std::to_string(input_shape_[i]);
}
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << invalid_index_pos << "-th value of 'indices'["
<< indices_ss.str() << "] is out of range[" + input_shape_ss.str() + "].";
return false;
}
if (!is_tensor_scatter_arithmetic_) {
if (auto ret = memcpy_s(output, outputs[kIndex0]->size, input, inputs[kIndex0]->size); ret != EOK) {