train fusion online

This commit is contained in:
yefeng 2021-09-23 16:11:07 +08:00
parent e586287da5
commit 660b335be1
26 changed files with 675 additions and 376 deletions

View File

@ -190,6 +190,7 @@ set(TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/graph_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_utils.cc
@ -201,6 +202,11 @@ set(TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/optimizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/fusion/fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/fusion/fusion_pattern.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/meta_graph_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc
) )
if(MSLITE_ENABLE_V0) if(MSLITE_ENABLE_V0)
set(TRAIN_SRC set(TRAIN_SRC

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/train/graph_fusion.h"
#include "tools/converter/optimizer.h"
#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
namespace mindspore {
namespace lite {
STATUS GraphFusion::Run(schema::MetaGraphT *graph) {
if (graph == nullptr) {
MS_LOG(ERROR) << "graph is nullptr.";
return RET_ERROR;
}
Optimizer fusion_optimizer;
fusion_optimizer.AddPass(new (std::nothrow) MatMulBiasAddFusionPass());
auto status = fusion_optimizer.Run(graph);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "graph fusion failed.";
return RET_ERROR;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/optimizer.h"
#include "inner/model_generated.h"
#include "include/errorcode.h"
namespace mindspore {
namespace lite {
class GraphFusion {
public:
GraphFusion() = default;
~GraphFusion() = default;
STATUS Run(schema::MetaGraphT *graph);
};
} // namespace lite
} // namespace mindspore

View File

@ -27,6 +27,7 @@
#include "src/train/train_utils.h" #include "src/train/train_utils.h"
#include "src/common/quant_utils.h" #include "src/common/quant_utils.h"
#include "tools/common/storage.h" #include "tools/common/storage.h"
#include "src/train/graph_fusion.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -530,6 +531,15 @@ int TrainExport::IsInputTensor(const schema::TensorT &t) {
return ((t.data.size() == 0) && (total_dims != 0)); return ((t.data.size() == 0) && (total_dims != 0));
} }
int TrainExport::TrainModelFusion() {
GraphFusion graph_fusion;
auto status = graph_fusion.Run(meta_graph_);
if (status != RET_OK) {
return RET_ERROR;
}
return RET_OK;
}
TrainExport::~TrainExport() { delete meta_graph_; } TrainExport::~TrainExport() { delete meta_graph_; }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -47,6 +47,7 @@ class TrainExport {
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; } void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
int LoadModel(void *buf, size_t buf_size); int LoadModel(void *buf, size_t buf_size);
int AddTransformNode(); int AddTransformNode();
int TrainModelFusion();
protected: protected:
virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor); virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);

View File

@ -1036,6 +1036,13 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
MS_LOG(ERROR) << "cannot export Network"; MS_LOG(ERROR) << "cannot export Network";
return status; return status;
} }
if (model_type == MT_INFERENCE) {
status = texport.TrainModelFusion();
if (status != RET_OK) {
MS_LOG(ERROR) << "TrainModelFusion failed.";
return status;
}
}
status = texport.SaveToFile(); status = texport.SaveToFile();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << file_name; MS_LOG(ERROR) << "failed to save to " << file_name;

View File

@ -36,6 +36,7 @@
#include "tools/converter/quantizer/bitpacking.h" #include "tools/converter/quantizer/bitpacking.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "src/ops/ops_utils.h" #include "src/ops/ops_utils.h"
#include "src/weight_decoder.h" #include "src/weight_decoder.h"
#include "tools/common/node_util.h" #include "tools/common/node_util.h"

View File

