support infer datatype and format when shape infer fail

This commit is contained in:
chenjianping 2020-08-18 10:53:11 +08:00
parent 9d55ac62c8
commit 475543aadf
17 changed files with 96 additions and 54 deletions

View File

@ -40,6 +40,11 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto input_shape0 = input0->shape(); auto input_shape0 = input0->shape();
auto input_shape1 = input1->shape(); auto input_shape1 = input1->shape();
auto format = input0->GetFormat(); auto format = input0->GetFormat();
output->SetFormat(format);
output->set_data_type(input0->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
in_shape0_.resize(5); in_shape0_.resize(5);
in_shape1_.resize(5); in_shape1_.resize(5);
out_shape_.resize(5); out_shape_.resize(5);
@ -94,9 +99,8 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
} }
output_shape.push_back(out_shape_[i]); output_shape.push_back(out_shape_[i]);
} }
output->SetFormat(format);
output->set_shape(output_shape); output->set_shape(output_shape);
output->set_data_type(input0->data_type());
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -26,10 +26,12 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto output = outputs_.front(); auto output = outputs_.front();
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat()); output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type()); output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK; return RET_OK;
} }

View File

@ -39,6 +39,11 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR; return RET_FORMAT_ERR;
} }
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape(); auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) { if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
@ -86,9 +91,7 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3); output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index]; output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(output_shape); outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -34,6 +34,11 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG(ERROR) << "output size is error"; MS_LOG(ERROR) << "output size is error";
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto concat_prim = this->primitive->value_as_Concat(); auto concat_prim = this->primitive->value_as_Concat();
MS_ASSERT(concat_prim != nullptr); MS_ASSERT(concat_prim != nullptr);
auto input0_shape = inputs_.at(0)->shape(); auto input0_shape = inputs_.at(0)->shape();
@ -74,9 +79,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
auto output_shape = input0_shape; auto output_shape = input0_shape;
output_shape[axis] = output_axis_dim; output_shape[axis] = output_axis_dim;
outputs_[0]->set_shape(output_shape); outputs_[0]->set_shape(output_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -66,6 +66,11 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT(input_tensor != nullptr); MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(out_tensor != nullptr); MS_ASSERT(out_tensor != nullptr);
out_tensor->SetFormat(input_tensor->GetFormat());
out_tensor->set_data_type(input_tensor->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = input_tensor->shape(); auto in_shape = input_tensor->shape();
int input_h = in_shape.at(1); int input_h = in_shape.at(1);
int input_w = in_shape.at(2); int input_w = in_shape.at(2);
@ -78,8 +83,7 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
out_shape.at(2) = output_w; out_shape.at(2) = output_w;
out_shape.at(3) = weight_tensor->shape()[0]; out_shape.at(3) = weight_tensor->shape()[0];
out_tensor->set_shape(out_shape); out_tensor->set_shape(out_shape);
out_tensor->SetFormat(input_tensor->GetFormat());
out_tensor->set_data_type(input_tensor->data_type());
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -37,7 +37,11 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto output = outputs_.front(); auto output = outputs_.front();
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto gather_prim = this->primitive->value_as_Gather(); auto gather_prim = this->primitive->value_as_Gather();
MS_ASSERT(gather_prim != nullptr); MS_ASSERT(gather_prim != nullptr);
@ -70,8 +74,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} }
output->set_shape(out_shape); output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK; return RET_OK;
} }

View File

@ -158,9 +158,12 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto output = outputs_.front(); auto output = outputs_.front();
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type()); output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat()); output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK; return RET_OK;
} }

View File

@ -33,6 +33,11 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (input == nullptr || output == nullptr) { if (input == nullptr || output == nullptr) {
return RET_NULL_PTR; return RET_NULL_PTR;
} }
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if (this->primitive == nullptr) { if (this->primitive == nullptr) {
return RET_NULL_PTR; return RET_NULL_PTR;
} }
@ -72,8 +77,6 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} }
} }
output->set_shape(out_shape); output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK; return RET_OK;
} }

View File

@ -82,6 +82,11 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto output = outputs_.front(); auto output = outputs_.front();
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto reshape_prim = this->primitive->value_as_Reshape(); auto reshape_prim = this->primitive->value_as_Reshape();
MS_ASSERT(reshape_prim != nullptr); MS_ASSERT(reshape_prim != nullptr);
@ -133,9 +138,6 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
} }
output->set_shape(out_shape); output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -37,6 +37,15 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto in_tensor = inputs_.front(); auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front(); auto out_tensor = outputs_.front();
auto ret_dtype = out_tensor->set_data_type(kNumberTypeInt32);
if (ret_dtype != in_tensor->data_type()) {
MS_LOG(ERROR) << "Set datatype fails.";
return RET_ERROR;
}
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> out_shape; std::vector<int> out_shape;
out_shape.push_back(static_cast<int>(in_tensor->shape().size())); out_shape.push_back(static_cast<int>(in_tensor->shape().size()));
@ -45,18 +54,6 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_LOG(ERROR) << "Set shape fails."; MS_LOG(ERROR) << "Set shape fails.";
return RET_ERROR; return RET_ERROR;
} }
auto ret_dtype = out_tensor->set_data_type(in_tensor->data_type());
if (ret_dtype != in_tensor->data_type()) {
MS_LOG(ERROR) << "Set datatype fails.";
return RET_ERROR;
}
// todo
// auto ret_data = out_tensor->MallocData();
// if (ret_data != 0) {
// MS_LOG(ERROR) << "Allocate memory fails.";
// return RET_ERROR;
// }
return RET_OK; return RET_OK;
} }

