forked from mindspore-Ecosystem/mindspore
optimize ScatterNdUpdate cpu kernel
This commit is contained in:
parent
ccb2e8851c
commit
87d6b62488
|
@ -16,10 +16,39 @@
|
|||
|
||||
#include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h"
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T>
|
||||
void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
|
||||
MS_EXCEPTION_IF_NULL(params);
|
||||
T *x = params->x_;
|
||||
int *indices = params->indices_;
|
||||
T *updates = params->updates_;
|
||||
std::vector<int> *out_strides = params->out_strides_;
|
||||
MS_EXCEPTION_IF_NULL(out_strides);
|
||||
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
int offset = 0;
|
||||
for (int j = 0; j < params->indices_unit_rank_; ++j) {
|
||||
auto index = indices[i * params->indices_unit_rank_ + j];
|
||||
if (index < 0) {
|
||||
MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index;
|
||||
}
|
||||
offset += index * out_strides->at(j) * params->unit_size_;
|
||||
}
|
||||
auto ret =
|
||||
memcpy_s(x + offset, params->x_mem_size_, updates + params->unit_size_ * i, params->unit_size_ * sizeof(T));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
Check(kernel_node);
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
@ -46,9 +75,9 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
unit_size_ *= SizeToInt(updates_shape[i]);
|
||||
}
|
||||
num_units_ = 1;
|
||||
num_units_ *= SizeToInt(updates_shape[indices_shape.size() - 2]);
|
||||
num_units_ *= updates_shape[indices_shape.size() - 2];
|
||||
for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) {
|
||||
num_units_ *= SizeToInt(updates_shape[i]);
|
||||
num_units_ *= updates_shape[i];
|
||||
}
|
||||
int out_stride = 1;
|
||||
out_strides_.push_back(out_stride);
|
||||
|
@ -56,8 +85,6 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
out_stride *= shape[i + 1];
|
||||
out_strides_.push_back(out_stride);
|
||||
}
|
||||
shape_ = shape;
|
||||
output_unit_offsets_.reserve(num_units_);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
|
@ -79,29 +106,29 @@ template <typename T>
|
|||
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto indices = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
ComputeParams<T> params;
|
||||
params.x_ = x;
|
||||
params.indices_ = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
params.updates_ = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
params.x_mem_size_ = inputs[0]->size;
|
||||
params.unit_size_ = unit_size_;
|
||||
params.indices_unit_rank_ = indices_unit_rank_;
|
||||
params.out_strides_ = &out_strides_;
|
||||
|
||||
for (int i = 0; i < num_units_; ++i) {
|
||||
int offset = 0;
|
||||
for (int j = 0; j < indices_unit_rank_; ++j) {
|
||||
auto index = indices[i * indices_unit_rank_ + j];
|
||||
if (index < 0) {
|
||||
MS_LOG(EXCEPTION) << "Error, Indices exist element which less than 0. element=" << index;
|
||||
}
|
||||
offset += index * out_strides_[j] * unit_size_;
|
||||
}
|
||||
output_unit_offsets_[i] = offset;
|
||||
const size_t thread_num = 24;
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(thread_num);
|
||||
size_t start = 0;
|
||||
size_t once_compute_size = (num_units_ + thread_num - 1) / thread_num;
|
||||
while (start < num_units_) {
|
||||
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
|
||||
threads.emplace_back(std::thread(Compute<T>, ¶ms, start, end));
|
||||
start += once_compute_size;
|
||||
}
|
||||
|
||||
auto mem_size = inputs[0]->size;
|
||||
for (int i = 0; i < num_units_; i++) {
|
||||
auto ret = memcpy_s(x + output_unit_offsets_[i], mem_size, updates + unit_size_ * i, unit_size_ * sizeof(T));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
for (size_t i = 0; i < threads.size(); ++i) {
|
||||
threads[i].join();
|
||||
}
|
||||
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size);
|
||||
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,17 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
struct ComputeParams {
|
||||
T *x_{nullptr};
|
||||
int *indices_{nullptr};
|
||||
T *updates_{nullptr};
|
||||
int unit_size_{0};
|
||||
int indices_unit_rank_{0};
|
||||
std::vector<int> *out_strides_{nullptr};
|
||||
size_t x_mem_size_{0};
|
||||
};
|
||||
|
||||
class ScatterNdUpdateCPUKernel : public CPUKernel {
|
||||
public:
|
||||
ScatterNdUpdateCPUKernel() = default;
|
||||
|
@ -41,10 +52,8 @@ class ScatterNdUpdateCPUKernel : public CPUKernel {
|
|||
void Check(const CNodePtr &kernel_node);
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
int unit_size_{0};
|
||||
int num_units_{0};
|
||||
size_t num_units_{0};
|
||||
int indices_unit_rank_{0};
|
||||
std::vector<size_t> shape_;
|
||||
std::vector<int> output_unit_offsets_;
|
||||
std::vector<int> out_strides_;
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue