From 644c83a645c42c7bf5e88b06091c3d583e574cea Mon Sep 17 00:00:00 2001 From: hezhenhao1 Date: Thu, 26 May 2022 10:26:53 +0800 Subject: [PATCH] Fix bug of ScatterNd* operator in CPU. --- .../scatter_nd_arithmetic_cpu_kernel.cc | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_arithmetic_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_arithmetic_cpu_kernel.cc index c17718a39a7..92076e7ea1f 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/scatter_nd_arithmetic_cpu_kernel.cc @@ -151,8 +151,8 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vector= static_cast(input_shape_[i])) { - out_bound = true; + if (index < 0 || index >= static_cast(input_shape_[i])) { + invalid_index_pos = SizeToLong(index_idx); + break; } } - if (out_bound) { + if (invalid_index_pos != -1) { break; } } @@ -181,8 +183,20 @@ bool ScatterNdArithmeticCpuKernelMod::LaunchKernel(const std::vectorGetKernelThreadNum()); ParallelLaunch(task, element_size, block_size, this); - if (out_bound) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the input 'indices' is out of bounds."; + if (invalid_index_pos != -1) { + std::stringstream indices_ss; + std::stringstream input_shape_ss; + for (size_t i = 0; i < slice_size_; i++) { + if (i > 0) { + indices_ss << ", "; + input_shape_ss << ", "; + } + indices_ss << std::to_string(indices[invalid_index_pos + i]); + input_shape_ss << std::to_string(input_shape_[i]); + } + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the " << invalid_index_pos << "-th value of 'indices'[" + << indices_ss.str() << "] is out of range[" + input_shape_ss.str() + "]."; + return false; } if (!is_tensor_scatter_arithmetic_) { if (auto ret = memcpy_s(output, outputs[kIndex0]->size, input, inputs[kIndex0]->size); ret != EOK) {