View File

@ -32,6 +32,11 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
auto input = inputs.at(0); auto input = inputs.at(0);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape(); auto input_shape = input->shape();
auto slice_prim = this->primitive->value_as_Slice(); auto slice_prim = this->primitive->value_as_Slice();
std::vector<int32_t> slice_begin(slice_prim->begin()->begin(), slice_prim->begin()->end()); std::vector<int32_t> slice_begin(slice_prim->begin()->begin(), slice_prim->begin()->end());
@ -61,9 +66,6 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
} }
outputs[0]->set_shape(output_shape); outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -26,9 +26,12 @@ int SoftMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto output = outputs_.front(); auto output = outputs_.front();
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type()); output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat()); output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK; return RET_OK;
} }

View File

@ -26,7 +26,11 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto output = outputs_.front(); auto output = outputs_.front();
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
MS_ASSERT(inputs_.size() == kSingleNum); MS_ASSERT(inputs_.size() == kSingleNum);
MS_ASSERT(outputs_.size() == kSingleNum); MS_ASSERT(outputs_.size() == kSingleNum);
auto transpore_prim = this->primitive->value_as_Transpose(); auto transpore_prim = this->primitive->value_as_Transpose();
@ -46,8 +50,6 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
} }
output->set_shape(out_shape); output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK; return RET_OK;
} }

View File

@ -32,6 +32,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
if (outputs_.size() != kSingleNum) { if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "output size is invalid"; MS_LOG(ERROR) << "output size is invalid";
} }
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto unsqueeze_prim = this->primitive->value_as_Unsqueeze(); auto unsqueeze_prim = this->primitive->value_as_Unsqueeze();
auto dims = unsqueeze_prim->axis()->data(); auto dims = unsqueeze_prim->axis()->data();
auto in_shape = input->shape(); auto in_shape = input->shape();
@ -65,9 +70,7 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
} }
} }
output->SetFormat(input->GetFormat());
output->set_shape(out_shape); output->set_shape(out_shape);
output->set_data_type(input->data_type());
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -58,6 +58,7 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
} }
if (pack_input_ != nullptr) { if (pack_input_ != nullptr) {
free(pack_input_); free(pack_input_);
pack_input_ = nullptr;
} }
} }
bool pre_trans_input_ = false; bool pre_trans_input_ = false;

View File

@ -223,11 +223,11 @@ bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx, OpParameter *op_parameter, const Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) { const kernel::KernelKey &desc, const lite::Primitive *primitive) {
MS_ASSERT(opParameter != nullptr); MS_ASSERT(op_parameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
int kernel_h = conv_param->kernel_h_; int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_; int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_; int stride_h = conv_param->stride_h_;
@ -239,25 +239,28 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
conv_param->output_h_ = outputs.front()->Height(); conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width(); conv_param->output_w_ = outputs.front()->Width();
bool use_winograd = false; bool use_winograd = false;
bool use_sw = false;
int out_unit; int out_unit;
InputTransformUnitFunc input_trans_func = nullptr; InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr; OutputTransformUnitFunc output_trans_func = nullptr;
if (primitive != nullptr && primitive->GetInferFlag()) {
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
bool use_sw = CheckIfUseSlideWindow(conv_param); use_sw = CheckIfUseSlideWindow(conv_param);
}
kernel::LiteKernel *kernel; kernel::LiteKernel *kernel;
if (kernel_h == 1 && kernel_w == 1) { if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx, primitive); kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx, primitive); kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else if (use_winograd) { } else if (use_winograd) {
kernel = kernel =
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit);
} else if (use_sw) { } else if (use_sw) {
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(opParameter, inputs, outputs, ctx, primitive); // kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive); kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else { } else {
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive); kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} }
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
@ -266,8 +269,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
auto ret = kernel->Init(); auto ret = kernel->Init();
if (ret != RET_OK && ret != RET_INFER_INVALID) { if (ret != RET_OK && ret != RET_INFER_INVALID) {
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr; return nullptr;
} }
return kernel; return kernel;

View File

@ -116,6 +116,12 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
} }
} else { } else {
primitive->SetInferFlag(false); primitive->SetInferFlag(false);
auto ret = primitive->InferShape(inputs, outputs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InferShape fail! name: " << cNode->name()->str()
<< ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type());
return RET_INFER_ERR;
}
} }
} }
return RET_OK; return RET_OK;