optimize gather op

This commit is contained in:
xuanyue 2022-02-16 16:41:26 +08:00
parent d0d932c94f
commit 0f72ace8cf
9 changed files with 231 additions and 219 deletions

View File

@ -92,11 +92,13 @@ bool GatherV2CpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inpu
indices_element_size *= indices_shape_.at(i);
}
auto limit = input_shape_.at(axis);
size_t byte_inner_size = inner_size * sizeof(T);
size_t byte_out_stride = indices_element_size * byte_inner_size;
auto task = [&](size_t start, size_t end) {
int count = SizeToInt(end - start);
const int8_t *in = input_tensor + start * limit * inner_size * sizeof(T);
int8_t *out = output_addr + start * indices_element_size * inner_size * sizeof(T);
int ret = Gather(in, count, inner_size, limit, indices_data, indices_element_size, out, sizeof(T));
const int8_t *in = input_tensor + start * limit * byte_inner_size;
int8_t *out = output_addr + start * byte_out_stride;
int ret = Gather(in, count, byte_inner_size, limit, indices_data, indices_element_size, out, byte_out_stride);
if (ret != 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
}

View File

@ -16,28 +16,28 @@
#include <stdio.h>
#include "nnacl/base/gather_base.h"
int Gather(const void *input, int outer_size, int inner_size, int limit, const int *indices, int indices_element_size,
void *output, int data_size) {
int Gather(const void *input, int64_t outer_size, int64_t inner_size, int64_t limit, const int *indices,
int64_t index_num, void *output, int64_t out_stride) {
if (input == NULL || output == NULL || indices == NULL) {
return NNACL_NULL_PTR;
}
const int8_t *int8_in = (int8_t *)input;
int8_t *int8_out = (int8_t *)output;
for (int m = 0; m < outer_size; ++m) {
const int8_t *int8_in_m = int8_in + inner_size * m * limit * data_size;
int8_t *int8_out_m = int8_out + inner_size * m * indices_element_size * data_size;
for (int i = 0; i < indices_element_size; ++i) {
int64_t in_stride = inner_size * limit;
for (int64_t m = 0; m < outer_size; ++m) {
int8_t *int8_out_m = int8_out;
for (int64_t i = 0; i < index_num; ++i) {
int index = indices[i];
if (index < -limit || index >= limit) {
memset(int8_out_m + i * inner_size * data_size, 0, data_size * inner_size);
continue;
}
index = index < 0 ? index + limit : index;
memcpy(int8_out_m + i * inner_size * data_size, int8_in_m + index * inner_size * data_size,
data_size * inner_size);
if (index < 0 || index >= limit) {
memset(int8_out_m, 0, inner_size);
} else {
memcpy(int8_out_m, int8_in + index * inner_size, inner_size);
}
int8_out_m += inner_size;
}
int8_in += in_stride;
int8_out += out_stride;
}
return NNACL_OK;
}

View File

@ -17,15 +17,14 @@
#ifndef MINDSPORE_NNACL_GATHER_BASE_H_
#define MINDSPORE_NNACL_GATHER_BASE_H_
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#ifdef __cplusplus
extern "C" {
#endif
int Gather(const void *input, int outer_size, int inner_size, int limit, const int *indices, int indices_element_size,
void *output, int data_size);
int Gather(const void *input, int64_t outer_size, int64_t inner_size, int64_t limit, const int *indices,
int64_t indices_element_size, void *output, int64_t data_size);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,132 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/gather_base.h"
#include <limits>
#include "nnacl/base/gather_base.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int GatherRun(const void *cdata, int task_id, float, float) {
auto gather_kernel = reinterpret_cast<const GatherBaseCPUKernel *>(cdata);
auto error_code = gather_kernel->DoGather(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
}
return error_code;
}
int GatherBaseCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), kInputSize2);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_.at(THIRD_INPUT));
CHECK_NULL_RETURN(in_tensors_.at(THIRD_INPUT)->data());
axis_ = *(reinterpret_cast<int *>(in_tensors_.at(THIRD_INPUT)->data()));
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GatherBaseCPUKernel::ReSize() { return ChooseThreadCuttingstrategy(); }
int GatherBaseCPUKernel::DoGather(int task_id) const {
auto *int8_in = reinterpret_cast<int8_t *>(in_tensors_[FIRST_INPUT]->data());
CHECK_NULL_RETURN(int8_in);
auto in_shape = in_tensors_[FIRST_INPUT]->shape();
MS_CHECK_LT(axis_, static_cast<int>(in_shape.size()), RET_ERROR);
const int64_t limit = in_shape.at(axis_);
auto *int8_out = reinterpret_cast<int8_t *>(out_tensors_[FIRST_INPUT]->data());
CHECK_NULL_RETURN(int8_out);
int data_size = static_cast<int>(lite::DataTypeSize(in_tensors_[FIRST_INPUT]->data_type()));
auto index_num = in_tensors_[SECOND_INPUT]->ElementsNum();
int64_t byte_inner_size = inner_size_ * data_size;
int64_t byte_out_stride = index_num * byte_inner_size;
int64_t all_count = split_by_index ? index_num : outer_size_;
int64_t count = (task_id < static_cast<int>(split_points_.size()) - 1)
? split_points_[task_id + 1] - split_points_[task_id]
: all_count - split_points_[task_id];
int ret = RET_OK;
if (split_by_index) {
int *indices_data = indices_data_ + split_points_[task_id];
int8_out += split_points_[task_id] * byte_inner_size;
ret = Gather(int8_in, outer_size_, byte_inner_size, limit, indices_data, count, int8_out, byte_out_stride);
} else {
int8_in += split_points_[task_id] * limit * byte_inner_size;
int8_out += split_points_[task_id] * byte_out_stride;
ret = Gather(int8_in, count, byte_inner_size, limit, indices_data_, index_num, int8_out, byte_out_stride);
}
return ret;
}
int GatherBaseCPUKernel::Run() {
bool isIndicesInt32 = in_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32;
int ret = AssignIndicesData(isIndicesInt32);
if (ret != RET_OK) {
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
return ret;
}
ret = ParallelLaunch(this->ms_context_, GatherRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
}
if (!isIndicesInt32) {
ms_context_->allocator->Free(indices_data_);
indices_data_ = nullptr;
}
return ret;
}
int GatherBaseCPUKernel::ChooseThreadCuttingstrategy() {
auto in_shape = in_tensors_[FIRST_INPUT]->shape();
int in_rank = static_cast<int>(in_shape.size());
MS_CHECK_TRUE_MSG(axis_ < in_rank, RET_ERROR, "gather's inputs are invalid.");
outer_size_ = 1;
for (int i = 0; i < axis_; ++i) {
outer_size_ *= in_shape.at(i);
}
inner_size_ = 1;
for (int i = axis_ + 1; i < in_rank; ++i) {
inner_size_ *= in_shape.at(i);
}
int64_t all_count = outer_size_;
auto index_num = in_tensors_[SECOND_INPUT]->ElementsNum();
if (outer_size_ >= index_num) {
split_by_index = false;
} else {
all_count = index_num;
split_by_index = true;
}
int64_t count_step = MSMAX(all_count / op_parameter_->thread_num_, 1);
int64_t count_remaining = MSMAX(all_count - count_step * op_parameter_->thread_num_, 0);
split_points_.clear();
int64_t split_point = 0;
while (split_point < all_count) {
split_points_.push_back(split_point);
split_point += count_step;
if (count_remaining > 0) {
split_point += 1;
--count_remaining;
}
}
thread_count_ = static_cast<int>(split_points_.size());
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,52 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GATHER_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GATHER_BASE_H_
#include <vector>
#include "include/errorcode.h"
#include "src/inner_kernel.h"
namespace mindspore::kernel {
class GatherBaseCPUKernel : public InnerKernel {
public:
GatherBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
~GatherBaseCPUKernel() = default;
int Prepare() override;
int ReSize() override;
int Run() override;
int DoGather(int task_id) const;
protected:
virtual int AssignIndicesData(bool isIndicesInt32) = 0;
int *indices_data_{nullptr};
private:
int ChooseThreadCuttingstrategy();
bool split_by_index{false}; // default by outer-size
int axis_ = 0;
int thread_count_{0};
int64_t outer_size_{0};
int64_t inner_size_{0};
std::vector<int64_t> split_points_; // split by outer-size or index-data.
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GATHER_BASE_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,97 +29,18 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gather;
namespace mindspore::kernel {
int GatherFp16CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), 3);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_[FIRST_INPUT]);
CHECK_NULL_RETURN(in_tensors_[SECOND_INPUT]);
CHECK_NULL_RETURN(in_tensors_[THIRD_INPUT]);
CHECK_NULL_RETURN(out_tensors_[kOutputIndex]);
CHECK_NULL_RETURN(in_tensors_[THIRD_INPUT]->data());
(reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ = *(static_cast<int *>(in_tensors_[THIRD_INPUT]->data()));
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GatherFp16CPUKernel::ReSize() { return RET_OK; }
int GatherFp16CPUKernel::DoGather(int task_id) {
auto input_tensor = in_tensors_.at(0);
auto indices_tensor = in_tensors_.at(1);
auto out_tensor = out_tensors_.at(0);
auto in_shape = input_tensor->shape();
int in_rank = in_shape.size();
int indices_element_size = indices_tensor->ElementsNum();
auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
MS_CHECK_LT(axis, static_cast<int>(in_shape.size()), RET_ERROR);
const int limit = in_shape.at(axis);
int outer_size = 1, inner_size = 1;
for (int i = 0; i < axis; ++i) {
outer_size *= in_shape.at(i);
}
for (int i = axis + 1; i < in_rank; ++i) {
inner_size *= in_shape.at(i);
}
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
int count = MSMIN(stride, outer_size - stride * task_id);
if (count <= 0) {
return RET_OK;
}
auto thread_stride = stride * task_id;
int8_t *int8_in = nullptr;
if (input_tensor->data_type() == kNumberTypeFloat16) {
int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
} else {
MS_LOG(ERROR) << "input data type error";
return RET_ERROR;
}
int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data());
CHECK_NULL_RETURN(int8_in);
CHECK_NULL_RETURN(int8_out);
int data_size = lite::DataTypeSize(kNumberTypeFloat16);
int8_in += thread_stride * limit * inner_size * data_size;
int8_out += thread_stride * indices_element_size * inner_size * data_size;
int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
return error_code;
}
int GatherRunFp16(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto gather_kernel = reinterpret_cast<GatherFp16CPUKernel *>(cdata);
auto error_code = gather_kernel->DoGather(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
}
return error_code;
}
void GatherFp16CPUKernel::FreeIndicesData() {
if (!is_indices_int32_) {
ms_context_->allocator->Free(indices_data_);
indices_data_ = nullptr;
}
}
int GatherFp16CPUKernel::Run() {
auto indices_tensor = in_tensors_.at(1);
int indices_num = indices_tensor->ElementsNum();
is_indices_int32_ = indices_tensor->data_type() == kNumberTypeInt32;
int ret = AssignIndicesData(is_indices_int32_, indices_num, indices_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
return ret;
}
ret = ParallelLaunch(this->ms_context_, GatherRunFp16, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
}
FreeIndicesData();
return ret;
CHECK_NULL_RETURN(in_tensors_.at(FIRST_INPUT));
CHECK_NULL_RETURN(in_tensors_.at(SECOND_INPUT));
CHECK_NULL_RETURN(out_tensors_.at(FIRST_INPUT));
MS_CHECK_TRUE_MSG(in_tensors_[FIRST_INPUT]->data_type() == kNumberTypeFloat16, RET_ERROR,
"First input's data-type is not fp16.");
return GatherBaseCPUKernel::Run();
}
int GatherFp16CPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, const lite::Tensor *indices_tensor) {
int GatherFp16CPUKernel::AssignIndicesData(bool isIndicesInt32) {
auto indices_tensor = in_tensors_[SECOND_INPUT];
auto indices_num = indices_tensor->ElementsNum();
CHECK_NULL_RETURN(indices_tensor->data());
if (!isIndicesInt32) {
if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,28 +20,20 @@
#include <arm_neon.h>
#include <vector>
#include "include/errorcode.h"
#include "src/inner_kernel.h"
#include "nnacl/gather_parameter.h"
#include "nnacl/base/gather_base.h"
#include "src/runtime/kernel/arm/base/gather_base.h"
namespace mindspore::kernel {
class GatherFp16CPUKernel : public InnerKernel {
class GatherFp16CPUKernel : public GatherBaseCPUKernel {
public:
GatherFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
: GatherBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~GatherFp16CPUKernel() = default;
int Prepare() override;
int ReSize() override;
int Run() override;
int DoGather(int task_id);
private:
int *indices_data_ = nullptr;
int AssignIndicesData(bool isIndicesInt32, int indices_num, const lite::Tensor *indices_tensor);
void FreeIndicesData();
bool is_indices_int32_ = false;
int AssignIndicesData(bool isIndicesInt32) override;
};
} // namespace mindspore::kernel

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,96 +26,17 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gather;
namespace mindspore::kernel {
namespace {
constexpr int kSecondInput = 2;
}
int GatherCPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), kInputSize2);
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_.at(kSecondInput));
CHECK_NULL_RETURN(in_tensors_.at(kSecondInput)->data());
axis_ = *(reinterpret_cast<int *>(in_tensors_.at(kSecondInput)->data()));
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GatherCPUKernel::ReSize() { return RET_OK; }
int GatherCPUKernel::DoGather(int task_id) const {
auto input_tensor = in_tensors_.at(0);
auto indices_tensor = in_tensors_.at(1);
auto out_tensor = out_tensors_.at(0);
auto in_shape = input_tensor->shape();
int in_rank = in_shape.size();
int indices_element_size = indices_tensor->ElementsNum();
MS_CHECK_LT(axis_, in_rank, RET_ERROR);
const int limit = in_shape.at(axis_);
int outer_size = 1, inner_size = 1;
for (int i = 0; i < axis_; ++i) {
outer_size *= in_shape.at(i);
}
for (int i = axis_ + 1; i < in_rank; ++i) {
inner_size *= in_shape.at(i);
}
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
int count = MSMIN(stride, outer_size - stride * task_id);
if (count <= 0) {
return RET_OK;
}
auto thread_stride = stride * task_id;
int8_t *int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
CHECK_NULL_RETURN(int8_in);
int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data());
CHECK_NULL_RETURN(int8_out);
int data_size = static_cast<int>(lite::DataTypeSize(input_tensor->data_type()));
int8_in += thread_stride * limit * inner_size * data_size;
int8_out += thread_stride * indices_element_size * inner_size * data_size;
int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
return error_code;
}
int GatherRun(const void *cdata, int task_id, float, float) {
auto gather_kernel = reinterpret_cast<const GatherCPUKernel *>(cdata);
auto error_code = gather_kernel->DoGather(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
}
return error_code;
}
int GatherCPUKernel::Run() {
CHECK_NULL_RETURN(in_tensors_.at(FIRST_INPUT));
CHECK_NULL_RETURN(in_tensors_.at(SECOND_INPUT));
CHECK_NULL_RETURN(out_tensors_.at(FIRST_INPUT));
auto indices_tensor = in_tensors_.at(1);
int indices_num = indices_tensor->ElementsNum();
bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32;
int ret = AssignIndicesData(isIndicesInt32, indices_num, indices_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
return ret;
}
ret = ParallelLaunch(this->ms_context_, GatherRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
}
if (!isIndicesInt32) {
ms_context_->allocator->Free(indices_data_);
indices_data_ = nullptr;
}
return ret;
return GatherBaseCPUKernel::Run();
}
int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor) {
int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32) {
auto indices_tensor = in_tensors_[SECOND_INPUT];
auto indices_num = indices_tensor->ElementsNum();
CHECK_NULL_RETURN(indices_tensor->data());
if (!isIndicesInt32) {
if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
MS_LOG(ERROR) << "Input indices_num is invalid, indices_num: " << indices_num;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,27 +19,20 @@
#include <vector>
#include "include/errorcode.h"
#include "src/inner_kernel.h"
#include "nnacl/gather_parameter.h"
#include "nnacl/base/gather_base.h"
#include "src/runtime/kernel/arm/base/gather_base.h"
namespace mindspore::kernel {
class GatherCPUKernel : public InnerKernel {
class GatherCPUKernel : public GatherBaseCPUKernel {
public:
GatherCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
: GatherBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~GatherCPUKernel() = default;
int Prepare() override;
int ReSize() override;
int Run() override;
int DoGather(int task_id) const;
private:
int *indices_data_ = nullptr;
int axis_ = 0;
int AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor);
int AssignIndicesData(bool isIndicesInt32) override;
};
} // namespace mindspore::kernel