diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index ee70a8c2ff3..a704d6b8f3a 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -285,6 +285,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr input_anod auto typePtr = abstractTensor->element()->GetTypeTrack(); MS_ASSERT(typePtr != nullptr); paramTensor->dataType = typePtr->type_id(); + paramTensor->format = schema::Format(abstractTensor->format()); if (!utils::isa(abstractTensor->BuildShape())) { MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); return RET_ERROR; diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 0f564aeabdb..83637346240 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -51,6 +51,7 @@ // #include "tools/converter/legacy_optimizer/node/weight_format_pass.h" #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" +#include "tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h" #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" @@ -113,7 +114,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - { Optimizer unusedOpRemoveOptimizer; unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); @@ -146,7 +146,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { return status; } if (!(this->graphDefT->fmkType == converter::FmkType_TF && - this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) { + this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) { status = mQuantizer->GenerateQuantParam(); if (status != RET_OK) { MS_LOG(ERROR) << "GenerateQuantParam failed"; @@ -171,6 +171,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { formatTransPass->SetQuantType(ctx.quantType); formatTransPass->SetFmk(ctx.fmk); formatTransOptimizer.AddPass(formatTransPass); + formatTransOptimizer.AddPass(new EltwiseFormatTransPass()); formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); // if (ctx.quantType == QuantType_AwareTraining) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc index 097f3c2eca7..f84d4859aaf 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc @@ -21,148 +21,120 @@ #include "tools/common/converter_op_utils.h" #include "tools/common/node_util.h" #include "utils/log_adapter.h" -#include "src/common/common.h" #include "src/common/utils.h" namespace mindspore { namespace lite { -#define kMinInputNum 1 -#define kOutputNum 1 STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) { - if (fmkType == converter::FmkType_TF) { - return RET_OK; - } MS_ASSERT(graph != nullptr); - auto status = DoModelInputFormatTrans(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status; - return status; - } - status = DoNodeInoutFormatTrans(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status; - return status; - } - return RET_OK; -} - -STATUS EltwiseFormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { - if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) { - return RET_OK; - } - MS_ASSERT(graph != nullptr); - // insert trans node in model input tensor - if (graph->nodes.empty()) { - return RET_OK; - } - auto graphInputIdxes = graph->inputIndex; - for (size_t i = 0; i < graphInputIdxes.size(); i++) { - auto inputIdx = graphInputIdxes.at(i); - MS_ASSERT(inputIdx < subGraph->allTensors.size()); - auto &tensor = graph->allTensors.at(inputIdx); - if (tensor->dims.size() != kNCHWDimNumber) { + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + if (node->primitive->value.type != PrimitiveType_Eltwise) { continue; } - - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { - if (node->inputIndex.at(inputIndexIdx) == inputIdx) { - STATUS status = RET_OK; - iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; - return status; + auto node_name = node->name; + auto input_node_indexes = GetInputNodeIdx(*graph, *node); + auto pre_type = schema::PrimitiveType_NONE; + size_t has_trans_count = 0; + auto can_fusion = true; + for (auto input_node_index : input_node_indexes) { + MS_ASSERT(graph->nodes.size() > input_node_index); + auto &pre_node = graph->nodes.at(input_node_index); + MS_ASSERT(pre_node != nullptr); + if (pre_type == schema::PrimitiveType_NONE) { + if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || + pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { + pre_type = pre_node->primitive->value.type; + has_trans_count++; + } + } else { + if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || + pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { + if (pre_type != pre_node->primitive->value.type) { + can_fusion = false; + break; + } else { + has_trans_count++; } - // set first tensor format to nhwc - auto &transNode = *(iter - 1); - MS_ASSERT(transNode != nullptr); - MS_ASSERT(transNode->inputIndex.size() == 1); - MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front()); - auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front()); - graphInTensor->format = schema::Format_NHWC; - // assume parser not reformat shape - auto oldDims = graphInTensor->dims; - graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; - break; } } } - } - return RET_OK; -} - -// inference needed inputFormat: -// conv deconv depth dedepth -// fp32 NCHW NCHW NCHW NCHW -// uint8 NCHW ? NCHW ? -STATUS EltwiseFormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - // insert before and after the op cal by nchw/nc4hw4 - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - FormatTransNodeType beforeNodeType, afterNodeType; - if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc - // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use - // nhwc - // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only - // support nhwc - // continue; - // } - // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { - // continue; - // } - // } else { - // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + if (!can_fusion) { continue; - // } - // } - // beforeNodeType = kNCHW2NHWC; - // afterNodeType = kNHWC2NCHW; - } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw - // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc - // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc - // continue; - // } - // } else { - // continue; - // } - if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { - continue; + } + auto output_node_indexes = GetOutputNodeIdx(*graph, *node); + auto post_type = schema::PrimitiveType_NONE; + for (auto output_node_index : output_node_indexes) { + MS_ASSERT(graph->nodes.size() > output_node_index); + auto &post_node = graph->nodes.at(output_node_index); + MS_ASSERT(post_node != nullptr); + if (post_type == schema::PrimitiveType_NONE) { + if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || + post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { + post_type = post_node->primitive->value.type; + has_trans_count++; + } + } else { + if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || + post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { + if (post_type != post_node->primitive->value.type) { + can_fusion = false; + break; + } else { + has_trans_count++; + } + } } - beforeNodeType = kNCHW2NHWC; - afterNodeType = kNHWC2NCHW; - } else if (fmkType == converter::FmkType_MS) { - if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { - continue; - } - beforeNodeType = kNCHW2NHWC; - afterNodeType = kNHWC2NCHW; + } + if (!can_fusion) { + continue; + } + auto total_node_count = input_node_indexes.size() + output_node_indexes.size(); + size_t half_count = total_node_count / 2; + if (total_node_count % 2 == 0) { + can_fusion = has_trans_count > half_count; } else { - MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; - return RET_ERROR; + can_fusion = has_trans_count >= half_count; } - auto &node = *iter; - auto nodeName = node->name; - if (node->inputIndex.size() < kMinInputNum) { - MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least"; - return RET_ERROR; + if (!can_fusion) { + continue; } - if (node->outputIndex.size() != kOutputNum) { - MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; - return RET_ERROR; - } - STATUS status; - iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed"; - return RET_ERROR; + FormatTransNodeType pre_insert_trans_type = kNHWC2NCHW; + FormatTransNodeType post_insert_trans_type = kNHWC2NCHW; + if (pre_type == PrimitiveType_NONE && post_type != PrimitiveType_NONE) { + pre_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; + post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + } else if (pre_type != PrimitiveType_NONE && post_type == PrimitiveType_NONE) { + pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + post_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; + } else if (pre_type == PrimitiveType_NONE && post_type == PrimitiveType_NONE) { + continue; + } else { + if (pre_type == post_type) { + MS_LOG(ERROR) << "Unknow error"; + return RET_ERROR; + } + pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; } - iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; - return RET_ERROR; + STATUS status = RET_OK; + auto input_tensor_size = (*iter)->inputIndex.size(); + for (auto i = 0; i < input_tensor_size; i++) { + iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "Insert" << pre_insert_trans_type << "before " << (*iter)->name << " failed"; + return status; + } + } + auto output_tensor_size = (*iter)->outputIndex.size(); + for (auto i = 0; i < output_tensor_size; i++) { + iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "Insert" << post_insert_trans_type << "Node before " << (*iter)->name << " failed"; + return status; + } } } return RET_OK; @@ -195,6 +167,5 @@ NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } - } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h index 5a5d754ac19..a8de9c97978 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h @@ -20,10 +20,10 @@ #include "tools/converter/optimizer.h" #include "tools/common/graph_util.h" #include "tools/converter/converter_flags.h" +#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" namespace mindspore { namespace lite { -enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW }; class EltwiseFormatTransPass : public GraphPass { public: @@ -38,10 +38,6 @@ class EltwiseFormatTransPass : public GraphPass { void SetFmk(converter::FmkType fmkType); private: - STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); - - STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); - NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, FormatTransNodeType nodeType, STATUS *errorCode); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index 652cdd61c60..a87e960c8c4 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -369,8 +369,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_KHWC) { // from tf - status = RET_OK; + } else if (weightTensor->format == schema::Format_CHWK) { // from tf + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; return -1;