@ -2,6 +2,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
file(GLOB CONVERTER_COMMON_SRC file(GLOB CONVERTER_COMMON_SRC
${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/meta_graph_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/node_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/node_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/tensor_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/storage.cc

View File

@ -20,6 +20,7 @@
#include <ctime> #include <ctime>
#include <utility> #include <utility>
#include <set> #include <set>
#include "tools/common/meta_graph_utils.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/bitpacking.h" #include "tools/converter/quantizer/bitpacking.h"
@ -74,315 +75,6 @@ OpDefCopyer GetSimpleOpCopyer() {
}; };
} }
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
}
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
std::vector<uint32_t> inputIndexes;
if (inputIndexIdx == -1) {
inputIndexes = node.inputIndex;
} else {
MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
}
std::set<size_t> inputNodeIdx;
for (uint32_t inputIdx : inputIndexes) {
auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
return ret;
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
const int outputIndexIdx) {
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
}
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
std::replace_if(
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
for (auto &subGraph : graphT->subGraph) {
std::replace_if(
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
}
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
std::vector<uint32_t> outputIndexes;
if (outputIndexIdx == -1) {
outputIndexes = node.outputIndex;
} else {
MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
}
std::set<size_t> outputNodeIdx;
for (uint32_t outputIdx : outputIndexes) {
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
return ret;
}
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> preNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto outputIndexes = oldNode->outputIndex;
if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
preNodeIdx.emplace_back(i);
}
}
return preNodeIdx;
}
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> postNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto inputIndexes = oldNode->inputIndex;
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
postNodeIdx.emplace_back(i);
}
}
return postNodeIdx;
}
STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
size_t nodeIdx = 0;
for (size_t i = 0; i < graphT->nodes.size(); i++) {
auto &inNode = graphT->nodes.at(i);
MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
if (inNode->name == node->name) {
nodeIdx = i;
break;
}
}
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
return RET_ERROR;
}
if (outputTensorIdxes.size() != 1) {
MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
<< "should has 1 output, in fact: " << outputTensorIdxes.size();
return RET_ERROR;
}
auto inDataTensorIdx = inputTensorIdxes.front();
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
RemoveTensor(graphT, outputTensorIdxes);
node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
MS_ASSERT(graph != nullptr);
return IsolateOneWayNode(graph, nodeIdx, removeTensor);
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
MS_ASSERT(graphT != nullptr);
if (graphT->nodes.size() <= nodeIdx) {
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
return RET_PARAM_INVALID;
}
CNodeT *node = graphT->nodes.at(nodeIdx).get();
if (node == nullptr) {
MS_LOG(ERROR) << "node is null";
return RET_NULL_PTR;
}
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
return RET_ERROR;
}
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
return RET_ERROR;
}
auto inDataTensorIdx = inputTensorIdxes.front();
if (!outputTensorIdxes.empty()) {
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
}
if (removeTensor) {
// now all node's outputTensors are useless
// remove all node's outputTensors
auto status = RemoveTensor(graphT, outputTensorIdxes);
if (status != RET_OK) {
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
return RET_ERROR;
}
}
node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
bool isSubNode = false;
size_t nodeIdx = 0;
for (size_t i = 0; i < graphT->nodes.size(); i++) {
auto &inNode = graphT->nodes.at(i);
MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
if (inNode->name == node->name) {
isSubNode = true;
nodeIdx = i;
break;
}
}
if (!isSubNode) {
MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
return RET_PARAM_INVALID;
} else {
return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
}
}
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
MS_ASSERT(graphT != nullptr);
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
uint32_t deleteIdx = *iter;
if (!forceDelete) {
if (GetRefCount(graphT, deleteIdx) > 1) {
iter++;
continue;
}
}
// update graph input indices
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update graph output indices
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
for (auto &subgraph : graphT->subGraph) {
// update subgraph input indices
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update subgraph output indices
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
// update subgraph output indices
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
if (*idx > deleteIdx) {
(*idx)--;
}
}
}
// update nodes indexes
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
// update nodes input indexes
UpdateNodeIndex((*node_iter).get(), deleteIdx);
}
// update deleteTensorIdx
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
if (*selfIt > deleteIdx) {
(*selfIt)--;
}
}
graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
iter = toDeleteTensorIdxes.erase(iter);
}
return RET_OK;
}
STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
MS_ASSERT(node != nullptr);
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
if (*inIdxIt == deleteIdx) {
inIdxIt = node->inputIndex.erase(inIdxIt);
} else {
if (*inIdxIt > deleteIdx) {
(*inIdxIt)--;
}
inIdxIt++;
}
}
// update nodes output indexes
for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
if (*outIdxIt == deleteIdx) {
outIdxIt = node->outputIndex.erase(outIdxIt);
} else {
if (*outIdxIt > deleteIdx) {
(*outIdxIt)--;
}
outIdxIt++;
}
}
return RET_OK;
}
STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor, STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
InsertPlace place) { InsertPlace place) {
if (nodeIdx >= graphT->nodes.size()) { if (nodeIdx >= graphT->nodes.size()) {
@ -672,33 +364,6 @@ std::string GetModelName(const std::string &modelFile) {
return modelName; return modelName;
} }
int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
for (auto &subgraph : meta_graphT->subGraph) {
std::vector<uint32_t> subgraph_indices{};
subgraph_indices.assign(subgraph->inputIndices.begin(), subgraph->inputIndices.end());
subgraph_indices.assign(subgraph->outputIndices.begin(), subgraph->outputIndices.end());
for (auto &node_idx : subgraph->nodeIndices) {
auto &node = meta_graphT->nodes.at(node_idx);
for (auto &input_idx : node->inputIndex) {
if (IsContain(subgraph_indices, input_idx)) {
continue;
} else {
subgraph_indices.push_back(input_idx);
}
}
for (auto &output_idx : node->outputIndex) {
if (IsContain(subgraph_indices, output_idx)) {
continue;
} else {
subgraph_indices.push_back(output_idx);
}
}
}
subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
}
return RET_OK;
}
std::vector<int> GetTransposePerm(MetaGraphT *graph, const std::unique_ptr<CNodeT> &cnode) { std::vector<int> GetTransposePerm(MetaGraphT *graph, const std::unique_ptr<CNodeT> &cnode) {
MS_ASSERT(graph != nullptr && cnode != nullptr); MS_ASSERT(graph != nullptr && cnode != nullptr);
std::vector<int> perm; std::vector<int> perm;

View File

@ -48,34 +48,6 @@ OpDefCopyer GetSimpleOpCopyer();
int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs); int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
int inputIndexIdx = -1);
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1);
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
int outputIndexIdx = -1);
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true);
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx);
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false);
STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<schema::TensorT> tensor, STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<schema::TensorT> tensor,
InsertPlace place = kBefore); InsertPlace place = kBefore);
@ -102,8 +74,6 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType); STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType);
STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT);
std::string GetModelName(const std::string &modelFile); std::string GetModelName(const std::string &modelFile);
std::vector<int> GetTransposePerm(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &cnode); std::vector<int> GetTransposePerm(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &cnode);

