forked from mindspore-Ecosystem/mindspore
!11012 [MS][LITE]Fix gatherNd
From: @gongdaguo Reviewed-by: Signed-off-by:
This commit is contained in:
commit
afdb4e99c7
|
@ -17,6 +17,7 @@
|
|||
#include "src/runtime/kernel/arm/fp32/gatherNd_fp32.h"
|
||||
#include <string.h>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
@ -65,11 +66,17 @@ int GatherNdCPUKernel::ReSize() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
(void)memset(in_offset_, 0, count_ * sizeof(int));
|
||||
|
||||
thread_sz_count_ = MSMIN(thread_count_, count_);
|
||||
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void GatherNdCPUKernel::InitOffset() {
|
||||
MS_ASSERT(in_offset_ != nullptr);
|
||||
auto indices_tensor = in_tensors_.at(1);
|
||||
auto indices_shape = indices_tensor->shape();
|
||||
auto in_shape = in_tensors_.front()->shape();
|
||||
int indices_rank = indices_shape.size();
|
||||
int in_rank = in_shape.size();
|
||||
int idx_lastshape = indices_shape[indices_rank - 1];
|
||||
auto indices_ptr = reinterpret_cast<int *>(indices_tensor->MutableData());
|
||||
|
@ -89,8 +96,6 @@ int GatherNdCPUKernel::ReSize() {
|
|||
in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride.at(k);
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GatherNdCPUKernel::DoGatherNd(int task_id) {
|
||||
|
@ -120,6 +125,7 @@ int GatherNdRun(void *cdata, int task_id) {
|
|||
int GatherNdCPUKernel::Run() {
|
||||
in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->MutableData());
|
||||
out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
InitOffset();
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdRun, this, thread_sz_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]";
|
||||
|
|
|
@ -41,6 +41,7 @@ class GatherNdCPUKernel : public LiteKernel {
|
|||
int DoGatherNd(int task_id);
|
||||
|
||||
private:
|
||||
void InitOffset();
|
||||
int thread_sz_count_;
|
||||
int thread_sz_stride_;
|
||||
int count_;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "src/runtime/kernel/arm/int8/gatherNd_int8.h"
|
||||
#include <string.h>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "schema/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
@ -50,7 +51,6 @@ int GatherNdInt8CPUKernel::ReSize() {
|
|||
in_offset_ = nullptr;
|
||||
}
|
||||
auto in_quant_args = in_tensors_.at(0)->quant_params();
|
||||
auto ind_quant_args = in_tensors_.at(1)->quant_params();
|
||||
auto out_quant_args = out_tensors_.at(0)->quant_params();
|
||||
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
|
||||
param_.zp_in_ = in_quant_args.front().zeroPoint;
|
||||
|
@ -73,10 +73,16 @@ int GatherNdInt8CPUKernel::ReSize() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
(void)memset(in_offset_, 0, count_ * sizeof(int));
|
||||
|
||||
thread_sz_count_ = MSMIN(thread_count_, count_);
|
||||
thread_sz_stride_ = UP_DIV(count_, thread_sz_count_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void GatherNdInt8CPUKernel::InitOffset() {
|
||||
auto ind_quant_args = in_tensors_.at(1)->quant_params();
|
||||
auto indices_tensor = in_tensors_.at(1);
|
||||
auto indices_shape = indices_tensor->shape();
|
||||
int indices_rank = indices_shape.size();
|
||||
auto in_shape = in_tensors_.front()->shape();
|
||||
int in_rank = in_shape.size();
|
||||
int idx_lastshape = indices_shape.at(indices_rank - 1);
|
||||
|
@ -99,7 +105,6 @@ int GatherNdInt8CPUKernel::ReSize() {
|
|||
in_offset_[j] += tmp * in_stride[k];
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GatherNdInt8CPUKernel::DoGatherNd(int task_id) {
|
||||
|
@ -129,6 +134,7 @@ int GatherNdInt8Run(void *cdata, int task_id) {
|
|||
int GatherNdInt8CPUKernel::Run() {
|
||||
in_ptr_ = reinterpret_cast<int8_t *>(in_tensors_.front()->MutableData());
|
||||
out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.front()->MutableData());
|
||||
InitOffset();
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, GatherNdInt8Run, this, thread_sz_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]";
|
||||
|
|
|
@ -36,6 +36,7 @@ class GatherNdInt8CPUKernel : public LiteKernel {
|
|||
int DoGatherNd(int task_id);
|
||||
|
||||
private:
|
||||
void InitOffset();
|
||||
int thread_count_;
|
||||
int thread_sz_count_;
|
||||
int thread_sz_stride_;
|
||||
|
|
Loading…
Reference in New Issue