From 9638139e27737e496ebf204c998dcfff5ab7b3a7 Mon Sep 17 00:00:00 2001 From: kai00 Date: Wed, 12 Aug 2020 22:04:46 +0800 Subject: [PATCH] fixed format trans --- mindspore/lite/src/common/anf_importer/import_from_protobuf.cc | 1 + mindspore/lite/tools/common/node_util.cc | 2 +- mindspore/lite/tools/common/node_util.h | 2 +- .../tools/converter/legacy_optimizer/node/weight_format_pass.cc | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc index 0479c40486d..db11cdcaab0 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc @@ -1168,6 +1168,7 @@ int AnfImporterFromProtobuf::Import() { const onnx::GraphProto &graphBuild = onnx_model_->graph(); if (!BuildFuncGraph(dstGraph, graphBuild)) { MS_LOG(ERROR) << "Build funcgraph failed!"; + func_graph_ = nullptr; return RET_ERROR; } func_graph_ = dstGraph; diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 16c146052e9..6f9a4e52e21 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -96,7 +96,7 @@ static const std::vector nhwcOpList = { schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize, - schema::PrimitiveType_BatchNorm}; + schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm}; static const std::vector fp32FullOpList = { schema::PrimitiveType_Concat, schema::PrimitiveType_Add, diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 1f017a54c56..f3db9cfbb50 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -234,7 +234,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCKHW2KHWC) { p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c)); + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else { p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index 06f6d39ca0a..74b146272d1 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -351,7 +351,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { // todo(00445839): consider varible weight condition } } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW - if (graphNode->subGraph->fmkType == converter::FmkType_MS) { + if (fmkType == converter::FmkType_MS) { weightTensor->format = schema::Format_CKHW; } if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms