forked from mindspore-Ecosystem/mindspore
!34959 Fix bug of ScatterNd* operator in CPU.
Merge pull request !34959 from hezhenhao1/add_scatter_mul
This commit is contained in:
commit
b7aa9ce900
|
@ -151,8 +151,8 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
|
||||||
}
|
}
|
||||||
target = output;
|
target = output;
|
||||||
}
|
}
|
||||||
bool out_bound = false;
|
int64_t invalid_index_pos = -1;
|
||||||
auto task = [this, &compute_func, &target, &indices, &updates, &out_bound](size_t start, size_t end) {
|
auto task = [this, &compute_func, &target, &indices, &updates, &invalid_index_pos](size_t start, size_t end) {
|
||||||
size_t pre_batch_idx = -1;
|
size_t pre_batch_idx = -1;
|
||||||
for (size_t upd_idx = start, out_idx = 0; upd_idx < end; ++upd_idx, ++out_idx) {
|
for (size_t upd_idx = start, out_idx = 0; upd_idx < end; ++upd_idx, ++out_idx) {
|
||||||
size_t batch_idx = upd_idx / inner_size_;
|
size_t batch_idx = upd_idx / inner_size_;
|
||||||
|
@ -161,14 +161,16 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
|
||||||
if (batch_idx != pre_batch_idx) {
|
if (batch_idx != pre_batch_idx) {
|
||||||
pre_batch_idx = batch_idx;
|
pre_batch_idx = batch_idx;
|
||||||
out_idx = upd_idx % inner_size_;
|
out_idx = upd_idx % inner_size_;
|
||||||
|
size_t index_idx = batch_idx * slice_size_;
|
||||||
for (size_t i = 0; i < slice_size_; i++) {
|
for (size_t i = 0; i < slice_size_; i++) {
|
||||||
auto index = indices[batch_idx * slice_size_ + i];
|
auto index = indices[index_idx + i];
|
||||||
out_idx += batch_strides_[i] * index * inner_size_;
|
out_idx += batch_strides_[i] * index * inner_size_;
|
||||||
if (index < 0 && index >= static_cast<S>(input_shape_[i])) {
|
if (index < 0 || index >= static_cast<S>(input_shape_[i])) {
|
||||||
out_bound = true;
|
invalid_index_pos = SizeToLong(index_idx);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (out_bound) {
|
if (invalid_index_pos != -1) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -181,8 +183,20 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Add
|
||||||
constexpr size_t min_block_size = 128;
|
constexpr size_t min_block_size = 128;
|
||||||
auto block_size = std::max(min_block_size, element_size / GetActorMgrInnerThreadPool()->GetKernelThreadNum());
|
auto block_size = std::max(min_block_size, element_size / GetActorMgrInnerThreadPool()->GetKernelThreadNum());
|
||||||
ParallelLaunch(task, element_size, block_size, this);
|
ParallelLaunch(task, element_size, block_size, this);
|
||||||
if (out_bound) {
|
if (invalid_index_pos != -1) {
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input 'indices' is out of bounds.";
|
std::stringstream indices_ss;
|
||||||
|
std::stringstream input_shape_ss;
|
||||||
|
for (size_t i = 0; i < slice_size_; i++) {
|
||||||
|
if (i > 0) {
|
||||||
|
indices_ss << ", ";
|
||||||
|
input_shape_ss << ", ";
|
||||||
|
}
|
||||||
|
indices_ss << std::to_string(indices[invalid_index_pos + i]);
|
||||||
|
input_shape_ss << std::to_string(input_shape_[i]);
|
||||||
|
}
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << invalid_index_pos << "-th value of 'indices'["
|
||||||
|
<< indices_ss.str() << "] is out of range[" + input_shape_ss.str() + "].";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
if (!is_tensor_scatter_arithmetic_) {
|
if (!is_tensor_scatter_arithmetic_) {
|
||||||
if (auto ret = memcpy_s(output, outputs[kIndex0]->size, input, inputs[kIndex0]->size); ret != EOK) {
|
if (auto ret = memcpy_s(output, outputs[kIndex0]->size, input, inputs[kIndex0]->size); ret != EOK) {
|
||||||
|
|
Loading…
Reference in New Issue