forked from mindspore-Ecosystem/mindspore
Refactoring operator of gather by nnacl
This commit is contained in:
parent
ba5b751418
commit
ac25fed634
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
#include "backend/kernel_compiler/cpu/gather_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "nnacl/gather_parameter.h"
|
||||
#include "nnacl/base/gather_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -32,75 +34,59 @@ void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
||||
}
|
||||
|
||||
int GatherV2CPUKernel::GatherLaunch(int8_t *input_data, int8_t *output_data, size_t size) {
|
||||
int in_rank = input_shape_.size();
|
||||
int indices_element_size = 1;
|
||||
const int limit = input_shape_.at(axis_);
|
||||
size_t data_size = sizeof(kNumberTypeFloat32);
|
||||
int outer_size = 1, inner_size = 1;
|
||||
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
outer_size *= input_shape_.at(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < in_rank; ++i) {
|
||||
inner_size *= input_shape_.at(i);
|
||||
}
|
||||
for (size_t i = 0; i < indices_shape_.size(); i++) {
|
||||
indices_element_size *= indices_shape_.at(i);
|
||||
}
|
||||
int stride = UP_DIV(outer_size, size);
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
int8_t *int8_in = input_data;
|
||||
int8_t *int8_out = output_data;
|
||||
int count = MSMIN(stride, static_cast<int>(outer_size - stride * i));
|
||||
if (count <= 0) {
|
||||
return;
|
||||
}
|
||||
auto thread_stride = stride * i;
|
||||
int8_in += thread_stride * limit * inner_size * data_size;
|
||||
int8_out += thread_stride * indices_element_size * inner_size * data_size;
|
||||
auto error_code =
|
||||
Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, sizeof(float));
|
||||
if (error_code != 0) {
|
||||
MS_LOG(ERROR) << "GatherRun error task_id[" << i << "] error_code[" << error_code << "]";
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto buff_size = outputs[0]->size;
|
||||
size_t dim0 = input_shape_[0];
|
||||
size_t dim1 = input_shape_[1];
|
||||
size_t dim2 = input_shape_[2];
|
||||
if (axis_ == 3) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
for (size_t k = 0; k < dim2; ++k) {
|
||||
CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (axis_ == 2) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size);
|
||||
}
|
||||
}
|
||||
} else if (axis_ == 1) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size);
|
||||
}
|
||||
} else if (axis_ == 0) {
|
||||
CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size);
|
||||
}
|
||||
int8_t *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
|
||||
indices_data_ = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
||||
int8_t *output_addr = reinterpret_cast<int8_t *>(outputs[0]->addr);
|
||||
size_t size = (outputs[0]->size > 0) ? static_cast<size_t>(outputs[0]->size / sizeof(int8_t)) : 1;
|
||||
|
||||
GatherLaunch(input_tensor, output_addr, size);
|
||||
return true;
|
||||
}
|
||||
|
||||
void GatherV2CPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
|
||||
size_t dim2, float **output_addr, size_t *buff_size) {
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
size_t elem_num = inputs[1]->size / 4;
|
||||
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_);
|
||||
for (size_t i = 0; i < elem_num; ++i) {
|
||||
if (indices_addr[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "The indices value is less than 0.";
|
||||
}
|
||||
size_t index = IntToSize(indices_addr[i]);
|
||||
if (index >= input_shape_[LongToSize(axis_)]) {
|
||||
auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memset failed.";
|
||||
}
|
||||
} else {
|
||||
size_t pos = 0;
|
||||
if (axis_ == 3) {
|
||||
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index);
|
||||
} else if (axis_ == 2) {
|
||||
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0);
|
||||
} else if (axis_ == 1) {
|
||||
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0);
|
||||
} else if (axis_ == 0) {
|
||||
pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0);
|
||||
}
|
||||
auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed.";
|
||||
}
|
||||
}
|
||||
*output_addr += num;
|
||||
*buff_size -= num * sizeof(float);
|
||||
}
|
||||
} // namespace kernel
|
||||
|
||||
void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.size() > 4) {
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/base/gather_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -33,12 +34,12 @@ class GatherV2CPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2,
|
||||
float **output_addr, size_t *buff_size);
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
int GatherLaunch(int8_t *int8_in, int8_t *int8_out, size_t size);
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> indices_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
int *indices_data_ = nullptr;
|
||||
int64_t axis_;
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue