forked from mindspore-Ecosystem/mindspore
optimize gather op
This commit is contained in:
parent
d0d932c94f
commit
0f72ace8cf
|
@ -92,11 +92,13 @@ bool GatherV2CpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inpu
|
|||
indices_element_size *= indices_shape_.at(i);
|
||||
}
|
||||
auto limit = input_shape_.at(axis);
|
||||
size_t byte_inner_size = inner_size * sizeof(T);
|
||||
size_t byte_out_stride = indices_element_size * byte_inner_size;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
int count = SizeToInt(end - start);
|
||||
const int8_t *in = input_tensor + start * limit * inner_size * sizeof(T);
|
||||
int8_t *out = output_addr + start * indices_element_size * inner_size * sizeof(T);
|
||||
int ret = Gather(in, count, inner_size, limit, indices_data, indices_element_size, out, sizeof(T));
|
||||
const int8_t *in = input_tensor + start * limit * byte_inner_size;
|
||||
int8_t *out = output_addr + start * byte_out_stride;
|
||||
int ret = Gather(in, count, byte_inner_size, limit, indices_data, indices_element_size, out, byte_out_stride);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]";
|
||||
}
|
||||
|
|
|
@ -16,28 +16,28 @@
|
|||
#include <stdio.h>
|
||||
#include "nnacl/base/gather_base.h"
|
||||
|
||||
int Gather(const void *input, int outer_size, int inner_size, int limit, const int *indices, int indices_element_size,
|
||||
void *output, int data_size) {
|
||||
int Gather(const void *input, int64_t outer_size, int64_t inner_size, int64_t limit, const int *indices,
|
||||
int64_t index_num, void *output, int64_t out_stride) {
|
||||
if (input == NULL || output == NULL || indices == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
const int8_t *int8_in = (int8_t *)input;
|
||||
int8_t *int8_out = (int8_t *)output;
|
||||
|
||||
for (int m = 0; m < outer_size; ++m) {
|
||||
const int8_t *int8_in_m = int8_in + inner_size * m * limit * data_size;
|
||||
int8_t *int8_out_m = int8_out + inner_size * m * indices_element_size * data_size;
|
||||
|
||||
for (int i = 0; i < indices_element_size; ++i) {
|
||||
int64_t in_stride = inner_size * limit;
|
||||
for (int64_t m = 0; m < outer_size; ++m) {
|
||||
int8_t *int8_out_m = int8_out;
|
||||
for (int64_t i = 0; i < index_num; ++i) {
|
||||
int index = indices[i];
|
||||
if (index < -limit || index >= limit) {
|
||||
memset(int8_out_m + i * inner_size * data_size, 0, data_size * inner_size);
|
||||
continue;
|
||||
}
|
||||
index = index < 0 ? index + limit : index;
|
||||
memcpy(int8_out_m + i * inner_size * data_size, int8_in_m + index * inner_size * data_size,
|
||||
data_size * inner_size);
|
||||
if (index < 0 || index >= limit) {
|
||||
memset(int8_out_m, 0, inner_size);
|
||||
} else {
|
||||
memcpy(int8_out_m, int8_in + index * inner_size, inner_size);
|
||||
}
|
||||
int8_out_m += inner_size;
|
||||
}
|
||||
int8_in += in_stride;
|
||||
int8_out += out_stride;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -17,15 +17,14 @@
|
|||
#ifndef MINDSPORE_NNACL_GATHER_BASE_H_
|
||||
#define MINDSPORE_NNACL_GATHER_BASE_H_
|
||||
|
||||
#include <string.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int Gather(const void *input, int outer_size, int inner_size, int limit, const int *indices, int indices_element_size,
|
||||
void *output, int data_size);
|
||||
int Gather(const void *input, int64_t outer_size, int64_t inner_size, int64_t limit, const int *indices,
|
||||
int64_t indices_element_size, void *output, int64_t data_size);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/base/gather_base.h"
|
||||
#include <limits>
|
||||
#include "nnacl/base/gather_base.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int GatherRun(const void *cdata, int task_id, float, float) {
|
||||
auto gather_kernel = reinterpret_cast<const GatherBaseCPUKernel *>(cdata);
|
||||
auto error_code = gather_kernel->DoGather(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
|
||||
int GatherBaseCPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), kInputSize2);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(in_tensors_.at(THIRD_INPUT));
|
||||
CHECK_NULL_RETURN(in_tensors_.at(THIRD_INPUT)->data());
|
||||
axis_ = *(reinterpret_cast<int *>(in_tensors_.at(THIRD_INPUT)->data()));
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int GatherBaseCPUKernel::ReSize() { return ChooseThreadCuttingstrategy(); }
|
||||
|
||||
int GatherBaseCPUKernel::DoGather(int task_id) const {
|
||||
auto *int8_in = reinterpret_cast<int8_t *>(in_tensors_[FIRST_INPUT]->data());
|
||||
CHECK_NULL_RETURN(int8_in);
|
||||
auto in_shape = in_tensors_[FIRST_INPUT]->shape();
|
||||
MS_CHECK_LT(axis_, static_cast<int>(in_shape.size()), RET_ERROR);
|
||||
const int64_t limit = in_shape.at(axis_);
|
||||
auto *int8_out = reinterpret_cast<int8_t *>(out_tensors_[FIRST_INPUT]->data());
|
||||
CHECK_NULL_RETURN(int8_out);
|
||||
int data_size = static_cast<int>(lite::DataTypeSize(in_tensors_[FIRST_INPUT]->data_type()));
|
||||
auto index_num = in_tensors_[SECOND_INPUT]->ElementsNum();
|
||||
int64_t byte_inner_size = inner_size_ * data_size;
|
||||
int64_t byte_out_stride = index_num * byte_inner_size;
|
||||
int64_t all_count = split_by_index ? index_num : outer_size_;
|
||||
int64_t count = (task_id < static_cast<int>(split_points_.size()) - 1)
|
||||
? split_points_[task_id + 1] - split_points_[task_id]
|
||||
: all_count - split_points_[task_id];
|
||||
|
||||
int ret = RET_OK;
|
||||
if (split_by_index) {
|
||||
int *indices_data = indices_data_ + split_points_[task_id];
|
||||
int8_out += split_points_[task_id] * byte_inner_size;
|
||||
ret = Gather(int8_in, outer_size_, byte_inner_size, limit, indices_data, count, int8_out, byte_out_stride);
|
||||
} else {
|
||||
int8_in += split_points_[task_id] * limit * byte_inner_size;
|
||||
int8_out += split_points_[task_id] * byte_out_stride;
|
||||
ret = Gather(int8_in, count, byte_inner_size, limit, indices_data_, index_num, int8_out, byte_out_stride);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int GatherBaseCPUKernel::Run() {
|
||||
bool isIndicesInt32 = in_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32;
|
||||
int ret = AssignIndicesData(isIndicesInt32);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, GatherRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
|
||||
}
|
||||
if (!isIndicesInt32) {
|
||||
ms_context_->allocator->Free(indices_data_);
|
||||
indices_data_ = nullptr;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int GatherBaseCPUKernel::ChooseThreadCuttingstrategy() {
|
||||
auto in_shape = in_tensors_[FIRST_INPUT]->shape();
|
||||
int in_rank = static_cast<int>(in_shape.size());
|
||||
MS_CHECK_TRUE_MSG(axis_ < in_rank, RET_ERROR, "gather's inputs are invalid.");
|
||||
outer_size_ = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
outer_size_ *= in_shape.at(i);
|
||||
}
|
||||
inner_size_ = 1;
|
||||
for (int i = axis_ + 1; i < in_rank; ++i) {
|
||||
inner_size_ *= in_shape.at(i);
|
||||
}
|
||||
int64_t all_count = outer_size_;
|
||||
auto index_num = in_tensors_[SECOND_INPUT]->ElementsNum();
|
||||
if (outer_size_ >= index_num) {
|
||||
split_by_index = false;
|
||||
} else {
|
||||
all_count = index_num;
|
||||
split_by_index = true;
|
||||
}
|
||||
int64_t count_step = MSMAX(all_count / op_parameter_->thread_num_, 1);
|
||||
int64_t count_remaining = MSMAX(all_count - count_step * op_parameter_->thread_num_, 0);
|
||||
split_points_.clear();
|
||||
int64_t split_point = 0;
|
||||
while (split_point < all_count) {
|
||||
split_points_.push_back(split_point);
|
||||
split_point += count_step;
|
||||
if (count_remaining > 0) {
|
||||
split_point += 1;
|
||||
--count_remaining;
|
||||
}
|
||||
}
|
||||
thread_count_ = static_cast<int>(split_points_.size());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GATHER_BASE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GATHER_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/inner_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class GatherBaseCPUKernel : public InnerKernel {
|
||||
public:
|
||||
GatherBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {}
|
||||
~GatherBaseCPUKernel() = default;
|
||||
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoGather(int task_id) const;
|
||||
|
||||
protected:
|
||||
virtual int AssignIndicesData(bool isIndicesInt32) = 0;
|
||||
int *indices_data_{nullptr};
|
||||
|
||||
private:
|
||||
int ChooseThreadCuttingstrategy();
|
||||
bool split_by_index{false}; // default by outer-size
|
||||
int axis_ = 0;
|
||||
int thread_count_{0};
|
||||
int64_t outer_size_{0};
|
||||
int64_t inner_size_{0};
|
||||
std::vector<int64_t> split_points_; // split by outer-size or index-data.
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GATHER_BASE_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,97 +29,18 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_Gather;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int GatherFp16CPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 3);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(in_tensors_[FIRST_INPUT]);
|
||||
CHECK_NULL_RETURN(in_tensors_[SECOND_INPUT]);
|
||||
CHECK_NULL_RETURN(in_tensors_[THIRD_INPUT]);
|
||||
CHECK_NULL_RETURN(out_tensors_[kOutputIndex]);
|
||||
CHECK_NULL_RETURN(in_tensors_[THIRD_INPUT]->data());
|
||||
(reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ = *(static_cast<int *>(in_tensors_[THIRD_INPUT]->data()));
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int GatherFp16CPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int GatherFp16CPUKernel::DoGather(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
auto indices_tensor = in_tensors_.at(1);
|
||||
auto out_tensor = out_tensors_.at(0);
|
||||
auto in_shape = input_tensor->shape();
|
||||
int in_rank = in_shape.size();
|
||||
int indices_element_size = indices_tensor->ElementsNum();
|
||||
auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
|
||||
MS_CHECK_LT(axis, static_cast<int>(in_shape.size()), RET_ERROR);
|
||||
const int limit = in_shape.at(axis);
|
||||
int outer_size = 1, inner_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
outer_size *= in_shape.at(i);
|
||||
}
|
||||
for (int i = axis + 1; i < in_rank; ++i) {
|
||||
inner_size *= in_shape.at(i);
|
||||
}
|
||||
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
|
||||
int count = MSMIN(stride, outer_size - stride * task_id);
|
||||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto thread_stride = stride * task_id;
|
||||
int8_t *int8_in = nullptr;
|
||||
if (input_tensor->data_type() == kNumberTypeFloat16) {
|
||||
int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "input data type error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data());
|
||||
CHECK_NULL_RETURN(int8_in);
|
||||
CHECK_NULL_RETURN(int8_out);
|
||||
int data_size = lite::DataTypeSize(kNumberTypeFloat16);
|
||||
int8_in += thread_stride * limit * inner_size * data_size;
|
||||
int8_out += thread_stride * indices_element_size * inner_size * data_size;
|
||||
int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
|
||||
return error_code;
|
||||
}
|
||||
|
||||
int GatherRunFp16(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto gather_kernel = reinterpret_cast<GatherFp16CPUKernel *>(cdata);
|
||||
auto error_code = gather_kernel->DoGather(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
|
||||
void GatherFp16CPUKernel::FreeIndicesData() {
|
||||
if (!is_indices_int32_) {
|
||||
ms_context_->allocator->Free(indices_data_);
|
||||
indices_data_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int GatherFp16CPUKernel::Run() {
|
||||
auto indices_tensor = in_tensors_.at(1);
|
||||
int indices_num = indices_tensor->ElementsNum();
|
||||
is_indices_int32_ = indices_tensor->data_type() == kNumberTypeInt32;
|
||||
int ret = AssignIndicesData(is_indices_int32_, indices_num, indices_tensor);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
ret = ParallelLaunch(this->ms_context_, GatherRunFp16, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
|
||||
}
|
||||
FreeIndicesData();
|
||||
return ret;
|
||||
CHECK_NULL_RETURN(in_tensors_.at(FIRST_INPUT));
|
||||
CHECK_NULL_RETURN(in_tensors_.at(SECOND_INPUT));
|
||||
CHECK_NULL_RETURN(out_tensors_.at(FIRST_INPUT));
|
||||
MS_CHECK_TRUE_MSG(in_tensors_[FIRST_INPUT]->data_type() == kNumberTypeFloat16, RET_ERROR,
|
||||
"First input's data-type is not fp16.");
|
||||
return GatherBaseCPUKernel::Run();
|
||||
}
|
||||
|
||||
int GatherFp16CPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, const lite::Tensor *indices_tensor) {
|
||||
int GatherFp16CPUKernel::AssignIndicesData(bool isIndicesInt32) {
|
||||
auto indices_tensor = in_tensors_[SECOND_INPUT];
|
||||
auto indices_num = indices_tensor->ElementsNum();
|
||||
CHECK_NULL_RETURN(indices_tensor->data());
|
||||
if (!isIndicesInt32) {
|
||||
if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,28 +20,20 @@
|
|||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/gather_parameter.h"
|
||||
#include "nnacl/base/gather_base.h"
|
||||
#include "src/runtime/kernel/arm/base/gather_base.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class GatherFp16CPUKernel : public InnerKernel {
|
||||
class GatherFp16CPUKernel : public GatherBaseCPUKernel {
|
||||
public:
|
||||
GatherFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {}
|
||||
: GatherBaseCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~GatherFp16CPUKernel() = default;
|
||||
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoGather(int task_id);
|
||||
|
||||
private:
|
||||
int *indices_data_ = nullptr;
|
||||
int AssignIndicesData(bool isIndicesInt32, int indices_num, const lite::Tensor *indices_tensor);
|
||||
void FreeIndicesData();
|
||||
bool is_indices_int32_ = false;
|
||||
int AssignIndicesData(bool isIndicesInt32) override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,96 +26,17 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_Gather;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr int kSecondInput = 2;
|
||||
}
|
||||
int GatherCPUKernel::Prepare() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), kInputSize2);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(in_tensors_.at(kSecondInput));
|
||||
CHECK_NULL_RETURN(in_tensors_.at(kSecondInput)->data());
|
||||
axis_ = *(reinterpret_cast<int *>(in_tensors_.at(kSecondInput)->data()));
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int GatherCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int GatherCPUKernel::DoGather(int task_id) const {
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
auto indices_tensor = in_tensors_.at(1);
|
||||
auto out_tensor = out_tensors_.at(0);
|
||||
|
||||
auto in_shape = input_tensor->shape();
|
||||
int in_rank = in_shape.size();
|
||||
int indices_element_size = indices_tensor->ElementsNum();
|
||||
MS_CHECK_LT(axis_, in_rank, RET_ERROR);
|
||||
const int limit = in_shape.at(axis_);
|
||||
|
||||
int outer_size = 1, inner_size = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
outer_size *= in_shape.at(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < in_rank; ++i) {
|
||||
inner_size *= in_shape.at(i);
|
||||
}
|
||||
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
|
||||
int count = MSMIN(stride, outer_size - stride * task_id);
|
||||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto thread_stride = stride * task_id;
|
||||
|
||||
int8_t *int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
|
||||
CHECK_NULL_RETURN(int8_in);
|
||||
int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data());
|
||||
CHECK_NULL_RETURN(int8_out);
|
||||
|
||||
int data_size = static_cast<int>(lite::DataTypeSize(input_tensor->data_type()));
|
||||
int8_in += thread_stride * limit * inner_size * data_size;
|
||||
int8_out += thread_stride * indices_element_size * inner_size * data_size;
|
||||
|
||||
int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
|
||||
|
||||
return error_code;
|
||||
}
|
||||
|
||||
int GatherRun(const void *cdata, int task_id, float, float) {
|
||||
auto gather_kernel = reinterpret_cast<const GatherCPUKernel *>(cdata);
|
||||
auto error_code = gather_kernel->DoGather(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
|
||||
int GatherCPUKernel::Run() {
|
||||
CHECK_NULL_RETURN(in_tensors_.at(FIRST_INPUT));
|
||||
CHECK_NULL_RETURN(in_tensors_.at(SECOND_INPUT));
|
||||
CHECK_NULL_RETURN(out_tensors_.at(FIRST_INPUT));
|
||||
auto indices_tensor = in_tensors_.at(1);
|
||||
int indices_num = indices_tensor->ElementsNum();
|
||||
bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32;
|
||||
int ret = AssignIndicesData(isIndicesInt32, indices_num, indices_tensor);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = ParallelLaunch(this->ms_context_, GatherRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
|
||||
}
|
||||
if (!isIndicesInt32) {
|
||||
ms_context_->allocator->Free(indices_data_);
|
||||
indices_data_ = nullptr;
|
||||
}
|
||||
return ret;
|
||||
return GatherBaseCPUKernel::Run();
|
||||
}
|
||||
|
||||
int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor) {
|
||||
int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32) {
|
||||
auto indices_tensor = in_tensors_[SECOND_INPUT];
|
||||
auto indices_num = indices_tensor->ElementsNum();
|
||||
CHECK_NULL_RETURN(indices_tensor->data());
|
||||
if (!isIndicesInt32) {
|
||||
if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
|
||||
MS_LOG(ERROR) << "Input indices_num is invalid, indices_num: " << indices_num;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,27 +19,20 @@
|
|||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/gather_parameter.h"
|
||||
#include "nnacl/base/gather_base.h"
|
||||
#include "src/runtime/kernel/arm/base/gather_base.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class GatherCPUKernel : public InnerKernel {
|
||||
class GatherCPUKernel : public GatherBaseCPUKernel {
|
||||
public:
|
||||
GatherCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {}
|
||||
: GatherBaseCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~GatherCPUKernel() = default;
|
||||
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoGather(int task_id) const;
|
||||
|
||||
private:
|
||||
int *indices_data_ = nullptr;
|
||||
int axis_ = 0;
|
||||
int AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor);
|
||||
int AssignIndicesData(bool isIndicesInt32) override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue