forked from mindspore-Ecosystem/mindspore
support infer datatype and format when shape infer fail
This commit is contained in:
parent
9d55ac62c8
commit
475543aadf
|
@ -40,6 +40,11 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
||||||
auto input_shape0 = input0->shape();
|
auto input_shape0 = input0->shape();
|
||||||
auto input_shape1 = input1->shape();
|
auto input_shape1 = input1->shape();
|
||||||
auto format = input0->GetFormat();
|
auto format = input0->GetFormat();
|
||||||
|
output->SetFormat(format);
|
||||||
|
output->set_data_type(input0->data_type());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
in_shape0_.resize(5);
|
in_shape0_.resize(5);
|
||||||
in_shape1_.resize(5);
|
in_shape1_.resize(5);
|
||||||
out_shape_.resize(5);
|
out_shape_.resize(5);
|
||||||
|
@ -94,9 +99,8 @@ int Arithmetic::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
||||||
}
|
}
|
||||||
output_shape.push_back(out_shape_[i]);
|
output_shape.push_back(out_shape_[i]);
|
||||||
}
|
}
|
||||||
output->SetFormat(format);
|
|
||||||
output->set_shape(output_shape);
|
output->set_shape(output_shape);
|
||||||
output->set_data_type(input0->data_type());
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -26,10 +26,12 @@ int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
|
||||||
MS_ASSERT(input != nullptr);
|
MS_ASSERT(input != nullptr);
|
||||||
auto output = outputs_.front();
|
auto output = outputs_.front();
|
||||||
MS_ASSERT(output != nullptr);
|
MS_ASSERT(output != nullptr);
|
||||||
|
|
||||||
output->SetFormat(input->GetFormat());
|
output->SetFormat(input->GetFormat());
|
||||||
output->set_shape(input->shape());
|
|
||||||
output->set_data_type(input->data_type());
|
output->set_data_type(input->data_type());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
output->set_shape(input->shape());
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,11 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
|
||||||
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
|
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
|
||||||
return RET_FORMAT_ERR;
|
return RET_FORMAT_ERR;
|
||||||
}
|
}
|
||||||
|
outputs[0]->SetFormat(input->GetFormat());
|
||||||
|
outputs[0]->set_data_type(input->data_type());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
auto input_shape = input->shape();
|
auto input_shape = input->shape();
|
||||||
if (input_shape.size() != kDimension_4d) {
|
if (input_shape.size() != kDimension_4d) {
|
||||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||||
|
@ -86,9 +91,7 @@ int BatchToSpace::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
|
||||||
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
|
output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3);
|
||||||
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
|
output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index];
|
||||||
|
|
||||||
outputs[0]->SetFormat(input->GetFormat());
|
|
||||||
outputs[0]->set_shape(output_shape);
|
outputs[0]->set_shape(output_shape);
|
||||||
outputs[0]->set_data_type(input->data_type());
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -34,6 +34,11 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
||||||
MS_LOG(ERROR) << "output size is error";
|
MS_LOG(ERROR) << "output size is error";
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
output->set_data_type(input0->data_type());
|
||||||
|
output->SetFormat(input0->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
auto concat_prim = this->primitive->value_as_Concat();
|
auto concat_prim = this->primitive->value_as_Concat();
|
||||||
MS_ASSERT(concat_prim != nullptr);
|
MS_ASSERT(concat_prim != nullptr);
|
||||||
auto input0_shape = inputs_.at(0)->shape();
|
auto input0_shape = inputs_.at(0)->shape();
|
||||||
|
@ -74,9 +79,6 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
||||||
auto output_shape = input0_shape;
|
auto output_shape = input0_shape;
|
||||||
output_shape[axis] = output_axis_dim;
|
output_shape[axis] = output_axis_dim;
|
||||||
outputs_[0]->set_shape(output_shape);
|
outputs_[0]->set_shape(output_shape);
|
||||||
output->set_data_type(input0->data_type());
|
|
||||||
output->SetFormat(input0->GetFormat());
|
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -66,6 +66,11 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
||||||
MS_ASSERT(input_tensor != nullptr);
|
MS_ASSERT(input_tensor != nullptr);
|
||||||
MS_ASSERT(out_tensor != nullptr);
|
MS_ASSERT(out_tensor != nullptr);
|
||||||
|
|
||||||
|
out_tensor->SetFormat(input_tensor->GetFormat());
|
||||||
|
out_tensor->set_data_type(input_tensor->data_type());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
auto in_shape = input_tensor->shape();
|
auto in_shape = input_tensor->shape();
|
||||||
int input_h = in_shape.at(1);
|
int input_h = in_shape.at(1);
|
||||||
int input_w = in_shape.at(2);
|
int input_w = in_shape.at(2);
|
||||||
|
@ -78,8 +83,7 @@ int Conv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
||||||
out_shape.at(2) = output_w;
|
out_shape.at(2) = output_w;
|
||||||
out_shape.at(3) = weight_tensor->shape()[0];
|
out_shape.at(3) = weight_tensor->shape()[0];
|
||||||
out_tensor->set_shape(out_shape);
|
out_tensor->set_shape(out_shape);
|
||||||
out_tensor->SetFormat(input_tensor->GetFormat());
|
|
||||||
out_tensor->set_data_type(input_tensor->data_type());
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -37,7 +37,11 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
||||||
MS_ASSERT(input != nullptr);
|
MS_ASSERT(input != nullptr);
|
||||||
auto output = outputs_.front();
|
auto output = outputs_.front();
|
||||||
MS_ASSERT(input != nullptr);
|
MS_ASSERT(input != nullptr);
|
||||||
|
output->set_data_type(input->data_type());
|
||||||
|
output->SetFormat(input->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
auto gather_prim = this->primitive->value_as_Gather();
|
auto gather_prim = this->primitive->value_as_Gather();
|
||||||
MS_ASSERT(gather_prim != nullptr);
|
MS_ASSERT(gather_prim != nullptr);
|
||||||
|
|
||||||
|
@ -70,8 +74,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
output->set_shape(out_shape);
|
output->set_shape(out_shape);
|
||||||
output->set_data_type(input->data_type());
|
|
||||||
output->SetFormat(input->GetFormat());
|
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,9 +158,12 @@ int Primitive::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
||||||
MS_ASSERT(input != nullptr);
|
MS_ASSERT(input != nullptr);
|
||||||
auto output = outputs_.front();
|
auto output = outputs_.front();
|
||||||
MS_ASSERT(output != nullptr);
|
MS_ASSERT(output != nullptr);
|
||||||
output->set_shape(input->shape());
|
|
||||||
output->set_data_type(input->data_type());
|
output->set_data_type(input->data_type());
|
||||||
output->SetFormat(input->GetFormat());
|
output->SetFormat(input->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
output->set_shape(input->shape());
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,11 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
||||||
if (input == nullptr || output == nullptr) {
|
if (input == nullptr || output == nullptr) {
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
output->set_data_type(input->data_type());
|
||||||
|
output->SetFormat(input->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
if (this->primitive == nullptr) {
|
if (this->primitive == nullptr) {
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
@ -72,8 +77,6 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output->set_shape(out_shape);
|
output->set_shape(out_shape);
|
||||||
output->set_data_type(input->data_type());
|
|
||||||
output->SetFormat(input->GetFormat());
|
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,6 +82,11 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
||||||
MS_ASSERT(input != nullptr);
|
MS_ASSERT(input != nullptr);
|
||||||
auto output = outputs_.front();
|
auto output = outputs_.front();
|
||||||
MS_ASSERT(output != nullptr);
|
MS_ASSERT(output != nullptr);
|
||||||
|
output->set_data_type(input->data_type());
|
||||||
|
output->SetFormat(input->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
auto reshape_prim = this->primitive->value_as_Reshape();
|
auto reshape_prim = this->primitive->value_as_Reshape();
|
||||||
MS_ASSERT(reshape_prim != nullptr);
|
MS_ASSERT(reshape_prim != nullptr);
|
||||||
|
|
||||||
|
@ -133,9 +138,6 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
||||||
}
|
}
|
||||||
|
|
||||||
output->set_shape(out_shape);
|
output->set_shape(out_shape);
|
||||||
output->set_data_type(input->data_type());
|
|
||||||
output->SetFormat(input->GetFormat());
|
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -37,6 +37,15 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
||||||
|
|
||||||
auto in_tensor = inputs_.front();
|
auto in_tensor = inputs_.front();
|
||||||
auto out_tensor = outputs_.front();
|
auto out_tensor = outputs_.front();
|
||||||
|
auto ret_dtype = out_tensor->set_data_type(kNumberTypeInt32);
|
||||||
|
if (ret_dtype != in_tensor->data_type()) {
|
||||||
|
MS_LOG(ERROR) << "Set datatype fails.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
std::vector<int> out_shape;
|
std::vector<int> out_shape;
|
||||||
out_shape.push_back(static_cast<int>(in_tensor->shape().size()));
|
out_shape.push_back(static_cast<int>(in_tensor->shape().size()));
|
||||||
|
|
||||||
|
@ -45,18 +54,6 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
||||||
MS_LOG(ERROR) << "Set shape fails.";
|
MS_LOG(ERROR) << "Set shape fails.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto ret_dtype = out_tensor->set_data_type(in_tensor->data_type());
|
|
||||||
if (ret_dtype != in_tensor->data_type()) {
|
|
||||||
MS_LOG(ERROR) << "Set datatype fails.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
// todo
|
|
||||||
// auto ret_data = out_tensor->MallocData();
|
|
||||||
// if (ret_data != 0) {
|
|
||||||
// MS_LOG(ERROR) << "Allocate memory fails.";
|
|
||||||
// return RET_ERROR;
|
|
||||||
// }
|
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,11 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
auto input = inputs.at(0);
|
auto input = inputs.at(0);
|
||||||
|
outputs[0]->set_data_type(input->data_type());
|
||||||
|
outputs[0]->SetFormat(input->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
auto input_shape = input->shape();
|
auto input_shape = input->shape();
|
||||||
auto slice_prim = this->primitive->value_as_Slice();
|
auto slice_prim = this->primitive->value_as_Slice();
|
||||||
std::vector<int32_t> slice_begin(slice_prim->begin()->begin(), slice_prim->begin()->end());
|
std::vector<int32_t> slice_begin(slice_prim->begin()->begin(), slice_prim->begin()->end());
|
||||||
|
@ -61,9 +66,6 @@ int Slice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs[0]->set_shape(output_shape);
|
outputs[0]->set_shape(output_shape);
|
||||||
outputs[0]->set_data_type(input->data_type());
|
|
||||||
outputs[0]->SetFormat(input->GetFormat());
|
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -26,9 +26,12 @@ int SoftMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
||||||
MS_ASSERT(input != nullptr);
|
MS_ASSERT(input != nullptr);
|
||||||
auto output = outputs_.front();
|
auto output = outputs_.front();
|
||||||
MS_ASSERT(output != nullptr);
|
MS_ASSERT(output != nullptr);
|
||||||
output->set_shape(input->shape());
|
|
||||||
output->set_data_type(input->data_type());
|
output->set_data_type(input->data_type());
|
||||||
output->SetFormat(input->GetFormat());
|
output->SetFormat(input->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
output->set_shape(input->shape());
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,11 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
||||||
MS_ASSERT(input != nullptr);
|
MS_ASSERT(input != nullptr);
|
||||||
auto output = outputs_.front();
|
auto output = outputs_.front();
|
||||||
MS_ASSERT(output != nullptr);
|
MS_ASSERT(output != nullptr);
|
||||||
|
output->set_data_type(input->data_type());
|
||||||
|
output->SetFormat(input->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
MS_ASSERT(inputs_.size() == kSingleNum);
|
MS_ASSERT(inputs_.size() == kSingleNum);
|
||||||
MS_ASSERT(outputs_.size() == kSingleNum);
|
MS_ASSERT(outputs_.size() == kSingleNum);
|
||||||
auto transpore_prim = this->primitive->value_as_Transpose();
|
auto transpore_prim = this->primitive->value_as_Transpose();
|
||||||
|
@ -46,8 +50,6 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
||||||
}
|
}
|
||||||
|
|
||||||
output->set_shape(out_shape);
|
output->set_shape(out_shape);
|
||||||
output->set_data_type(input->data_type());
|
|
||||||
output->SetFormat(input->GetFormat());
|
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
||||||
if (outputs_.size() != kSingleNum) {
|
if (outputs_.size() != kSingleNum) {
|
||||||
MS_LOG(ERROR) << "output size is invalid";
|
MS_LOG(ERROR) << "output size is invalid";
|
||||||
}
|
}
|
||||||
|
output->set_data_type(input->data_type());
|
||||||
|
output->SetFormat(input->GetFormat());
|
||||||
|
if (!GetInferFlag()) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
auto unsqueeze_prim = this->primitive->value_as_Unsqueeze();
|
auto unsqueeze_prim = this->primitive->value_as_Unsqueeze();
|
||||||
auto dims = unsqueeze_prim->axis()->data();
|
auto dims = unsqueeze_prim->axis()->data();
|
||||||
auto in_shape = input->shape();
|
auto in_shape = input->shape();
|
||||||
|
@ -65,9 +70,7 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
output->SetFormat(input->GetFormat());
|
|
||||||
output->set_shape(out_shape);
|
output->set_shape(out_shape);
|
||||||
output->set_data_type(input->data_type());
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -58,6 +58,7 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
||||||
}
|
}
|
||||||
if (pack_input_ != nullptr) {
|
if (pack_input_ != nullptr) {
|
||||||
free(pack_input_);
|
free(pack_input_);
|
||||||
|
pack_input_ = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool pre_trans_input_ = false;
|
bool pre_trans_input_ = false;
|
||||||
|
|
|
@ -223,11 +223,11 @@ bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
|
||||||
|
|
||||||
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||||
OpParameter *opParameter, const Context *ctx,
|
OpParameter *op_parameter, const Context *ctx,
|
||||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||||
MS_ASSERT(opParameter != nullptr);
|
MS_ASSERT(op_parameter != nullptr);
|
||||||
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
|
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
|
||||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
|
||||||
int kernel_h = conv_param->kernel_h_;
|
int kernel_h = conv_param->kernel_h_;
|
||||||
int kernel_w = conv_param->kernel_w_;
|
int kernel_w = conv_param->kernel_w_;
|
||||||
int stride_h = conv_param->stride_h_;
|
int stride_h = conv_param->stride_h_;
|
||||||
|
@ -239,25 +239,28 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
||||||
conv_param->output_h_ = outputs.front()->Height();
|
conv_param->output_h_ = outputs.front()->Height();
|
||||||
conv_param->output_w_ = outputs.front()->Width();
|
conv_param->output_w_ = outputs.front()->Width();
|
||||||
bool use_winograd = false;
|
bool use_winograd = false;
|
||||||
|
bool use_sw = false;
|
||||||
int out_unit;
|
int out_unit;
|
||||||
InputTransformUnitFunc input_trans_func = nullptr;
|
InputTransformUnitFunc input_trans_func = nullptr;
|
||||||
OutputTransformUnitFunc output_trans_func = nullptr;
|
OutputTransformUnitFunc output_trans_func = nullptr;
|
||||||
|
if (primitive != nullptr && primitive->GetInferFlag()) {
|
||||||
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
|
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
|
||||||
bool use_sw = CheckIfUseSlideWindow(conv_param);
|
use_sw = CheckIfUseSlideWindow(conv_param);
|
||||||
|
}
|
||||||
|
|
||||||
kernel::LiteKernel *kernel;
|
kernel::LiteKernel *kernel;
|
||||||
if (kernel_h == 1 && kernel_w == 1) {
|
if (kernel_h == 1 && kernel_w == 1) {
|
||||||
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||||
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
|
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
|
||||||
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||||
} else if (use_winograd) {
|
} else if (use_winograd) {
|
||||||
kernel =
|
kernel =
|
||||||
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
|
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit);
|
||||||
} else if (use_sw) {
|
} else if (use_sw) {
|
||||||
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
// kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||||
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||||
} else {
|
} else {
|
||||||
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||||
}
|
}
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||||
|
@ -266,8 +269,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
if (ret != RET_OK && ret != RET_INFER_INVALID) {
|
if (ret != RET_OK && ret != RET_INFER_INVALID) {
|
||||||
delete kernel;
|
delete kernel;
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
|
|
|
@ -116,6 +116,12 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
primitive->SetInferFlag(false);
|
primitive->SetInferFlag(false);
|
||||||
|
auto ret = primitive->InferShape(inputs, outputs);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "InferShape fail! name: " << cNode->name()->str()
|
||||||
|
<< ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type());
|
||||||
|
return RET_INFER_ERR;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
|
Loading…
Reference in New Issue