View File

@ -0,0 +1,374 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/common/meta_graph_utils.h"
#include <vector>
#include <set>
#include "inner/model_generated.h"
#include "src/common/utils.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
namespace {
size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(graphT->allTensors.size() > tensorIdx);
size_t refCount = 0;
for (auto &node : graphT->nodes) {
MS_ASSERT(node != nullptr);
if (IsContain(node->inputIndex, tensorIdx)) {
refCount++;
}
}
return refCount;
}
} // namespace
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> postNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto inputIndexes = oldNode->inputIndex;
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
postNodeIdx.emplace_back(i);
}
}
return postNodeIdx;
}
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> preNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto outputIndexes = oldNode->outputIndex;
if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
preNodeIdx.emplace_back(i);
}
}
return preNodeIdx;
}
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
const int inputIndexIdx) {
std::vector<uint32_t> inputIndexes;
if (inputIndexIdx == -1) {
inputIndexes = node.inputIndex;
} else {
MS_ASSERT(node.inputIndex.size() > static_cast<uint32_t>(inputIndexIdx));
inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
}
std::set<size_t> inputNodeIdx;
for (uint32_t inputIdx : inputIndexes) {
auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
return ret;
}
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
const int outputIndexIdx) {
std::vector<uint32_t> outputIndexes;
if (outputIndexIdx == -1) {
outputIndexes = node.outputIndex;
} else {
MS_ASSERT(node.outputIndex.size() > static_cast<uint32_t>(outputIndexIdx));
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
}
std::set<size_t> outputNodeIdx;
for (uint32_t outputIdx : outputIndexes) {
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
return ret;
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
const int outputIndexIdx) {
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
}
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
std::replace_if(
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
for (auto &subGraph : graphT->subGraph) {
std::replace_if(
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
}
}
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx) {
MS_ASSERT(node != nullptr);
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
if (*inIdxIt == deleteIdx) {
inIdxIt = node->inputIndex.erase(inIdxIt);
} else {
if (*inIdxIt > deleteIdx) {
(*inIdxIt)--;
}
inIdxIt++;
}
}
// update nodes output indexes
for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
if (*outIdxIt == deleteIdx) {
outIdxIt = node->outputIndex.erase(outIdxIt);
} else {
if (*outIdxIt > deleteIdx) {
(*outIdxIt)--;
}
outIdxIt++;
}
}
return RET_OK;
}
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
MS_ASSERT(graphT != nullptr);
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
uint32_t deleteIdx = *iter;
if (!forceDelete) {
if (GetRefCount(graphT, deleteIdx) > 1) {
iter++;
continue;
}
}
// update graph input indices
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update graph output indices
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
for (auto &subgraph : graphT->subGraph) {
// update subgraph input indices
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update subgraph output indices
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
// update subgraph output indices
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
if (*idx > deleteIdx) {
(*idx)--;
}
}
}
// update nodes indexes
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
// update nodes input indexes
UpdateNodeIndex((*node_iter).get(), deleteIdx);
}
// update deleteTensorIdx
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
if (*selfIt > deleteIdx) {
(*selfIt)--;
}
}
graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
iter = toDeleteTensorIdxes.erase(iter);
}
return RET_OK;
}
STATUS IsolateNode(schema::MetaGraphT *graphT, schema::CNodeT *node) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
size_t nodeIdx = 0;
for (size_t i = 0; i < graphT->nodes.size(); i++) {
auto &inNode = graphT->nodes.at(i);
MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
if (inNode->name == node->name) {
nodeIdx = i;
break;
}
}
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
return RET_ERROR;
}
if (outputTensorIdxes.size() != 1) {
MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
<< "should has 1 output, in fact: " << outputTensorIdxes.size();
return RET_ERROR;
}
auto inDataTensorIdx = inputTensorIdxes.front();
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
RemoveTensor(graphT, outputTensorIdxes);
node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
MS_ASSERT(graphT != nullptr);
if (graphT->nodes.size() <= nodeIdx) {
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
return RET_PARAM_INVALID;
}
schema::CNodeT *node = graphT->nodes.at(nodeIdx).get();
if (node == nullptr) {
MS_LOG(ERROR) << "node is null";
return RET_NULL_PTR;
}
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
return RET_ERROR;
}
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
return RET_ERROR;
}
auto inDataTensorIdx = inputTensorIdxes.front();
if (!outputTensorIdxes.empty()) {
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
}
if (removeTensor) {
// now all node's outputTensors are useless
// remove all node's outputTensors
auto status = RemoveTensor(graphT, outputTensorIdxes);
if (status != RET_OK) {
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
return RET_ERROR;
}
}
node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
MS_ASSERT(graph != nullptr);
return IsolateOneWayNode(graph, nodeIdx, removeTensor);
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
bool isSubNode = false;
size_t nodeIdx = 0;
for (size_t i = 0; i < graphT->nodes.size(); i++) {
auto &inNode = graphT->nodes.at(i);
MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
if (inNode->name == node->name) {
isSubNode = true;
nodeIdx = i;
break;
}
}
if (!isSubNode) {
MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
return RET_PARAM_INVALID;
} else {
return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
}
}
int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
for (auto &subgraph : meta_graphT->subGraph) {
std::vector<uint32_t> subgraph_indices{};
subgraph_indices.assign(subgraph->inputIndices.begin(), subgraph->inputIndices.end());
subgraph_indices.assign(subgraph->outputIndices.begin(), subgraph->outputIndices.end());
for (auto &node_idx : subgraph->nodeIndices) {
auto &node = meta_graphT->nodes.at(node_idx);
for (auto &input_idx : node->inputIndex) {
if (IsContain(subgraph_indices, input_idx)) {
continue;
} else {
subgraph_indices.push_back(input_idx);
}
}
for (auto &output_idx : node->outputIndex) {
if (IsContain(subgraph_indices, output_idx)) {
continue;
} else {
subgraph_indices.push_back(output_idx);
}
}
}
subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTIL_H
#define MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTIL_H
#include <vector>
#include "inner/model_generated.h"
#include "include/errorcode.h"
namespace mindspore::lite {
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
int inputIndexIdx = -1);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
int outputIndexIdx = -1);
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1);
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false);
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTIL_H

