forked from mindspore-Ecosystem/mindspore
!5382 [MS][LITE][Develop]optimize gather
Merge pull request !5382 from chenjianping/lite_dev2
This commit is contained in:
commit
29070d60a1
|
@ -37,6 +37,10 @@ int GatherCPUKernel::Init() {
|
||||||
return ReSize();
|
return ReSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GatherCPUKernel::~GatherCPUKernel() {
|
||||||
|
context_->allocator->Free(indices_data_);
|
||||||
|
}
|
||||||
|
|
||||||
int GatherCPUKernel::ReSize() { return RET_OK; }
|
int GatherCPUKernel::ReSize() { return RET_OK; }
|
||||||
|
|
||||||
int GatherCPUKernel::DoGather(int task_id) {
|
int GatherCPUKernel::DoGather(int task_id) {
|
||||||
|
@ -45,7 +49,6 @@ int GatherCPUKernel::DoGather(int task_id) {
|
||||||
auto out_tensor = out_tensors_.at(0);
|
auto out_tensor = out_tensors_.at(0);
|
||||||
|
|
||||||
auto input_ptr = reinterpret_cast<float *>(input_tensor->Data());
|
auto input_ptr = reinterpret_cast<float *>(input_tensor->Data());
|
||||||
auto indices_ptr = reinterpret_cast<float *>(indices_tensor->Data());
|
|
||||||
auto output_ptr = reinterpret_cast<float *>(out_tensor->Data());
|
auto output_ptr = reinterpret_cast<float *>(out_tensor->Data());
|
||||||
|
|
||||||
auto input_int32 = reinterpret_cast<int32_t *>(input_tensor->Data());
|
auto input_int32 = reinterpret_cast<int32_t *>(input_tensor->Data());
|
||||||
|
@ -57,13 +60,6 @@ int GatherCPUKernel::DoGather(int task_id) {
|
||||||
auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
|
auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
|
||||||
|
|
||||||
const int limit = in_shape[axis];
|
const int limit = in_shape[axis];
|
||||||
for (int i = 0; i < indices_element_size; ++i) {
|
|
||||||
indices_data_[i] = static_cast<int>(indices_ptr[i]);
|
|
||||||
if (indices_data_[i] >= limit) {
|
|
||||||
MS_LOG(ERROR) << " indice data: " << indices_data_[i] << " is not in [ 0, " << limit - 1 << " ]";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int outer_size = 1, inner_size = 1;
|
int outer_size = 1, inner_size = 1;
|
||||||
for (int i = 0; i < axis; ++i) {
|
for (int i = 0; i < axis; ++i) {
|
||||||
|
@ -106,12 +102,23 @@ int GatherCPUKernel::Run() {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto indices_tensor = in_tensors_.at(1);
|
auto indices_tensor = in_tensors_.at(1);
|
||||||
indices_data_ = reinterpret_cast<int *>(context_->allocator->Malloc(indices_tensor->ElementsNum() * sizeof(int)));
|
indices_data_ = reinterpret_cast<int *>(context_->allocator->Malloc(indices_tensor->Size()));
|
||||||
if (indices_data_ == nullptr) {
|
if (indices_data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Memory allocation failed";
|
MS_LOG(ERROR) << "Memory allocation failed";
|
||||||
context_->allocator->Free(indices_data_);
|
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
auto in_shape = in_tensors_.at(0)->shape();
|
||||||
|
int indices_element_size = indices_tensor->ElementsNum();
|
||||||
|
auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;;
|
||||||
|
auto indices_ptr = reinterpret_cast<float *>(indices_tensor->Data());
|
||||||
|
const int limit = in_shape[axis];
|
||||||
|
for (int i = 0; i < indices_element_size; ++i) {
|
||||||
|
indices_data_[i] = static_cast<int>(indices_ptr[i]);
|
||||||
|
if (indices_data_[i] >= limit) {
|
||||||
|
MS_LOG(ERROR) << " indice data: " << indices_data_[i] << " is not in [ 0, " << limit - 1 << " ]";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherRun, this, op_parameter_->thread_num_);
|
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherRun, this, op_parameter_->thread_num_);
|
||||||
if (error_code != RET_OK) {
|
if (error_code != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]";
|
MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]";
|
||||||
|
|
|
@ -28,7 +28,7 @@ class GatherCPUKernel : public LiteKernel {
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||||
const mindspore::lite::PrimitiveC *primitive)
|
const mindspore::lite::PrimitiveC *primitive)
|
||||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||||
~GatherCPUKernel() override = default;
|
~GatherCPUKernel() override;
|
||||||
|
|
||||||
int Init() override;
|
int Init() override;
|
||||||
int ReSize() override;
|
int ReSize() override;
|
||||||
|
|
Loading…
Reference in New Issue