forked from mindspore-Ecosystem/mindspore
!13301 [MS][LITE]reverse op bug
From: @fuzhiye Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
edaffb66af
|
@ -37,6 +37,9 @@ int ReverseCPUKernel::Stride(int index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int ReverseCPUKernel::ReSize() {
|
int ReverseCPUKernel::ReSize() {
|
||||||
|
// trans negative to positive axis
|
||||||
|
UpdateAxisInfo();
|
||||||
|
|
||||||
data_size_ = in_tensors_.at(0)->ElementsNum();
|
data_size_ = in_tensors_.at(0)->ElementsNum();
|
||||||
thread_sz_count_ = MSMIN(op_parameter_->thread_num_, data_size_);
|
thread_sz_count_ = MSMIN(op_parameter_->thread_num_, data_size_);
|
||||||
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
|
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
|
||||||
|
@ -134,6 +137,16 @@ int ReverseCPUKernel::Run() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReverseCPUKernel::UpdateAxisInfo() {
|
||||||
|
auto reverse_param = reinterpret_cast<ReverseParameter *>(op_parameter_);
|
||||||
|
int in_shape_len = in_tensors_.front()->shape().size();
|
||||||
|
for (int i = 0; i < reverse_param->num_axis_; ++i) {
|
||||||
|
if (reverse_param->axis_[i] < 0) {
|
||||||
|
reverse_param->axis_[i] += in_shape_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>)
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>)
|
||||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>)
|
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>)
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -28,7 +28,7 @@ class ReverseCPUKernel : public LiteKernel {
|
||||||
ReverseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
ReverseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||||
: LiteKernel(parameter, inputs, outputs, ctx) {}
|
: LiteKernel(parameter, inputs, outputs, ctx) {}
|
||||||
~ReverseCPUKernel() {
|
~ReverseCPUKernel() override {
|
||||||
if (tmp_ != nullptr) {
|
if (tmp_ != nullptr) {
|
||||||
free(tmp_);
|
free(tmp_);
|
||||||
tmp_ = nullptr;
|
tmp_ = nullptr;
|
||||||
|
@ -42,6 +42,7 @@ class ReverseCPUKernel : public LiteKernel {
|
||||||
int DoReverse(int task_id);
|
int DoReverse(int task_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void UpdateAxisInfo();
|
||||||
int thread_sz_count_ = 0;
|
int thread_sz_count_ = 0;
|
||||||
int thread_sz_stride_ = 0;
|
int thread_sz_stride_ = 0;
|
||||||
int data_size_ = 0;
|
int data_size_ = 0;
|
||||||
|
|
Loading…
Reference in New Issue