support infer datatype and format when shape infer fail

This commit is contained in:
chenjianping 2020-08-18 10:53:11 +08:00
parent 9d55ac62c8
commit 475543aadf
17 changed files with 96 additions and 54 deletions

View File

@ -40,6 +40,11 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
auto input_shape0 = input0->shape();
auto input_shape1 = input1->shape();
auto format = input0->GetFormat();
output->SetFormat(format);
output->set_data_type(input0->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
in_shape0_.resize(5);
in_shape1_.resize(5);
out_shape_.resize(5);
@ -94,9 +99,8 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
}
output_shape.push_back(out_shape_[i]);
}
output->SetFormat(format);
output->set_shape(output_shape);
output->set_data_type(input0->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -26,10 +26,12 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->SetFormat(input->GetFormat());
output->set_shape(input->shape());
output->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK;
}

View File

@ -39,6 +39,11 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
if (input_shape.size() != kDimension_4d) {
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
@ -86,9 +91,7 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -34,6 +34,11 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG(ERROR) << "output size is error";
return RET_PARAM_INVALID;
}
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto concat_prim = this->primitive->value_as_Concat();
MS_ASSERT(concat_prim != nullptr);
auto input0_shape = inputs_.at(0)->shape();
@ -74,9 +79,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
auto output_shape = input0_shape;
output_shape[axis] = output_axis_dim;
outputs_[0]->set_shape(output_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -66,6 +66,11 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(out_tensor != nullptr);
out_tensor->SetFormat(input_tensor->GetFormat());
out_tensor->set_data_type(input_tensor->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
auto in_shape = input_tensor->shape();
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
@ -78,8 +83,7 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
out_shape.at(2) = output_w;
out_shape.at(3) = weight_tensor->shape()[0];
out_tensor->set_shape(out_shape);
out_tensor->SetFormat(input_tensor->GetFormat());
out_tensor->set_data_type(input_tensor->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -37,7 +37,11 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(input != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto gather_prim = this->primitive->value_as_Gather();
MS_ASSERT(gather_prim != nullptr);
@ -70,8 +74,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}

View File

@ -158,9 +158,12 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK;
}

View File

@ -33,6 +33,11 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
if (input == nullptr || output == nullptr) {
return RET_NULL_PTR;
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if (this->primitive == nullptr) {
return RET_NULL_PTR;
}
@ -72,8 +77,6 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}

View File

@ -82,6 +82,11 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto reshape_prim = this->primitive->value_as_Reshape();
MS_ASSERT(reshape_prim != nullptr);
@ -133,9 +138,6 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -37,6 +37,15 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front();
auto ret_dtype = out_tensor->set_data_type(kNumberTypeInt32);
if (ret_dtype != in_tensor->data_type()) {
MS_LOG(ERROR) << "Set datatype fails.";
return RET_ERROR;
}
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> out_shape;
out_shape.push_back(static_cast<int>(in_tensor->shape().size()));
@ -45,18 +54,6 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_LOG(ERROR) << "Set shape fails.";
return RET_ERROR;
}
auto ret_dtype = out_tensor->set_data_type(in_tensor->data_type());
if (ret_dtype != in_tensor->data_type()) {
MS_LOG(ERROR) << "Set datatype fails.";
return RET_ERROR;
}
// todo
// auto ret_data = out_tensor->MallocData();
// if (ret_data != 0) {
// MS_LOG(ERROR) << "Allocate memory fails.";
// return RET_ERROR;
// }
return RET_OK;
}

View File

@ -32,6 +32,11 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_PARAM_INVALID;
}
auto input = inputs.at(0);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto input_shape = input->shape();
auto slice_prim = this->primitive->value_as_Slice();
std::vector<int32_t> slice_begin(slice_prim->begin()->begin(), slice_prim->begin()->end());
@ -61,9 +66,6 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
}
outputs[0]->set_shape(output_shape);
outputs[0]->set_data_type(input->data_type());
outputs[0]->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -26,9 +26,12 @@ int SoftMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
return RET_OK;
}

View File

@ -26,7 +26,11 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
MS_ASSERT(inputs_.size() == kSingleNum);
MS_ASSERT(outputs_.size() == kSingleNum);
auto transpore_prim = this->primitive->value_as_Transpose();
@ -46,8 +50,6 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
output->set_shape(out_shape);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
return RET_OK;
}

View File

@ -32,6 +32,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "output size is invalid";
}
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
auto unsqueeze_prim = this->primitive->value_as_Unsqueeze();
auto dims = unsqueeze_prim->axis()->data();
auto in_shape = input->shape();
@ -65,9 +70,7 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
}
output->SetFormat(input->GetFormat());
output->set_shape(out_shape);
output->set_data_type(input->data_type());
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -58,6 +58,7 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
}
if (pack_input_ != nullptr) {
free(pack_input_);
pack_input_ = nullptr;
}
}
bool pre_trans_input_ = false;

View File

@ -223,11 +223,11 @@ bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
OpParameter *op_parameter, const Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(op_parameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
@ -239,25 +239,28 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width();
bool use_winograd = false;
bool use_sw = false;
int out_unit;
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
bool use_sw = CheckIfUseSlideWindow(conv_param);
if (primitive != nullptr && primitive->GetInferFlag()) {
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
use_sw = CheckIfUseSlideWindow(conv_param);
}
kernel::LiteKernel *kernel;
if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else if (use_winograd) {
kernel =
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit);
} else if (use_sw) {
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
@ -266,8 +269,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
auto ret = kernel->Init();
if (ret != RET_OK && ret != RET_INFER_INVALID) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;

View File

@ -116,6 +116,12 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
}
} else {
primitive->SetInferFlag(false);
auto ret = primitive->InferShape(inputs, outputs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InferShape fail! name: " << cNode->name()->str()
<< ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type());
return RET_INFER_ERR;
}
}
}
return RET_OK;