!11012 [MS][LITE]Fix gatherNd

From: @gongdaguo
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-13 17:02:45 +08:00 committed by Gitee
commit afdb4e99c7
4 changed files with 20 additions and 6 deletions

View File

@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/fp32/gatherNd_fp32.h"
#include <string.h>
#include <limits>
#include <vector>
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
@ -65,11 +66,17 @@ int GatherNdCPUKernel::ReSize() {
return RET_ERROR;
}
(void)memset(in_offset_, 0, count_ * sizeof(int));
thread_sz_count_ = MSMIN(thread_count_, count_);
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
return RET_OK;
}
void GatherNdCPUKernel::InitOffset() {
MS_ASSERT(in_offset_ != nullptr);
auto indices_tensor = in_tensors_.at(1);
auto indices_shape = indices_tensor->shape();
auto in_shape = in_tensors_.front()->shape();
int indices_rank = indices_shape.size();
int in_rank = in_shape.size();
int idx_lastshape = indices_shape[indices_rank - 1];
auto indices_ptr = reinterpret_cast<int *>(indices_tensor->MutableData());
@ -89,8 +96,6 @@ int GatherNdCPUKernel::ReSize() {
in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride.at(k);
}
}
return RET_OK;
}
int GatherNdCPUKernel::DoGatherNd(int task_id) {
@ -120,6 +125,7 @@ int GatherNdRun(void *cdata, int task_id) {
int GatherNdCPUKernel::Run() {
in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->MutableData());
out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
InitOffset();
auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdRun, this, thread_sz_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]";

View File

@ -41,6 +41,7 @@ class GatherNdCPUKernel : public LiteKernel {
int DoGatherNd(int task_id);
private:
void InitOffset();
int thread_sz_count_;
int thread_sz_stride_;
int count_;

View File

@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/int8/gatherNd_int8.h"
#include <string.h>
#include <limits>
#include <vector>
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
@ -50,7 +51,6 @@ int GatherNdInt8CPUKernel::ReSize() {
in_offset_ = nullptr;
}
auto in_quant_args = in_tensors_.at(0)->quant_params();
auto ind_quant_args = in_tensors_.at(1)->quant_params();
auto out_quant_args = out_tensors_.at(0)->quant_params();
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
param_.zp_in_ = in_quant_args.front().zeroPoint;
@ -73,10 +73,16 @@ int GatherNdInt8CPUKernel::ReSize() {
return RET_ERROR;
}
(void)memset(in_offset_, 0, count_ * sizeof(int));
thread_sz_count_ = MSMIN(thread_count_, count_);
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
return RET_OK;
}
void GatherNdInt8CPUKernel::InitOffset() {
auto ind_quant_args = in_tensors_.at(1)->quant_params();
auto indices_tensor = in_tensors_.at(1);
auto indices_shape = indices_tensor->shape();
int indices_rank = indices_shape.size();
auto in_shape = in_tensors_.front()->shape();
int in_rank = in_shape.size();
int idx_lastshape = indices_shape.at(indices_rank - 1);
@ -99,7 +105,6 @@ int GatherNdInt8CPUKernel::ReSize() {
in_offset_[j] += tmp * in_stride[k];
}
}
return RET_OK;
}
int GatherNdInt8CPUKernel::DoGatherNd(int task_id) {
@ -129,6 +134,7 @@ int GatherNdInt8Run(void *cdata, int task_id) {
int GatherNdInt8CPUKernel::Run() {
in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.front()->MutableData());
out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.front()->MutableData());
InitOffset();
auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdInt8Run, this, thread_sz_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]";

View File

@ -36,6 +36,7 @@ class GatherNdInt8CPUKernel : public LiteKernel {
int DoGatherNd(int task_id);
private:
void InitOffset();
int thread_count_;
int thread_sz_count_;
int thread_sz_stride_;