View File

@ -30,6 +30,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/meta_graph_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc

View File

@ -26,7 +26,7 @@
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/common/graph_util.h" #include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
@ -131,7 +131,7 @@ STATUS FusionPass::MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pat
for (auto index : graph->outputIndex) { for (auto index : graph->outputIndex) {
auto subGraphOutputNodeIdxes = GetLinkedPreIdx(*graph, index); auto subGraphOutputNodeIdxes = GetLinkedPreIdx(*graph, index);
for (auto subGraphOutputNodeIdx : subGraphOutputNodeIdxes) { for (auto subGraphOutputNodeIdx : subGraphOutputNodeIdxes) {
MS_ASSERT(subGraph->nodes.size() > subGraphOutputNodeIdx); MS_ASSERT(graph->nodes.size() > subGraphOutputNodeIdx);
nodeQueue.push(subGraphOutputNodeIdx); nodeQueue.push(subGraphOutputNodeIdx);
} }
} }
@ -141,7 +141,7 @@ STATUS FusionPass::MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pat
if (IsContain(sinkIdes, nodeIdx)) { if (IsContain(sinkIdes, nodeIdx)) {
continue; continue;
} }
MS_ASSERT(subGraph->nodes.size() > nodeIdx); MS_ASSERT(graph->nodes.size() > nodeIdx);
auto &node = graph->nodes.at(nodeIdx); auto &node = graph->nodes.at(nodeIdx);
sinkIdes.emplace_back(nodeIdx); sinkIdes.emplace_back(nodeIdx);
@ -151,7 +151,7 @@ STATUS FusionPass::MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pat
} }
auto preNodeIdxes = GetInputNodeIdx(*graph, nodeIdx); auto preNodeIdxes = GetInputNodeIdx(*graph, nodeIdx);
for (auto preNodeIdx : preNodeIdxes) { for (auto preNodeIdx : preNodeIdxes) {
MS_ASSERT((subGraph->nodes.size() > preNodeIdx)); MS_ASSERT(graph->nodes.size() > preNodeIdx);
nodeQueue.push(preNodeIdx); nodeQueue.push(preNodeIdx);
} }
} }
@ -277,7 +277,7 @@ bool FusionPass::MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std:
if (preNodeIdxInner == preNodeIdx) { if (preNodeIdxInner == preNodeIdx) {
continue; continue;
} }
MS_ASSERT(subGraph->nodes.size() > preNodeIdxInner); MS_ASSERT(graph->nodes.size() > preNodeIdxInner);
if (MatchTree(graph, preNodeIdxInner, target->right, sinkIdes, pathSinkIdes)) { if (MatchTree(graph, preNodeIdxInner, target->right, sinkIdes, pathSinkIdes)) {
return true; // ignore follow match, pick the first match return true; // ignore follow match, pick the first match
} }

