fix output shape error in deconv operator and support parameter passing when doing infershape during inference

This commit is contained in:
AGroupofProbiotocs 2020-10-22 09:03:14 +08:00
parent 9b2b062642
commit f533dd6579
5 changed files with 35 additions and 8 deletions

View File

@ -22,6 +22,7 @@
#include "src/runtime/kernel/arm/int8/add_int8.h"
#include "src/runtime/kernel/arm/int8/mul_int8.h"
#include "src/runtime/runtime_api.h"
#include "src/populate_parameter.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -40,6 +41,31 @@ int ArithmeticCPUKernel::Init() {
return ReSize();
}
int ArithmeticCPUKernel::PreProcess() {
if (!InferShapeDone()) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(true);
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (ret != 0) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(false);
MS_LOG(ERROR) << "InferShape fail!";
return ret;
}
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(kernel::PopulateArithmetic(primitive_));
ret = ReSize();
if (ret != 0) {
MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
return ret;
}
}
auto outputs = this->out_tensors();
for (auto *output : outputs) {
MS_ASSERT(output != nullptr);
output->MallocData();
}
return RET_OK;
}
int ArithmeticCPUKernel::ReSize() {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
data_type_ = kDataTypeFloat;

View File

@ -163,6 +163,7 @@ class ArithmeticCPUKernel : public LiteKernel {
~ArithmeticCPUKernel() override;
int Init() override;
int PreProcess() override;
int ReSize() override;
int Run() override;
int DoArithmetic(int task_id);

View File

@ -81,7 +81,7 @@ int QuantizedAddCPUKernel::Run() {
input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData());
output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData());
elements_num_ = in_tensors_.at(0)->ElementsNum();
elements_num_ = out_tensors_.at(0)->ElementsNum();
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {

View File

@ -106,7 +106,7 @@ int MulInt8CPUKernel::Run() {
input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData());
output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData());
elements_num_ = in_tensors_.at(0)->ElementsNum();
elements_num_ = out_tensors_.at(0)->ElementsNum();
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));

View File

@ -87,8 +87,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
@ -101,8 +101,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(1));
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "auto_pad") {
attr->padMode = GetOnnxPadMode(onnx_node_attr);
} else if (onnx_node_attr.name() == "pads") {
@ -119,8 +119,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
}
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1));
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "order") {
if (onnx_node_attr.s() == "NHWC") {
attr->format = schema::Format::Format_NHWC;