forked from mindspore-Ecosystem/mindspore
!4530 add formatTransOp around eltwise fusion
Merge pull request !4530 from hangq/master
This commit is contained in:
commit
106cadbaa4
|
@ -285,6 +285,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod
|
|||
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
||||
MS_ASSERT(typePtr != nullptr);
|
||||
paramTensor->dataType = typePtr->type_id();
|
||||
paramTensor->format = schema::Format(abstractTensor->format());
|
||||
if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
||||
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name();
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -51,6 +51,7 @@
|
|||
//
|
||||
#include "tools/converter/legacy_optimizer/node/weight_format_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
|
||||
|
@ -113,7 +114,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
{
|
||||
Optimizer unusedOpRemoveOptimizer;
|
||||
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
|
||||
|
@ -146,7 +146,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
return status;
|
||||
}
|
||||
if (!(this->graphDefT->fmkType == converter::FmkType_TF &&
|
||||
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) {
|
||||
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) {
|
||||
status = mQuantizer->GenerateQuantParam();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GenerateQuantParam failed";
|
||||
|
@ -171,6 +171,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
formatTransPass->SetQuantType(ctx.quantType);
|
||||
formatTransPass->SetFmk(ctx.fmk);
|
||||
formatTransOptimizer.AddPass(formatTransPass);
|
||||
formatTransOptimizer.AddPass(new EltwiseFormatTransPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
// if (ctx.quantType == QuantType_AwareTraining) {
|
||||
|
|
|
@ -21,148 +21,120 @@
|
|||
#include "tools/common/converter_op_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#define kMinInputNum 1
|
||||
#define kOutputNum 1
|
||||
|
||||
STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) {
|
||||
if (fmkType == converter::FmkType_TF) {
|
||||
return RET_OK;
|
||||
}
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto status = DoModelInputFormatTrans(graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = DoNodeInoutFormatTrans(graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status;
|
||||
return status;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS EltwiseFormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
|
||||
if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) {
|
||||
return RET_OK;
|
||||
}
|
||||
MS_ASSERT(graph != nullptr);
|
||||
// insert trans node in model input tensor
|
||||
if (graph->nodes.empty()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto graphInputIdxes = graph->inputIndex;
|
||||
for (size_t i = 0; i < graphInputIdxes.size(); i++) {
|
||||
auto inputIdx = graphInputIdxes.at(i);
|
||||
MS_ASSERT(inputIdx < subGraph->allTensors.size());
|
||||
auto &tensor = graph->allTensors.at(inputIdx);
|
||||
if (tensor->dims.size() != kNCHWDimNumber) {
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
if (node->primitive->value.type != PrimitiveType_Eltwise) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
|
||||
if (node->inputIndex.at(inputIndexIdx) == inputIdx) {
|
||||
STATUS status = RET_OK;
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed";
|
||||
return status;
|
||||
auto node_name = node->name;
|
||||
auto input_node_indexes = GetInputNodeIdx(*graph, *node);
|
||||
auto pre_type = schema::PrimitiveType_NONE;
|
||||
size_t has_trans_count = 0;
|
||||
auto can_fusion = true;
|
||||
for (auto input_node_index : input_node_indexes) {
|
||||
MS_ASSERT(graph->nodes.size() > input_node_index);
|
||||
auto &pre_node = graph->nodes.at(input_node_index);
|
||||
MS_ASSERT(pre_node != nullptr);
|
||||
if (pre_type == schema::PrimitiveType_NONE) {
|
||||
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
|
||||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
pre_type = pre_node->primitive->value.type;
|
||||
has_trans_count++;
|
||||
}
|
||||
} else {
|
||||
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
|
||||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
if (pre_type != pre_node->primitive->value.type) {
|
||||
can_fusion = false;
|
||||
break;
|
||||
} else {
|
||||
has_trans_count++;
|
||||
}
|
||||
// set first tensor format to nhwc
|
||||
auto &transNode = *(iter - 1);
|
||||
MS_ASSERT(transNode != nullptr);
|
||||
MS_ASSERT(transNode->inputIndex.size() == 1);
|
||||
MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front());
|
||||
auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front());
|
||||
graphInTensor->format = schema::Format_NHWC;
|
||||
// assume parser not reformat shape
|
||||
auto oldDims = graphInTensor->dims;
|
||||
graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]};
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
// inference needed inputFormat:
|
||||
// conv deconv depth dedepth
|
||||
// fp32 NCHW NCHW NCHW NCHW
|
||||
// uint8 NCHW ? NCHW ?
|
||||
STATUS EltwiseFormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
// insert before and after the op cal by nchw/nc4hw4
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
FormatTransNodeType beforeNodeType, afterNodeType;
|
||||
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc
|
||||
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use
|
||||
// nhwc
|
||||
// if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only
|
||||
// support nhwc
|
||||
// continue;
|
||||
// }
|
||||
// if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
|
||||
// continue;
|
||||
// }
|
||||
// } else {
|
||||
// if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
|
||||
if (!can_fusion) {
|
||||
continue;
|
||||
// }
|
||||
// }
|
||||
// beforeNodeType = kNCHW2NHWC;
|
||||
// afterNodeType = kNHWC2NCHW;
|
||||
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw
|
||||
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc
|
||||
// if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc
|
||||
// continue;
|
||||
// }
|
||||
// } else {
|
||||
// continue;
|
||||
// }
|
||||
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
|
||||
continue;
|
||||
}
|
||||
auto output_node_indexes = GetOutputNodeIdx(*graph, *node);
|
||||
auto post_type = schema::PrimitiveType_NONE;
|
||||
for (auto output_node_index : output_node_indexes) {
|
||||
MS_ASSERT(graph->nodes.size() > output_node_index);
|
||||
auto &post_node = graph->nodes.at(output_node_index);
|
||||
MS_ASSERT(post_node != nullptr);
|
||||
if (post_type == schema::PrimitiveType_NONE) {
|
||||
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
|
||||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
post_type = post_node->primitive->value.type;
|
||||
has_trans_count++;
|
||||
}
|
||||
} else {
|
||||
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
|
||||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
if (post_type != post_node->primitive->value.type) {
|
||||
can_fusion = false;
|
||||
break;
|
||||
} else {
|
||||
has_trans_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
beforeNodeType = kNCHW2NHWC;
|
||||
afterNodeType = kNHWC2NCHW;
|
||||
} else if (fmkType == converter::FmkType_MS) {
|
||||
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
|
||||
continue;
|
||||
}
|
||||
beforeNodeType = kNCHW2NHWC;
|
||||
afterNodeType = kNHWC2NCHW;
|
||||
}
|
||||
if (!can_fusion) {
|
||||
continue;
|
||||
}
|
||||
auto total_node_count = input_node_indexes.size() + output_node_indexes.size();
|
||||
size_t half_count = total_node_count / 2;
|
||||
if (total_node_count % 2 == 0) {
|
||||
can_fusion = has_trans_count > half_count;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported fmk: " << fmkType;
|
||||
return RET_ERROR;
|
||||
can_fusion = has_trans_count >= half_count;
|
||||
}
|
||||
auto &node = *iter;
|
||||
auto nodeName = node->name;
|
||||
if (node->inputIndex.size() < kMinInputNum) {
|
||||
MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least";
|
||||
return RET_ERROR;
|
||||
if (!can_fusion) {
|
||||
continue;
|
||||
}
|
||||
if (node->outputIndex.size() != kOutputNum) {
|
||||
MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
STATUS status;
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed";
|
||||
return RET_ERROR;
|
||||
FormatTransNodeType pre_insert_trans_type = kNHWC2NCHW;
|
||||
FormatTransNodeType post_insert_trans_type = kNHWC2NCHW;
|
||||
if (pre_type == PrimitiveType_NONE && post_type != PrimitiveType_NONE) {
|
||||
pre_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
|
||||
post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
|
||||
} else if (pre_type != PrimitiveType_NONE && post_type == PrimitiveType_NONE) {
|
||||
pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
|
||||
post_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
|
||||
} else if (pre_type == PrimitiveType_NONE && post_type == PrimitiveType_NONE) {
|
||||
continue;
|
||||
} else {
|
||||
if (pre_type == post_type) {
|
||||
MS_LOG(ERROR) << "Unknow error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
|
||||
post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
|
||||
}
|
||||
|
||||
iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
|
||||
return RET_ERROR;
|
||||
STATUS status = RET_OK;
|
||||
auto input_tensor_size = (*iter)->inputIndex.size();
|
||||
for (auto i = 0; i < input_tensor_size; i++) {
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type << "before " << (*iter)->name << " failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
auto output_tensor_size = (*iter)->outputIndex.size();
|
||||
for (auto i = 0; i < output_tensor_size; i++) {
|
||||
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Insert" << post_insert_trans_type << "Node before " << (*iter)->name << " failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -195,6 +167,5 @@ NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph
|
|||
void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }
|
||||
|
||||
void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; }
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,10 +20,10 @@
|
|||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW };
|
||||
|
||||
class EltwiseFormatTransPass : public GraphPass {
|
||||
public:
|
||||
|
@ -38,10 +38,6 @@ class EltwiseFormatTransPass : public GraphPass {
|
|||
void SetFmk(converter::FmkType fmkType);
|
||||
|
||||
private:
|
||||
STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph);
|
||||
|
||||
STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph);
|
||||
|
||||
NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
|
||||
FormatTransNodeType nodeType, STATUS *errorCode);
|
||||
|
||||
|
|
|
@ -369,8 +369,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|||
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
|
||||
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
||||
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
|
||||
status = RET_OK;
|
||||
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
||||
return -1;
|
||||
|
|
Loading…
Reference in New Issue