Refactoring operator of gather by nnacl

This commit is contained in:
xuguoyang 2021-04-29 15:04:56 +08:00
parent ba5b751418
commit ac25fed634
2 changed files with 52 additions and 65 deletions

View File

@ -15,6 +15,8 @@
*/
#include "backend/kernel_compiler/cpu/gather_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "nnacl/gather_parameter.h"
#include "nnacl/base/gather_base.h"
namespace mindspore {
namespace kernel {
@ -32,75 +34,59 @@ void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) {
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
}
int GatherV2CPUKernel::GatherLaunch(int8_t *input_data, int8_t *output_data, size_t size) {
int in_rank = input_shape_.size();
int indices_element_size = 1;
const int limit = input_shape_.at(axis_);
size_t data_size = sizeof(kNumberTypeFloat32);
int outer_size = 1, inner_size = 1;
for (int i = 0; i < axis_; ++i) {
outer_size *= input_shape_.at(i);
}
for (int i = axis_ + 1; i < in_rank; ++i) {
inner_size *= input_shape_.at(i);
}
for (size_t i = 0; i < indices_shape_.size(); i++) {
indices_element_size *= indices_shape_.at(i);
}
int stride = UP_DIV(outer_size, size);
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
int8_t *int8_in = input_data;
int8_t *int8_out = output_data;
int count = MSMIN(stride, static_cast<int>(outer_size - stride * i));
if (count <= 0) {
return;
}
auto thread_stride = stride * i;
int8_in += thread_stride * limit * inner_size * data_size;
int8_out += thread_stride * indices_element_size * inner_size * data_size;
auto error_code =
Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, sizeof(float));
if (error_code != 0) {
MS_LOG(ERROR) << "GatherRun error task_id[" << i << "] error_code[" << error_code << "]";
}
}
};
CPUKernelUtils::ParallelFor(task, size);
return 0;
}
bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto buff_size = outputs[0]->size;
size_t dim0 = input_shape_[0];
size_t dim1 = input_shape_[1];
size_t dim2 = input_shape_[2];
if (axis_ == 3) {
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
for (size_t k = 0; k < dim2; ++k) {
CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size);
}
}
}
} else if (axis_ == 2) {
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size);
}
}
} else if (axis_ == 1) {
for (size_t i = 0; i < dim0; ++i) {
CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size);
}
} else if (axis_ == 0) {
CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size);
}
int8_t *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
indices_data_ = reinterpret_cast<int32_t *>(inputs[1]->addr);
int8_t *output_addr = reinterpret_cast<int8_t *>(outputs[0]->addr);
size_t size = (outputs[0]->size > 0) ? static_cast<size_t>(outputs[0]->size / sizeof(int8_t)) : 1;
GatherLaunch(input_tensor, output_addr, size);
return true;
}
void GatherV2CPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
size_t dim2, float **output_addr, size_t *buff_size) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
size_t elem_num = inputs[1]->size / 4;
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_);
for (size_t i = 0; i < elem_num; ++i) {
if (indices_addr[i] < 0) {
MS_LOG(EXCEPTION) << "The indices value is less than 0.";
}
size_t index = IntToSize(indices_addr[i]);
if (index >= input_shape_[LongToSize(axis_)]) {
auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memset failed.";
}
} else {
size_t pos = 0;
if (axis_ == 3) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index);
} else if (axis_ == 2) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0);
} else if (axis_ == 1) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0);
} else if (axis_ == 0) {
pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0);
}
auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed.";
}
}
*output_addr += num;
*buff_size -= num * sizeof(float);
}
} // namespace kernel
void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > 4) {

View File

@ -19,6 +19,7 @@
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/cpu/nnacl/base/gather_base.h"
namespace mindspore {
namespace kernel {
@ -33,12 +34,12 @@ class GatherV2CPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2,
float **output_addr, size_t *buff_size);
void CheckParam(const CNodePtr &kernel_node);
int GatherLaunch(int8_t *int8_in, int8_t *int8_out, size_t size);
std::vector<size_t> input_shape_;
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;
int *indices_data_ = nullptr;
int64_t axis_;
};