diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc index f6282cd5862..6cd9b81aec2 100644 --- a/mindspore/lite/src/model_impl.cc +++ b/mindspore/lite/src/model_impl.cc @@ -80,6 +80,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { return new lite::Activation(const_cast(srcPrim)); case schema::PrimitiveType_Conv2D: return new lite::Conv2D(const_cast(srcPrim)); + case schema::PrimitiveType_DeConv2D: + return new lite::DeConv2D(const_cast(srcPrim)); case schema::PrimitiveType_Reduce: return new lite::Reduce(const_cast(srcPrim)); case schema::PrimitiveType_Pooling: diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc index c353f2ca94e..f51789f4e3e 100644 --- a/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc @@ -81,7 +81,7 @@ STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pat } auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); MS_ASSERT(baNodeBiasTensor != nullptr); - if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) { + if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) { // dont fusion, return return RET_OK; } @@ -215,7 +215,9 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr convPath, << ". or bias tensor is a scaler"; return RET_ERROR; } - if (!biasDims.empty() && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { + + bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1; + if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; return RET_ERROR; @@ -234,6 +236,11 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr convPath, MS_LOG(ERROR) << "memset_s newBiasData failed"; return RET_ERROR; } + } else if (bias_const) { + auto *biasData = reinterpret_cast(biasTensor->data.data()); + for (size_t i = 0; i < kernelNum; i++) { + newBiasData[i] = *biasData; + } } else { if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { MS_LOG(ERROR) << "memcpy_s newBiasData failed"; diff --git a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc index ff173918b74..582e7146862 100644 --- a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc @@ -153,6 +153,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { weightTensor->format = schema::Format_KHWC; } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { weightTensor->format = schema::Format_CHWK; + } else if (opType == schema::PrimitiveType_DeConv2D) { + weightTensor->format = schema::Format_KHWC; } else { MS_LOG(ERROR) << "unsupport format"; return -1; @@ -356,18 +358,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); // todo(00445839): consider varible weight condition } - } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW - if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx - return 0; - } else if (weightTensor->format == schema::Format_HWKC) { // from tf - status = TransFilterFormat(weightTensor.get(), kHWKC2KCHW); + } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC + if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); + } else if (weightTensor->format == schema::Format_KHWC) { // from tf + status = RET_OK; } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; return -1; } if (status == 0) { - node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; - weightTensor->format = schema::Format_KCHW; + node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_KHWC; } else { MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); // todo(00445839): consider varible weight condition diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc index 377ecab1676..1a0e3aa01ba 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc @@ -27,8 +27,16 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr &tfliteOp schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { - // MS_LOGD("parse TfliteAddParser"); + MS_LOG(DEBUG) << "parse TfliteAddParser"; std::unique_ptr attr(new schema::AddT()); + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + return RET_ERROR; + } + if (op != nullptr) { op->primitive = std::make_unique(); op->primitive->value.type = schema::PrimitiveType_Add;