!13301 [MS][LITE]reverse op bug

From: @fuzhiye
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-03-15 14:08:16 +08:00 committed by Gitee
commit edaffb66af
2 changed files with 15 additions and 1 deletions

View File

@ -37,6 +37,9 @@ int ReverseCPUKernel::Stride(int index) {
} }
int ReverseCPUKernel::ReSize() { int ReverseCPUKernel::ReSize() {
// trans negative to positive axis
UpdateAxisInfo();
data_size_ = in_tensors_.at(0)->ElementsNum(); data_size_ = in_tensors_.at(0)->ElementsNum();
thread_sz_count_ = MSMIN(op_parameter_->thread_num_, data_size_); thread_sz_count_ = MSMIN(op_parameter_->thread_num_, data_size_);
thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_);
@ -134,6 +137,16 @@ int ReverseCPUKernel::Run() {
return RET_OK; return RET_OK;
} }
void ReverseCPUKernel::UpdateAxisInfo() {
auto reverse_param = reinterpret_cast<ReverseParameter *>(op_parameter_);
int in_shape_len = in_tensors_.front()->shape().size();
for (int i = 0; i < reverse_param->num_axis_; ++i) {
if (reverse_param->axis_[i] < 0) {
reverse_param->axis_[i] += in_shape_len;
}
}
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ReverseV2, LiteKernelCreator<ReverseCPUKernel>)
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -28,7 +28,7 @@ class ReverseCPUKernel : public LiteKernel {
ReverseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, ReverseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {} : LiteKernel(parameter, inputs, outputs, ctx) {}
~ReverseCPUKernel() { ~ReverseCPUKernel() override {
if (tmp_ != nullptr) { if (tmp_ != nullptr) {
free(tmp_); free(tmp_);
tmp_ = nullptr; tmp_ = nullptr;
@ -42,6 +42,7 @@ class ReverseCPUKernel : public LiteKernel {
int DoReverse(int task_id); int DoReverse(int task_id);
private: private:
void UpdateAxisInfo();
int thread_sz_count_ = 0; int thread_sz_count_ = 0;
int thread_sz_stride_ = 0; int thread_sz_stride_ = 0;
int data_size_ = 0; int data_size_ = 0;