!3713 detect model fix bug

Merge pull request !3713 from ghzl/deconv-adapter
This commit is contained in:
mindspore-ci-bot 2020-08-01 17:27:53 +08:00 committed by Gitee
commit 9eb0a7697f
4 changed files with 29 additions and 10 deletions

View File

@ -80,6 +80,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Activation(const_cast<schema::Primitive *>(srcPrim)); return new lite::Activation(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Conv2D: case schema::PrimitiveType_Conv2D:
return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim)); return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_DeConv2D:
return new lite::DeConv2D(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Reduce: case schema::PrimitiveType_Reduce:
return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim)); return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Pooling: case schema::PrimitiveType_Pooling:

View File

@ -81,7 +81,7 @@ STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pat
} }
auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get();
MS_ASSERT(baNodeBiasTensor != nullptr); MS_ASSERT(baNodeBiasTensor != nullptr);
if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) { if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) {
// dont fusion, return // dont fusion, return
return RET_OK; return RET_OK;
} }
@ -215,7 +215,9 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
<< ". or bias tensor is a scaler"; << ". or bias tensor is a scaler";
return RET_ERROR; return RET_ERROR;
} }
if (!biasDims.empty() && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1;
if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)"
<< biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum;
return RET_ERROR; return RET_ERROR;
@ -234,6 +236,11 @@ STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath,
MS_LOG(ERROR) << "memset_s newBiasData failed"; MS_LOG(ERROR) << "memset_s newBiasData failed";
return RET_ERROR; return RET_ERROR;
} }
} else if (bias_const) {
auto *biasData = reinterpret_cast<float *>(biasTensor->data.data());
for (size_t i = 0; i < kernelNum; i++) {
newBiasData[i] = *biasData;
}
} else { } else {
if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s newBiasData failed"; MS_LOG(ERROR) << "memcpy_s newBiasData failed";

View File

@ -153,6 +153,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
weightTensor->format = schema::Format_KHWC; weightTensor->format = schema::Format_KHWC;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { } else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_CHWK; weightTensor->format = schema::Format_CHWK;
} else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_KHWC;
} else { } else {
MS_LOG(ERROR) << "unsupport format"; MS_LOG(ERROR) << "unsupport format";
return -1; return -1;
@ -356,18 +358,18 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition // todo(00445839): consider varible weight condition
} }
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
return 0; status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_HWKC) { // from tf } else if (weightTensor->format == schema::Format_KHWC) { // from tf
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2KCHW); status = RET_OK;
} else { } else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1; return -1;
} }
if (status == 0) { if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW;
weightTensor->format = schema::Format_KCHW; weightTensor->format = schema::Format_KHWC;
} else { } else {
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition // todo(00445839): consider varible weight condition

View File

@ -27,8 +27,16 @@ STATUS TfliteAddParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp
schema::CNodeT *op, schema::CNodeT *op,
TensorCache *tensor_cache, TensorCache *tensor_cache,
bool quantizedModel) { bool quantizedModel) {
// MS_LOGD("parse TfliteAddParser"); MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr(new schema::AddT()); std::unique_ptr<schema::AddT> attr(new schema::AddT());
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()};
if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
return RET_ERROR;
}
if (op != nullptr) { if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>(); op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Add; op->primitive->value.type = schema::PrimitiveType_Add;