View File

@ -22,7 +22,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "tools/common/node_util.h"
#include "tools/converter/optimizer.h" #include "tools/converter/optimizer.h"
#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h" #include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h"

View File

@ -17,7 +17,6 @@
#include <set> #include <set>
#include <utility> #include <utility>
#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h" #include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"

View File

@ -0,0 +1,92 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include <string>
#include <unordered_map>
#include <vector>
#include <memory>
#include "schema/inner/model_generated.h"
#include "tools/common/meta_graph_utils.h"
namespace {
constexpr int kNumBiasMatchPathLen = 2;
constexpr const char *MulName = "MATMUL";
constexpr const char *BiasName = "BIASADD";
} // namespace
namespace mindspore {
namespace lite {
STATUS MatMulBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); }
STATUS MatMulBiasAddFusionPass::DefinePattern() {
auto mul_op = std::make_shared<PatternOp>();
mul_op->id = MulName;
mul_op->types = {schema::PrimitiveType_MatMul};
auto bias_op = std::make_shared<PatternOp>();
bias_op->id = BiasName;
bias_op->types = {schema::PrimitiveType_BiasAdd};
bias_op->left = mul_op;
std::unique_ptr<FusionPattern> fusion_pattern(new (std::nothrow) FusionPattern("MatMulBiasAddFusion"));
if (fusion_pattern == nullptr) {
MS_LOG(ERROR) << "new fusion_pattern failed";
return RET_ERROR;
}
fusion_pattern->AddPatternOp(mul_op);
fusion_pattern->AddPatternOp(bias_op);
fusion_pattern->Finish();
this->patterns.emplace_back(fusion_pattern.release());
return RET_OK;
}
STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pattern_name,
std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) {
MS_ASSERT(graph != nullptr);
if (matched_path.size() != kNumBiasMatchPathLen) {
MS_LOG(ERROR) << "MatMul-BiasAdd-Fusion should have two NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}
auto mul_path = matched_path[MulName];
auto bias_path = matched_path[BiasName];
auto mul_index = mul_path->nodeIdx;
auto bias_index = bias_path->nodeIdx;
auto &mul_node = graph->nodes.at(mul_index);
auto &bias_node = graph->nodes.at(bias_index);
auto bias_tensor_index = bias_node->inputIndex.at(1);
if (mul_node->inputIndex.size() != 2) {
MS_LOG(DEBUG) << "cat not fusion.";
return RET_NO_CHANGE;
}
mul_node->inputIndex.push_back(bias_tensor_index);
mul_node->outputIndex = {bias_node->outputIndex};
graph->nodes.erase(bias_index + graph->nodes.begin());
auto it = find(graph->subGraph.at(0)->nodeIndices.begin(), graph->subGraph.at(0)->nodeIndices.end(),
static_cast<uint32_t>(bias_index));
if (it == graph->subGraph.at(0)->nodeIndices.end()) {
MS_LOG(ERROR) << "can not find node in subgraph.";
return RET_ERROR;
}
graph->subGraph.at(0)->nodeIndices.erase(it);
for (size_t i = 0; i < graph->subGraph.at(0)->nodeIndices.size(); i++) {
if (graph->subGraph.at(0)->nodeIndices.at(i) > static_cast<uint32_t>(bias_index)) {
graph->subGraph.at(0)->nodeIndices.at(i)--;
}
}
return RET_OK;
}
MatMulBiasAddFusionPass::~MatMulBiasAddFusionPass() = default;
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H
#define MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H
#include <string>
#include <unordered_map>
#include <memory>
#include <algorithm>
#include <utility>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
namespace mindspore {
namespace lite {
class MatMulBiasAddFusionPass : public FusionPass {
public:
MatMulBiasAddFusionPass() = default;
~MatMulBiasAddFusionPass() override;
STATUS DefinePattern() override;
STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name,
std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override;
STATUS Run(MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H

View File

@ -21,9 +21,11 @@
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/common/utils.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

View File

@ -20,6 +20,7 @@
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"

View File

@ -18,6 +18,7 @@
#include <queue> #include <queue>
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"

View File

@ -18,7 +18,6 @@
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"

View File

@ -21,7 +21,6 @@
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"

View File

@ -20,7 +20,7 @@
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/common/graph_util.h" #include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"

View File

@ -22,6 +22,7 @@
#include "tools/converter/quantizer/quantize_util.h" #include "tools/converter/quantizer/quantize_util.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "tools/common/node_util.h" #include "tools/common/node_util.h"
#include "src/common/quant_utils.h" #include "src/common/quant_utils.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"

View File

@ -23,6 +23,7 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
#include "tools/common/meta_graph_utils.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {