!4530 add formatTransOp around eltwise fusion

Merge pull request !4530 from hangq/master
This commit is contained in:
mindspore-ci-bot 2020-08-16 16:38:41 +08:00 committed by Gitee
commit 106cadbaa4
5 changed files with 101 additions and 132 deletions

View File

@ -285,6 +285,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod
auto typePtr = abstractTensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr);
paramTensor->dataType = typePtr->type_id();
paramTensor->format = schema::Format(abstractTensor->format());
if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name();
return RET_ERROR;

View File

@ -51,6 +51,7 @@
//
#include "tools/converter/legacy_optimizer/node/weight_format_pass.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
@ -113,7 +114,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
@ -146,7 +146,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return status;
}
if (!(this->graphDefT->fmkType == converter::FmkType_TF &&
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) {
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) {
status = mQuantizer->GenerateQuantParam();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateQuantParam failed";
@ -171,6 +171,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
formatTransPass->SetQuantType(ctx.quantType);
formatTransPass->SetFmk(ctx.fmk);
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new EltwiseFormatTransPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
// if (ctx.quantType == QuantType_AwareTraining) {

View File

@ -21,148 +21,120 @@
#include "tools/common/converter_op_utils.h"
#include "tools/common/node_util.h"
#include "utils/log_adapter.h"
#include "src/common/common.h"
#include "src/common/utils.h"
namespace mindspore {
namespace lite {
#define kMinInputNum 1
#define kOutputNum 1
STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) {
if (fmkType == converter::FmkType_TF) {
return RET_OK;
}
MS_ASSERT(graph != nullptr);
auto status = DoModelInputFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status;
return status;
}
status = DoNodeInoutFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status;
return status;
}
return RET_OK;
}
STATUS EltwiseFormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) {
return RET_OK;
}
MS_ASSERT(graph != nullptr);
// insert trans node in model input tensor
if (graph->nodes.empty()) {
return RET_OK;
}
auto graphInputIdxes = graph->inputIndex;
for (size_t i = 0; i < graphInputIdxes.size(); i++) {
auto inputIdx = graphInputIdxes.at(i);
MS_ASSERT(inputIdx < subGraph->allTensors.size());
auto &tensor = graph->allTensors.at(inputIdx);
if (tensor->dims.size() != kNCHWDimNumber) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (node->primitive->value.type != PrimitiveType_Eltwise) {
continue;
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
if (node->inputIndex.at(inputIndexIdx) == inputIdx) {
STATUS status = RET_OK;
iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed";
return status;
auto node_name = node->name;
auto input_node_indexes = GetInputNodeIdx(*graph, *node);
auto pre_type = schema::PrimitiveType_NONE;
size_t has_trans_count = 0;
auto can_fusion = true;
for (auto input_node_index : input_node_indexes) {
MS_ASSERT(graph->nodes.size() > input_node_index);
auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr);
if (pre_type == schema::PrimitiveType_NONE) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
pre_type = pre_node->primitive->value.type;
has_trans_count++;
}
} else {
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
if (pre_type != pre_node->primitive->value.type) {
can_fusion = false;
break;
} else {
has_trans_count++;
}
// set first tensor format to nhwc
auto &transNode = *(iter - 1);
MS_ASSERT(transNode != nullptr);
MS_ASSERT(transNode->inputIndex.size() == 1);
MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front());
auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front());
graphInTensor->format = schema::Format_NHWC;
// assume parser not reformat shape
auto oldDims = graphInTensor->dims;
graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]};
break;
}
}
}
}
return RET_OK;
}
// inference needed inputFormat:
// conv deconv depth dedepth
// fp32 NCHW NCHW NCHW NCHW
// uint8 NCHW ? NCHW ?
STATUS EltwiseFormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert before and after the op cal by nchw/nc4hw4
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
FormatTransNodeType beforeNodeType, afterNodeType;
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use
// nhwc
// if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only
// support nhwc
// continue;
// }
// if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
// continue;
// }
// } else {
// if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
if (!can_fusion) {
continue;
// }
// }
// beforeNodeType = kNCHW2NHWC;
// afterNodeType = kNHWC2NCHW;
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc
// if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc
// continue;
// }
// } else {
// continue;
// }
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
continue;
}
auto output_node_indexes = GetOutputNodeIdx(*graph, *node);
auto post_type = schema::PrimitiveType_NONE;
for (auto output_node_index : output_node_indexes) {
MS_ASSERT(graph->nodes.size() > output_node_index);
auto &post_node = graph->nodes.at(output_node_index);
MS_ASSERT(post_node != nullptr);
if (post_type == schema::PrimitiveType_NONE) {
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
post_type = post_node->primitive->value.type;
has_trans_count++;
}
} else {
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
if (post_type != post_node->primitive->value.type) {
can_fusion = false;
break;
} else {
has_trans_count++;
}
}
}
beforeNodeType = kNCHW2NHWC;
afterNodeType = kNHWC2NCHW;
} else if (fmkType == converter::FmkType_MS) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
continue;
}
beforeNodeType = kNCHW2NHWC;
afterNodeType = kNHWC2NCHW;
}
if (!can_fusion) {
continue;
}
auto total_node_count = input_node_indexes.size() + output_node_indexes.size();
size_t half_count = total_node_count / 2;
if (total_node_count % 2 == 0) {
can_fusion = has_trans_count > half_count;
} else {
MS_LOG(ERROR) << "Unsupported fmk: " << fmkType;
return RET_ERROR;
can_fusion = has_trans_count >= half_count;
}
auto &node = *iter;
auto nodeName = node->name;
if (node->inputIndex.size() < kMinInputNum) {
MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
if (!can_fusion) {
continue;
}
if (node->outputIndex.size() != kOutputNum) {
MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor";
return RET_ERROR;
}
STATUS status;
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed";
return RET_ERROR;
FormatTransNodeType pre_insert_trans_type = kNHWC2NCHW;
FormatTransNodeType post_insert_trans_type = kNHWC2NCHW;
if (pre_type == PrimitiveType_NONE && post_type != PrimitiveType_NONE) {
pre_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
} else if (pre_type != PrimitiveType_NONE && post_type == PrimitiveType_NONE) {
pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
} else if (pre_type == PrimitiveType_NONE && post_type == PrimitiveType_NONE) {
continue;
} else {
if (pre_type == post_type) {
MS_LOG(ERROR) << "Unknow error";
return RET_ERROR;
}
pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
}
iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
return RET_ERROR;
STATUS status = RET_OK;
auto input_tensor_size = (*iter)->inputIndex.size();
for (auto i = 0; i < input_tensor_size; i++) {
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type << "before " << (*iter)->name << " failed";
return status;
}
}
auto output_tensor_size = (*iter)->outputIndex.size();
for (auto i = 0; i < output_tensor_size; i++) {
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << post_insert_trans_type << "Node before " << (*iter)->name << " failed";
return status;
}
}
}
return RET_OK;
@ -195,6 +167,5 @@ NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph
void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; }
} // namespace lite
} // namespace mindspore

View File

@ -20,10 +20,10 @@
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
namespace mindspore {
namespace lite {
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW };
class EltwiseFormatTransPass : public GraphPass {
public:
@ -38,10 +38,6 @@ class EltwiseFormatTransPass : public GraphPass {
void SetFmk(converter::FmkType fmkType);
private:
STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph);
STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph);
NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
FormatTransNodeType nodeType, STATUS *errorCode);

View File

@ -369,8 +369,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
status = RET_OK;
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;