forked from mindspore-Ecosystem/mindspore
support infer datatype and format when shape infer fail
This commit is contained in:
parent
9d55ac62c8
commit
475543aadf
|
@ -40,6 +40,11 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
auto input_shape0 = input0->shape();
|
||||
auto input_shape1 = input1->shape();
|
||||
auto format = input0->GetFormat();
|
||||
output->SetFormat(format);
|
||||
output->set_data_type(input0->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
in_shape0_.resize(5);
|
||||
in_shape1_.resize(5);
|
||||
out_shape_.resize(5);
|
||||
|
@ -94,9 +99,8 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
}
|
||||
output_shape.push_back(out_shape_[i]);
|
||||
}
|
||||
output->SetFormat(format);
|
||||
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input0->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -26,10 +26,12 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -39,6 +39,11 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
|
|||
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
|
||||
return RET_FORMAT_ERR;
|
||||
}
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
|
@ -86,9 +91,7 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
|
|||
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
|
||||
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
|
||||
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -34,6 +34,11 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_LOG(ERROR) << "output size is error";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto concat_prim = this->primitive->value_as_Concat();
|
||||
MS_ASSERT(concat_prim != nullptr);
|
||||
auto input0_shape = inputs_.at(0)->shape();
|
||||
|
@ -74,9 +79,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
auto output_shape = input0_shape;
|
||||
output_shape[axis] = output_axis_dim;
|
||||
outputs_[0]->set_shape(output_shape);
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -66,6 +66,11 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_ASSERT(input_tensor != nullptr);
|
||||
MS_ASSERT(out_tensor != nullptr);
|
||||
|
||||
out_tensor->SetFormat(input_tensor->GetFormat());
|
||||
out_tensor->set_data_type(input_tensor->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto in_shape = input_tensor->shape();
|
||||
int input_h = in_shape.at(1);
|
||||
int input_w = in_shape.at(2);
|
||||
|
@ -78,8 +83,7 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
out_shape.at(2) = output_w;
|
||||
out_shape.at(3) = weight_tensor->shape()[0];
|
||||
out_tensor->set_shape(out_shape);
|
||||
out_tensor->SetFormat(input_tensor->GetFormat());
|
||||
out_tensor->set_data_type(input_tensor->data_type());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -37,7 +37,11 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto gather_prim = this->primitive->value_as_Gather();
|
||||
MS_ASSERT(gather_prim != nullptr);
|
||||
|
||||
|
@ -70,8 +74,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
}
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -158,9 +158,12 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -33,6 +33,11 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
if (input == nullptr || output == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (this->primitive == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
@ -72,8 +77,6 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
}
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -82,6 +82,11 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto reshape_prim = this->primitive->value_as_Reshape();
|
||||
MS_ASSERT(reshape_prim != nullptr);
|
||||
|
||||
|
@ -133,9 +138,6 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
}
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -37,6 +37,15 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
|||
|
||||
auto in_tensor = inputs_.front();
|
||||
auto out_tensor = outputs_.front();
|
||||
auto ret_dtype = out_tensor->set_data_type(kNumberTypeInt32);
|
||||
if (ret_dtype != in_tensor->data_type()) {
|
||||
MS_LOG(ERROR) << "Set datatype fails.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> out_shape;
|
||||
out_shape.push_back(static_cast<int>(in_tensor->shape().size()));
|
||||
|
||||
|
@ -45,18 +54,6 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
|||
MS_LOG(ERROR) << "Set shape fails.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret_dtype = out_tensor->set_data_type(in_tensor->data_type());
|
||||
if (ret_dtype != in_tensor->data_type()) {
|
||||
MS_LOG(ERROR) << "Set datatype fails.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// todo
|
||||
// auto ret_data = out_tensor->MallocData();
|
||||
// if (ret_data != 0) {
|
||||
// MS_LOG(ERROR) << "Allocate memory fails.";
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -32,6 +32,11 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
auto slice_prim = this->primitive->value_as_Slice();
|
||||
std::vector<int32_t> slice_begin(slice_prim->begin()->begin(), slice_prim->begin()->end());
|
||||
|
@ -61,9 +66,6 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
}
|
||||
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -26,9 +26,12 @@ int SoftMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -26,7 +26,11 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
MS_ASSERT(inputs_.size() == kSingleNum);
|
||||
MS_ASSERT(outputs_.size() == kSingleNum);
|
||||
auto transpore_prim = this->primitive->value_as_Transpose();
|
||||
|
@ -46,8 +50,6 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
|||
}
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -32,6 +32,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
|||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "output size is invalid";
|
||||
}
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto unsqueeze_prim = this->primitive->value_as_Unsqueeze();
|
||||
auto dims = unsqueeze_prim->axis()->data();
|
||||
auto in_shape = input->shape();
|
||||
|
@ -65,9 +70,7 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
|||
}
|
||||
}
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -58,6 +58,7 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
|||
}
|
||||
if (pack_input_ != nullptr) {
|
||||
free(pack_input_);
|
||||
pack_input_ = nullptr;
|
||||
}
|
||||
}
|
||||
bool pre_trans_input_ = false;
|
||||
|
|
|
@ -223,11 +223,11 @@ bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
|
|||
|
||||
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const Context *ctx,
|
||||
OpParameter *op_parameter, const Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
int kernel_w = conv_param->kernel_w_;
|
||||
int stride_h = conv_param->stride_h_;
|
||||
|
@ -239,25 +239,28 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
|||
conv_param->output_h_ = outputs.front()->Height();
|
||||
conv_param->output_w_ = outputs.front()->Width();
|
||||
bool use_winograd = false;
|
||||
bool use_sw = false;
|
||||
int out_unit;
|
||||
InputTransformUnitFunc input_trans_func = nullptr;
|
||||
OutputTransformUnitFunc output_trans_func = nullptr;
|
||||
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
|
||||
bool use_sw = CheckIfUseSlideWindow(conv_param);
|
||||
if (primitive != nullptr && primitive->GetInferFlag()) {
|
||||
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
|
||||
use_sw = CheckIfUseSlideWindow(conv_param);
|
||||
}
|
||||
|
||||
kernel::LiteKernel *kernel;
|
||||
if (kernel_h == 1 && kernel_w == 1) {
|
||||
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
|
||||
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
} else if (use_winograd) {
|
||||
kernel =
|
||||
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
|
||||
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit);
|
||||
} else if (use_sw) {
|
||||
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
} else {
|
||||
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
}
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
|
@ -266,8 +269,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
|||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK && ret != RET_INFER_INVALID) {
|
||||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
|
|
|
@ -116,6 +116,12 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
|
|||
}
|
||||
} else {
|
||||
primitive->SetInferFlag(false);
|
||||
auto ret = primitive->InferShape(inputs, outputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InferShape fail! name: " << cNode->name()->str()
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type());
|
||||
return RET_INFER_ERR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
Loading…
Reference in New Issue