fix ml_2020_ocr_cn fp16 segmenttation one thread

Merge pull request  from zhaodezan/master
This commit is contained in:
mindspore-ci-bot 2020-10-29 18:36:08 +08:00 committed by Gitee
commit d1f01a8d9c
2 changed files with 27 additions and 0 deletions
mindspore/lite/src/runtime/kernel/arm/fp16

View File

@ -21,6 +21,7 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/ops/populate/populate_register.h"
#include "include/errorcode.h" #include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
@ -97,6 +98,31 @@ int ArithmeticFP16CPUKernel::Init() {
return ReSize(); return ReSize();
} }
int ArithmeticFP16CPUKernel::PreProcess() {
if (!InferShapeDone()) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(true);
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (ret != 0) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(false);
MS_LOG(ERROR) << "InferShape fail!";
return ret;
}
param_ = reinterpret_cast<ArithmeticParameter *>(PopulateArithmetic(primitive_));
ret = ReSize();
if (ret != 0) {
MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
return ret;
}
}
auto outputs = this->out_tensors();
for (auto *output : outputs) {
MS_ASSERT(output != nullptr);
output->MallocData();
}
return RET_OK;
}
int ArithmeticFP16CPUKernel::ReSize() { int ArithmeticFP16CPUKernel::ReSize() {
param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); param_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); param_->in_elements_num1_ = in_tensors_[1]->ElementsNum();

View File

@ -44,6 +44,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
~ArithmeticFP16CPUKernel() = default; ~ArithmeticFP16CPUKernel() = default;
int Init() override; int Init() override;
int PreProcess() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int DoArithmetic(int task_id); int DoArithmetic(int task_id);