optimize ScatterNdUpdate cpu kernel

This commit is contained in:
huanghui 2020-09-07 16:07:29 +08:00
parent ccb2e8851c
commit 87d6b62488
2 changed files with 63 additions and 27 deletions

View File

@ -16,10 +16,39 @@
#include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h"
#include <string>
#include <thread>
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
template <typename T>
void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
MS_EXCEPTION_IF_NULL(params);
T *x = params->x_;
int *indices = params->indices_;
T *updates = params->updates_;
std::vector<int> *out_strides = params->out_strides_;
MS_EXCEPTION_IF_NULL(out_strides);
for (size_t i = start; i < end; ++i) {
int offset = 0;
for (int j = 0; j < params->indices_unit_rank_; ++j) {
auto index = indices[i * params->indices_unit_rank_ + j];
if (index < 0) {
MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index;
}
offset += index * out_strides->at(j) * params->unit_size_;
}
auto ret =
memcpy_s(x + offset, params->x_mem_size_, updates + params->unit_size_ * i, params->unit_size_ * sizeof(T));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
}
}
} // namespace
void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
Check(kernel_node);
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
@ -46,9 +75,9 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
unit_size_ *= SizeToInt(updates_shape[i]);
}
num_units_ = 1;
num_units_ *= SizeToInt(updates_shape[indices_shape.size() - 2]);
num_units_ *= updates_shape[indices_shape.size() - 2];
for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) {
num_units_ *= SizeToInt(updates_shape[i]);
num_units_ *= updates_shape[i];
}
int out_stride = 1;
out_strides_.push_back(out_stride);
@ -56,8 +85,6 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
out_stride *= shape[i + 1];
out_strides_.push_back(out_stride);
}
shape_ = shape;
output_unit_offsets_.reserve(num_units_);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
}
@ -79,29 +106,29 @@ template <typename T>
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto indices = reinterpret_cast<int *>(inputs[1]->addr);
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
ComputeParams<T> params;
params.x_ = x;
params.indices_ = reinterpret_cast<int *>(inputs[1]->addr);
params.updates_ = reinterpret_cast<T *>(inputs[2]->addr);
params.x_mem_size_ = inputs[0]->size;
params.unit_size_ = unit_size_;
params.indices_unit_rank_ = indices_unit_rank_;
params.out_strides_ = &out_strides_;
for (int i = 0; i < num_units_; ++i) {
int offset = 0;
for (int j = 0; j < indices_unit_rank_; ++j) {
auto index = indices[i * indices_unit_rank_ + j];
if (index < 0) {
MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index;
const size_t thread_num = 24;
std::vector<std::thread> threads;
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (num_units_ + thread_num - 1) / thread_num;
while (start < num_units_) {
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
threads.emplace_back(std::thread(Compute<T>, &params, start, end));
start += once_compute_size;
}
offset += index * out_strides_[j] * unit_size_;
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
output_unit_offsets_[i] = offset;
}
auto mem_size = inputs[0]->size;
for (int i = 0; i < num_units_; i++) {
auto ret = memcpy_s(x + output_unit_offsets_[i], mem_size, updates + unit_size_ * i, unit_size_ * sizeof(T));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
}
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size);
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}

View File

@ -24,6 +24,17 @@
namespace mindspore {
namespace kernel {
template <typename T>
struct ComputeParams {
T *x_{nullptr};
int *indices_{nullptr};
T *updates_{nullptr};
int unit_size_{0};
int indices_unit_rank_{0};
std::vector<int> *out_strides_{nullptr};
size_t x_mem_size_{0};
};
class ScatterNdUpdateCPUKernel : public CPUKernel {
public:
ScatterNdUpdateCPUKernel() = default;
@ -41,10 +52,8 @@ class ScatterNdUpdateCPUKernel : public CPUKernel {
void Check(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
int unit_size_{0};
int num_units_{0};
size_t num_units_{0};
int indices_unit_rank_{0};
std::vector<size_t> shape_;
std::vector<int> output_unit_offsets_;
std::vector<int> out_strides_;
};