diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc index bd0ab155418..9ac962ce7db 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc @@ -16,10 +16,39 @@ #include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h" #include +#include #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { namespace kernel { +namespace { +template +void Compute(const ComputeParams *params, const size_t start, const size_t end) { + MS_EXCEPTION_IF_NULL(params); + T *x = params->x_; + int *indices = params->indices_; + T *updates = params->updates_; + std::vector *out_strides = params->out_strides_; + MS_EXCEPTION_IF_NULL(out_strides); + + for (size_t i = start; i < end; ++i) { + int offset = 0; + for (int j = 0; j < params->indices_unit_rank_; ++j) { + auto index = indices[i * params->indices_unit_rank_ + j]; + if (index < 0) { + MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index; + } + offset += index * out_strides->at(j) * params->unit_size_; + } + auto ret = + memcpy_s(x + offset, params->x_mem_size_, updates + params->unit_size_ * i, params->unit_size_ * sizeof(T)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + } + } +} +} // namespace + void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { Check(kernel_node); auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -46,9 +75,9 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { unit_size_ *= SizeToInt(updates_shape[i]); } num_units_ = 1; - num_units_ *= SizeToInt(updates_shape[indices_shape.size() - 2]); + num_units_ *= updates_shape[indices_shape.size() - 2]; for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) { - num_units_ *= SizeToInt(updates_shape[i]); + num_units_ *= updates_shape[i]; } int out_stride = 1; out_strides_.push_back(out_stride); @@ -56,8 +85,6 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { out_stride *= shape[i + 1]; out_strides_.push_back(out_stride); } - shape_ = shape; - output_unit_offsets_.reserve(num_units_); dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); } @@ -79,29 +106,29 @@ template void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { auto x = reinterpret_cast(inputs[0]->addr); - auto indices = reinterpret_cast(inputs[1]->addr); - auto updates = reinterpret_cast(inputs[2]->addr); + ComputeParams params; + params.x_ = x; + params.indices_ = reinterpret_cast(inputs[1]->addr); + params.updates_ = reinterpret_cast(inputs[2]->addr); + params.x_mem_size_ = inputs[0]->size; + params.unit_size_ = unit_size_; + params.indices_unit_rank_ = indices_unit_rank_; + params.out_strides_ = &out_strides_; - for (int i = 0; i < num_units_; ++i) { - int offset = 0; - for (int j = 0; j < indices_unit_rank_; ++j) { - auto index = indices[i * indices_unit_rank_ + j]; - if (index < 0) { - MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index; - } - offset += index * out_strides_[j] * unit_size_; - } - output_unit_offsets_[i] = offset; + const size_t thread_num = 24; + std::vector threads; + threads.reserve(thread_num); + size_t start = 0; + size_t once_compute_size = (num_units_ + thread_num - 1) / thread_num; + while (start < num_units_) { + size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size); + threads.emplace_back(std::thread(Compute, ¶ms, start, end)); + start += once_compute_size; } - - auto mem_size = inputs[0]->size; - for (int i = 0; i < num_units_; i++) { - auto ret = memcpy_s(x + output_unit_offsets_[i], mem_size, updates + unit_size_ * i, unit_size_ * sizeof(T)); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; - } + for (size_t i = 0; i < threads.size(); ++i) { + threads[i].join(); } - auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size); + auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h index f2a45cecd8f..c8606512f2d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h @@ -24,6 +24,17 @@ namespace mindspore { namespace kernel { +template +struct ComputeParams { + T *x_{nullptr}; + int *indices_{nullptr}; + T *updates_{nullptr}; + int unit_size_{0}; + int indices_unit_rank_{0}; + std::vector *out_strides_{nullptr}; + size_t x_mem_size_{0}; +}; + class ScatterNdUpdateCPUKernel : public CPUKernel { public: ScatterNdUpdateCPUKernel() = default; @@ -41,10 +52,8 @@ class ScatterNdUpdateCPUKernel : public CPUKernel { void Check(const CNodePtr &kernel_node); TypeId dtype_{kTypeUnknown}; int unit_size_{0}; - int num_units_{0}; + size_t num_units_{0}; int indices_unit_rank_{0}; - std::vector shape_; - std::vector output_unit_offsets_; std::vector out_strides_; };