From ac25fed6340be9777b480dc9c6442785c63a145c Mon Sep 17 00:00:00 2001 From: xuguoyang Date: Thu, 29 Apr 2021 15:04:56 +0800 Subject: [PATCH] Refactoring operator of gather by nnacl --- .../kernel_compiler/cpu/gather_cpu_kernel.cc | 112 ++++++++---------- .../kernel_compiler/cpu/gather_cpu_kernel.h | 5 +- 2 files changed, 52 insertions(+), 65 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc index 220a90f75ac..72ec849af1a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc @@ -15,6 +15,8 @@ */ #include "backend/kernel_compiler/cpu/gather_cpu_kernel.h" #include "runtime/device/cpu/cpu_device_address.h" +#include "nnacl/gather_parameter.h" +#include "nnacl/base/gather_base.h" namespace mindspore { namespace kernel { @@ -32,75 +34,59 @@ void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) { CPUKernelUtils::ExpandDimsTo4(&output_shape_); } +int GatherV2CPUKernel::GatherLaunch(int8_t *input_data, int8_t *output_data, size_t size) { + int in_rank = input_shape_.size(); + int indices_element_size = 1; + const int limit = input_shape_.at(axis_); + size_t data_size = sizeof(kNumberTypeFloat32); + int outer_size = 1, inner_size = 1; + + for (int i = 0; i < axis_; ++i) { + outer_size *= input_shape_.at(i); + } + for (int i = axis_ + 1; i < in_rank; ++i) { + inner_size *= input_shape_.at(i); + } + for (size_t i = 0; i < indices_shape_.size(); i++) { + indices_element_size *= indices_shape_.at(i); + } + int stride = UP_DIV(outer_size, size); + + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + int8_t *int8_in = input_data; + int8_t *int8_out = output_data; + int count = MSMIN(stride, static_cast(outer_size - stride * i)); + if (count <= 0) { + return; + } + auto thread_stride = stride * i; + int8_in += thread_stride * limit * inner_size * data_size; + int8_out += thread_stride * indices_element_size * inner_size * data_size; + auto error_code = + Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, sizeof(float)); + if (error_code != 0) { + MS_LOG(ERROR) << "GatherRun error task_id[" << i << "] error_code[" << error_code << "]"; + } + } + }; + CPUKernelUtils::ParallelFor(task, size); + + return 0; +} + bool GatherV2CPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto buff_size = outputs[0]->size; - size_t dim0 = input_shape_[0]; - size_t dim1 = input_shape_[1]; - size_t dim2 = input_shape_[2]; - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); - } - } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); - } - } else if (axis_ == 0) { - CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); - } + int8_t *input_tensor = reinterpret_cast(inputs[0]->addr); + indices_data_ = reinterpret_cast(inputs[1]->addr); + int8_t *output_addr = reinterpret_cast(outputs[0]->addr); + size_t size = (outputs[0]->size > 0) ? static_cast(outputs[0]->size / sizeof(int8_t)) : 1; + + GatherLaunch(input_tensor, output_addr, size); return true; } -void GatherV2CPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr, size_t *buff_size) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto indices_addr = reinterpret_cast(inputs[1]->addr); - size_t elem_num = inputs[1]->size / 4; - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); - for (size_t i = 0; i < elem_num; ++i) { - if (indices_addr[i] < 0) { - MS_LOG(EXCEPTION) << "The indices value is less than 0."; - } - size_t index = IntToSize(indices_addr[i]); - if (index >= input_shape_[LongToSize(axis_)]) { - auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memset failed."; - } - } else { - size_t pos = 0; - if (axis_ == 3) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); - } else if (axis_ == 2) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); - } else if (axis_ == 1) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); - } else if (axis_ == 0) { - pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); - } - auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed."; - } - } - *output_addr += num; - *buff_size -= num * sizeof(float); - } -} // namespace kernel - void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); if (input_shape.size() > 4) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h index b98077ed2a3..a1fb590e245 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h @@ -19,6 +19,7 @@ #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/cpu/nnacl/base/gather_base.h" namespace mindspore { namespace kernel { @@ -33,12 +34,12 @@ class GatherV2CPUKernel : public CPUKernel { const std::vector &outputs) override; private: - void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr, size_t *buff_size); void CheckParam(const CNodePtr &kernel_node); + int GatherLaunch(int8_t *int8_in, int8_t *int8_out, size_t size); std::vector input_shape_; std::vector indices_shape_; std::vector output_shape_; + int *indices_data_ = nullptr; int64_t axis_; };