forked from mindspore-Ecosystem/mindspore
fix output shape error in deconv operator and support parameter passing when doing infershape during inference
This commit is contained in:
parent
9b2b062642
commit
f533dd6579
|
@ -22,6 +22,7 @@
|
|||
#include "src/runtime/kernel/arm/int8/add_int8.h"
|
||||
#include "src/runtime/kernel/arm/int8/mul_int8.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/populate_parameter.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -40,6 +41,31 @@ int ArithmeticCPUKernel::Init() {
|
|||
return ReSize();
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::PreProcess() {
|
||||
if (!InferShapeDone()) {
|
||||
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(true);
|
||||
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
|
||||
if (ret != 0) {
|
||||
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(false);
|
||||
MS_LOG(ERROR) << "InferShape fail!";
|
||||
return ret;
|
||||
}
|
||||
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(kernel::PopulateArithmetic(primitive_));
|
||||
ret = ReSize();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
auto outputs = this->out_tensors();
|
||||
for (auto *output : outputs) {
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->MallocData();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::ReSize() {
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
|
|
|
@ -163,6 +163,7 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
~ArithmeticCPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int PreProcess() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoArithmetic(int task_id);
|
||||
|
|
|
@ -81,7 +81,7 @@ int QuantizedAddCPUKernel::Run() {
|
|||
input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData());
|
||||
output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData());
|
||||
|
||||
elements_num_ = in_tensors_.at(0)->ElementsNum();
|
||||
elements_num_ = out_tensors_.at(0)->ElementsNum();
|
||||
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
|
||||
|
||||
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
|
||||
|
|
|
@ -106,7 +106,7 @@ int MulInt8CPUKernel::Run() {
|
|||
input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData());
|
||||
output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData());
|
||||
|
||||
elements_num_ = in_tensors_.at(0)->ElementsNum();
|
||||
elements_num_ = out_tensors_.at(0)->ElementsNum();
|
||||
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
|
||||
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
|
||||
input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
|
|
|
@ -87,8 +87,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "kernels") {
|
||||
if (onnx_node_attr.ints().size() != 2) {
|
||||
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
|
@ -101,8 +101,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "auto_pad") {
|
||||
attr->padMode = GetOnnxPadMode(onnx_node_attr);
|
||||
} else if (onnx_node_attr.name() == "pads") {
|
||||
|
@ -119,8 +119,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
} else if (onnx_node_attr.name() == "order") {
|
||||
if (onnx_node_attr.s() == "NHWC") {
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
|
|
Loading…
Reference in New Issue