forked from mindspore-Ecosystem/mindspore
train fusion online
This commit is contained in:
parent
e586287da5
commit
660b335be1
|
@ -190,6 +190,7 @@ set(TRAIN_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/graph_fusion.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_utils.cc
|
||||
|
@ -201,6 +202,11 @@ set(TRAIN_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/optimizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/fusion/fusion_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/fusion/fusion_pattern.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/meta_graph_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.cc
|
||||
)
|
||||
if(MSLITE_ENABLE_V0)
|
||||
set(TRAIN_SRC
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/train/graph_fusion.h"
|
||||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS GraphFusion::Run(schema::MetaGraphT *graph) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "graph is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
Optimizer fusion_optimizer;
|
||||
fusion_optimizer.AddPass(new (std::nothrow) MatMulBiasAddFusionPass());
|
||||
auto status = fusion_optimizer.Run(graph);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "graph fusion failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/optimizer.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class GraphFusion {
|
||||
public:
|
||||
GraphFusion() = default;
|
||||
~GraphFusion() = default;
|
||||
STATUS Run(schema::MetaGraphT *graph);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -27,6 +27,7 @@
|
|||
#include "src/train/train_utils.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
#include "tools/common/storage.h"
|
||||
#include "src/train/graph_fusion.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -530,6 +531,15 @@ int TrainExport::IsInputTensor(const schema::TensorT &t) {
|
|||
return ((t.data.size() == 0) && (total_dims != 0));
|
||||
}
|
||||
|
||||
int TrainExport::TrainModelFusion() {
|
||||
GraphFusion graph_fusion;
|
||||
auto status = graph_fusion.Run(meta_graph_);
|
||||
if (status != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TrainExport::~TrainExport() { delete meta_graph_; }
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,6 +47,7 @@ class TrainExport {
|
|||
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
|
||||
int LoadModel(void *buf, size_t buf_size);
|
||||
int AddTransformNode();
|
||||
int TrainModelFusion();
|
||||
|
||||
protected:
|
||||
virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);
|
||||
|
|
|
@ -1036,6 +1036,13 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
|
|||
MS_LOG(ERROR) << "cannot export Network";
|
||||
return status;
|
||||
}
|
||||
if (model_type == MT_INFERENCE) {
|
||||
status = texport.TrainModelFusion();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TrainModelFusion failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to save to " << file_name;
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "src/weight_decoder.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
|
|
@ -2,6 +2,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
|||
|
||||
file(GLOB CONVERTER_COMMON_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/meta_graph_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/node_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/storage.cc
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <ctime>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
|
@ -74,315 +75,6 @@ OpDefCopyer GetSimpleOpCopyer() {
|
|||
};
|
||||
}
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
|
||||
return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
|
||||
}
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
|
||||
std::vector<uint32_t> inputIndexes;
|
||||
if (inputIndexIdx == -1) {
|
||||
inputIndexes = node.inputIndex;
|
||||
} else {
|
||||
MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
|
||||
inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
|
||||
}
|
||||
std::set<size_t> inputNodeIdx;
|
||||
for (uint32_t inputIdx : inputIndexes) {
|
||||
auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
|
||||
inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
|
||||
}
|
||||
std::vector<size_t> ret;
|
||||
ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
|
||||
const int outputIndexIdx) {
|
||||
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
|
||||
}
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
|
||||
std::replace_if(
|
||||
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
|
||||
for (auto &subGraph : graphT->subGraph) {
|
||||
std::replace_if(
|
||||
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
|
||||
std::vector<uint32_t> outputIndexes;
|
||||
if (outputIndexIdx == -1) {
|
||||
outputIndexes = node.outputIndex;
|
||||
} else {
|
||||
MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
|
||||
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
|
||||
}
|
||||
std::set<size_t> outputNodeIdx;
|
||||
for (uint32_t outputIdx : outputIndexes) {
|
||||
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
|
||||
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
|
||||
}
|
||||
std::vector<size_t> ret;
|
||||
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||
std::vector<size_t> preNodeIdx;
|
||||
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||
auto &oldNode = graphT.nodes.at(i);
|
||||
if (oldNode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto outputIndexes = oldNode->outputIndex;
|
||||
if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
|
||||
preNodeIdx.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return preNodeIdx;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||
std::vector<size_t> postNodeIdx;
|
||||
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||
auto &oldNode = graphT.nodes.at(i);
|
||||
if (oldNode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto inputIndexes = oldNode->inputIndex;
|
||||
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
|
||||
postNodeIdx.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return postNodeIdx;
|
||||
}
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
size_t nodeIdx = 0;
|
||||
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
||||
auto &inNode = graphT->nodes.at(i);
|
||||
MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
|
||||
if (inNode->name == node->name) {
|
||||
nodeIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
if (inputTensorIdxes.empty()) {
|
||||
MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputTensorIdxes.size() != 1) {
|
||||
MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
|
||||
<< "should has 1 output, in fact: " << outputTensorIdxes.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inDataTensorIdx = inputTensorIdxes.front();
|
||||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
|
||||
auto &postNode = graphT->nodes.at(postNodeIdx);
|
||||
MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
|
||||
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RemoveTensor(graphT, outputTensorIdxes);
|
||||
node->inputIndex.clear();
|
||||
node->outputIndex.clear();
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
return IsolateOneWayNode(graph, nodeIdx, removeTensor);
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
if (graphT->nodes.size() <= nodeIdx) {
|
||||
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
CNodeT *node = graphT->nodes.at(nodeIdx).get();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "node is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
|
||||
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
|
||||
MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (inputTensorIdxes.empty()) {
|
||||
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inDataTensorIdx = inputTensorIdxes.front();
|
||||
if (!outputTensorIdxes.empty()) {
|
||||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
|
||||
auto &postNode = graphT->nodes.at(postNodeIdx);
|
||||
MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
|
||||
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (removeTensor) {
|
||||
// now all node's outputTensors are useless
|
||||
// remove all node's outputTensors
|
||||
auto status = RemoveTensor(graphT, outputTensorIdxes);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
node->inputIndex.clear();
|
||||
node->outputIndex.clear();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
bool isSubNode = false;
|
||||
size_t nodeIdx = 0;
|
||||
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
||||
auto &inNode = graphT->nodes.at(i);
|
||||
MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
|
||||
if (inNode->name == node->name) {
|
||||
isSubNode = true;
|
||||
nodeIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isSubNode) {
|
||||
MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
|
||||
return RET_PARAM_INVALID;
|
||||
} else {
|
||||
return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
|
||||
}
|
||||
}
|
||||
|
||||
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
|
||||
uint32_t deleteIdx = *iter;
|
||||
if (!forceDelete) {
|
||||
if (GetRefCount(graphT, deleteIdx) > 1) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// update graph input indices
|
||||
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update graph output indices
|
||||
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &subgraph : graphT->subGraph) {
|
||||
// update subgraph input indices
|
||||
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
|
||||
if (*idx > deleteIdx) {
|
||||
(*idx)--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update nodes indexes
|
||||
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
|
||||
// update nodes input indexes
|
||||
UpdateNodeIndex((*node_iter).get(), deleteIdx);
|
||||
}
|
||||
// update deleteTensorIdx
|
||||
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
|
||||
if (*selfIt > deleteIdx) {
|
||||
(*selfIt)--;
|
||||
}
|
||||
}
|
||||
graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
|
||||
iter = toDeleteTensorIdxes.erase(iter);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
|
||||
if (*inIdxIt == deleteIdx) {
|
||||
inIdxIt = node->inputIndex.erase(inIdxIt);
|
||||
} else {
|
||||
if (*inIdxIt > deleteIdx) {
|
||||
(*inIdxIt)--;
|
||||
}
|
||||
inIdxIt++;
|
||||
}
|
||||
}
|
||||
// update nodes output indexes
|
||||
for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
|
||||
if (*outIdxIt == deleteIdx) {
|
||||
outIdxIt = node->outputIndex.erase(outIdxIt);
|
||||
} else {
|
||||
if (*outIdxIt > deleteIdx) {
|
||||
(*outIdxIt)--;
|
||||
}
|
||||
outIdxIt++;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
|
||||
InsertPlace place) {
|
||||
if (nodeIdx >= graphT->nodes.size()) {
|
||||
|
@ -672,33 +364,6 @@ std::string GetModelName(const std::string &modelFile) {
|
|||
return modelName;
|
||||
}
|
||||
|
||||
int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
|
||||
for (auto &subgraph : meta_graphT->subGraph) {
|
||||
std::vector<uint32_t> subgraph_indices{};
|
||||
subgraph_indices.assign(subgraph->inputIndices.begin(), subgraph->inputIndices.end());
|
||||
subgraph_indices.assign(subgraph->outputIndices.begin(), subgraph->outputIndices.end());
|
||||
for (auto &node_idx : subgraph->nodeIndices) {
|
||||
auto &node = meta_graphT->nodes.at(node_idx);
|
||||
for (auto &input_idx : node->inputIndex) {
|
||||
if (IsContain(subgraph_indices, input_idx)) {
|
||||
continue;
|
||||
} else {
|
||||
subgraph_indices.push_back(input_idx);
|
||||
}
|
||||
}
|
||||
for (auto &output_idx : node->outputIndex) {
|
||||
if (IsContain(subgraph_indices, output_idx)) {
|
||||
continue;
|
||||
} else {
|
||||
subgraph_indices.push_back(output_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<int> GetTransposePerm(MetaGraphT *graph, const std::unique_ptr<CNodeT> &cnode) {
|
||||
MS_ASSERT(graph != nullptr && cnode != nullptr);
|
||||
std::vector<int> perm;
|
||||
|
|
|
@ -48,34 +48,6 @@ OpDefCopyer GetSimpleOpCopyer();
|
|||
|
||||
int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
int outputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true);
|
||||
|
||||
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx);
|
||||
|
||||
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false);
|
||||
|
||||
STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<schema::TensorT> tensor,
|
||||
InsertPlace place = kBefore);
|
||||
|
||||
|
@ -102,8 +74,6 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|||
|
||||
STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType);
|
||||
|
||||
STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT);
|
||||
|
||||
std::string GetModelName(const std::string &modelFile);
|
||||
|
||||
std::vector<int> GetTransposePerm(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &cnode);
|
||||
|
|
|
@ -0,0 +1,374 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(graphT->allTensors.size() > tensorIdx);
|
||||
size_t refCount = 0;
|
||||
for (auto &node : graphT->nodes) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (IsContain(node->inputIndex, tensorIdx)) {
|
||||
refCount++;
|
||||
}
|
||||
}
|
||||
return refCount;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||
std::vector<size_t> postNodeIdx;
|
||||
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||
auto &oldNode = graphT.nodes.at(i);
|
||||
if (oldNode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto inputIndexes = oldNode->inputIndex;
|
||||
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
|
||||
postNodeIdx.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return postNodeIdx;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||
std::vector<size_t> preNodeIdx;
|
||||
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||
auto &oldNode = graphT.nodes.at(i);
|
||||
if (oldNode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto outputIndexes = oldNode->outputIndex;
|
||||
if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
|
||||
preNodeIdx.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return preNodeIdx;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
const int inputIndexIdx) {
|
||||
std::vector<uint32_t> inputIndexes;
|
||||
if (inputIndexIdx == -1) {
|
||||
inputIndexes = node.inputIndex;
|
||||
} else {
|
||||
MS_ASSERT(node.inputIndex.size() > static_cast<uint32_t>(inputIndexIdx));
|
||||
inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
|
||||
}
|
||||
std::set<size_t> inputNodeIdx;
|
||||
for (uint32_t inputIdx : inputIndexes) {
|
||||
auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
|
||||
inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
|
||||
}
|
||||
std::vector<size_t> ret;
|
||||
ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
|
||||
return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
const int outputIndexIdx) {
|
||||
std::vector<uint32_t> outputIndexes;
|
||||
if (outputIndexIdx == -1) {
|
||||
outputIndexes = node.outputIndex;
|
||||
} else {
|
||||
MS_ASSERT(node.outputIndex.size() > static_cast<uint32_t>(outputIndexIdx));
|
||||
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
|
||||
}
|
||||
std::set<size_t> outputNodeIdx;
|
||||
for (uint32_t outputIdx : outputIndexes) {
|
||||
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
|
||||
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
|
||||
}
|
||||
std::vector<size_t> ret;
|
||||
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
|
||||
const int outputIndexIdx) {
|
||||
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
|
||||
}
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
|
||||
std::replace_if(
|
||||
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
|
||||
for (auto &subGraph : graphT->subGraph) {
|
||||
std::replace_if(
|
||||
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
}
|
||||
}
|
||||
|
||||
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
|
||||
if (*inIdxIt == deleteIdx) {
|
||||
inIdxIt = node->inputIndex.erase(inIdxIt);
|
||||
} else {
|
||||
if (*inIdxIt > deleteIdx) {
|
||||
(*inIdxIt)--;
|
||||
}
|
||||
inIdxIt++;
|
||||
}
|
||||
}
|
||||
// update nodes output indexes
|
||||
for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
|
||||
if (*outIdxIt == deleteIdx) {
|
||||
outIdxIt = node->outputIndex.erase(outIdxIt);
|
||||
} else {
|
||||
if (*outIdxIt > deleteIdx) {
|
||||
(*outIdxIt)--;
|
||||
}
|
||||
outIdxIt++;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
|
||||
uint32_t deleteIdx = *iter;
|
||||
if (!forceDelete) {
|
||||
if (GetRefCount(graphT, deleteIdx) > 1) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// update graph input indices
|
||||
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update graph output indices
|
||||
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &subgraph : graphT->subGraph) {
|
||||
// update subgraph input indices
|
||||
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
|
||||
if (*idx > deleteIdx) {
|
||||
(*idx)--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update nodes indexes
|
||||
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
|
||||
// update nodes input indexes
|
||||
UpdateNodeIndex((*node_iter).get(), deleteIdx);
|
||||
}
|
||||
// update deleteTensorIdx
|
||||
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
|
||||
if (*selfIt > deleteIdx) {
|
||||
(*selfIt)--;
|
||||
}
|
||||
}
|
||||
graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
|
||||
iter = toDeleteTensorIdxes.erase(iter);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *graphT, schema::CNodeT *node) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
size_t nodeIdx = 0;
|
||||
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
||||
auto &inNode = graphT->nodes.at(i);
|
||||
MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
|
||||
if (inNode->name == node->name) {
|
||||
nodeIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
if (inputTensorIdxes.empty()) {
|
||||
MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputTensorIdxes.size() != 1) {
|
||||
MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
|
||||
<< "should has 1 output, in fact: " << outputTensorIdxes.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inDataTensorIdx = inputTensorIdxes.front();
|
||||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
|
||||
auto &postNode = graphT->nodes.at(postNodeIdx);
|
||||
MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
|
||||
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
RemoveTensor(graphT, outputTensorIdxes);
|
||||
node->inputIndex.clear();
|
||||
node->outputIndex.clear();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
if (graphT->nodes.size() <= nodeIdx) {
|
||||
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
schema::CNodeT *node = graphT->nodes.at(nodeIdx).get();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "node is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
|
||||
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
|
||||
MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (inputTensorIdxes.empty()) {
|
||||
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inDataTensorIdx = inputTensorIdxes.front();
|
||||
if (!outputTensorIdxes.empty()) {
|
||||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
|
||||
auto &postNode = graphT->nodes.at(postNodeIdx);
|
||||
MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
|
||||
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (removeTensor) {
|
||||
// now all node's outputTensors are useless
|
||||
// remove all node's outputTensors
|
||||
auto status = RemoveTensor(graphT, outputTensorIdxes);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
node->inputIndex.clear();
|
||||
node->outputIndex.clear();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
return IsolateOneWayNode(graph, nodeIdx, removeTensor);
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
bool isSubNode = false;
|
||||
size_t nodeIdx = 0;
|
||||
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
||||
auto &inNode = graphT->nodes.at(i);
|
||||
MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
|
||||
if (inNode->name == node->name) {
|
||||
isSubNode = true;
|
||||
nodeIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isSubNode) {
|
||||
MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
|
||||
return RET_PARAM_INVALID;
|
||||
} else {
|
||||
return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
|
||||
}
|
||||
}
|
||||
|
||||
int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
|
||||
for (auto &subgraph : meta_graphT->subGraph) {
|
||||
std::vector<uint32_t> subgraph_indices{};
|
||||
subgraph_indices.assign(subgraph->inputIndices.begin(), subgraph->inputIndices.end());
|
||||
subgraph_indices.assign(subgraph->outputIndices.begin(), subgraph->outputIndices.end());
|
||||
for (auto &node_idx : subgraph->nodeIndices) {
|
||||
auto &node = meta_graphT->nodes.at(node_idx);
|
||||
for (auto &input_idx : node->inputIndex) {
|
||||
if (IsContain(subgraph_indices, input_idx)) {
|
||||
continue;
|
||||
} else {
|
||||
subgraph_indices.push_back(input_idx);
|
||||
}
|
||||
}
|
||||
for (auto &output_idx : node->outputIndex) {
|
||||
if (IsContain(subgraph_indices, output_idx)) {
|
||||
continue;
|
||||
} else {
|
||||
subgraph_indices.push_back(output_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTIL_H
|
||||
#define MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTIL_H
|
||||
|
||||
#include <vector>
|
||||
#include "inner/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
namespace mindspore::lite {
|
||||
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
int outputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1);
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
|
||||
|
||||
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false);
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
|
||||
|
||||
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
|
||||
|
||||
STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT);
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTIL_H
|
|
@ -30,6 +30,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/meta_graph_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
@ -131,7 +131,7 @@ STATUS FusionPass::MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pat
|
|||
for (auto index : graph->outputIndex) {
|
||||
auto subGraphOutputNodeIdxes = GetLinkedPreIdx(*graph, index);
|
||||
for (auto subGraphOutputNodeIdx : subGraphOutputNodeIdxes) {
|
||||
MS_ASSERT(subGraph->nodes.size() > subGraphOutputNodeIdx);
|
||||
MS_ASSERT(graph->nodes.size() > subGraphOutputNodeIdx);
|
||||
nodeQueue.push(subGraphOutputNodeIdx);
|
||||
}
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ STATUS FusionPass::MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pat
|
|||
if (IsContain(sinkIdes, nodeIdx)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(subGraph->nodes.size() > nodeIdx);
|
||||
MS_ASSERT(graph->nodes.size() > nodeIdx);
|
||||
auto &node = graph->nodes.at(nodeIdx);
|
||||
sinkIdes.emplace_back(nodeIdx);
|
||||
|
||||
|
@ -151,7 +151,7 @@ STATUS FusionPass::MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pat
|
|||
}
|
||||
auto preNodeIdxes = GetInputNodeIdx(*graph, nodeIdx);
|
||||
for (auto preNodeIdx : preNodeIdxes) {
|
||||
MS_ASSERT((subGraph->nodes.size() > preNodeIdx));
|
||||
MS_ASSERT(graph->nodes.size() > preNodeIdx);
|
||||
nodeQueue.push(preNodeIdx);
|
||||
}
|
||||
}
|
||||
|
@ -277,7 +277,7 @@ bool FusionPass::MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std:
|
|||
if (preNodeIdxInner == preNodeIdx) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(subGraph->nodes.size() > preNodeIdxInner);
|
||||
MS_ASSERT(graph->nodes.size() > preNodeIdxInner);
|
||||
if (MatchTree(graph, preNodeIdxInner, target->right, sinkIdes, pathSinkIdes)) {
|
||||
return true; // ignore follow match, pick the first match
|
||||
}
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h"
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include <set>
|
||||
#include <utility>
|
||||
#include "tools/converter/legacy_optimizer/fusion/fusion_pattern.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
namespace {
|
||||
constexpr int kNumBiasMatchPathLen = 2;
|
||||
constexpr const char *MulName = "MATMUL";
|
||||
constexpr const char *BiasName = "BIASADD";
|
||||
} // namespace
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
STATUS MatMulBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); }
|
||||
|
||||
STATUS MatMulBiasAddFusionPass::DefinePattern() {
|
||||
auto mul_op = std::make_shared<PatternOp>();
|
||||
mul_op->id = MulName;
|
||||
mul_op->types = {schema::PrimitiveType_MatMul};
|
||||
auto bias_op = std::make_shared<PatternOp>();
|
||||
bias_op->id = BiasName;
|
||||
bias_op->types = {schema::PrimitiveType_BiasAdd};
|
||||
bias_op->left = mul_op;
|
||||
std::unique_ptr<FusionPattern> fusion_pattern(new (std::nothrow) FusionPattern("MatMulBiasAddFusion"));
|
||||
if (fusion_pattern == nullptr) {
|
||||
MS_LOG(ERROR) << "new fusion_pattern failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
fusion_pattern->AddPatternOp(mul_op);
|
||||
fusion_pattern->AddPatternOp(bias_op);
|
||||
fusion_pattern->Finish();
|
||||
this->patterns.emplace_back(fusion_pattern.release());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pattern_name,
|
||||
std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (matched_path.size() != kNumBiasMatchPathLen) {
|
||||
MS_LOG(ERROR) << "MatMul-BiasAdd-Fusion should have two NodeIndex in matchedPair";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto mul_path = matched_path[MulName];
|
||||
auto bias_path = matched_path[BiasName];
|
||||
auto mul_index = mul_path->nodeIdx;
|
||||
auto bias_index = bias_path->nodeIdx;
|
||||
auto &mul_node = graph->nodes.at(mul_index);
|
||||
auto &bias_node = graph->nodes.at(bias_index);
|
||||
auto bias_tensor_index = bias_node->inputIndex.at(1);
|
||||
if (mul_node->inputIndex.size() != 2) {
|
||||
MS_LOG(DEBUG) << "cat not fusion.";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
mul_node->inputIndex.push_back(bias_tensor_index);
|
||||
mul_node->outputIndex = {bias_node->outputIndex};
|
||||
graph->nodes.erase(bias_index + graph->nodes.begin());
|
||||
auto it = find(graph->subGraph.at(0)->nodeIndices.begin(), graph->subGraph.at(0)->nodeIndices.end(),
|
||||
static_cast<uint32_t>(bias_index));
|
||||
if (it == graph->subGraph.at(0)->nodeIndices.end()) {
|
||||
MS_LOG(ERROR) << "can not find node in subgraph.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
graph->subGraph.at(0)->nodeIndices.erase(it);
|
||||
for (size_t i = 0; i < graph->subGraph.at(0)->nodeIndices.size(); i++) {
|
||||
if (graph->subGraph.at(0)->nodeIndices.at(i) > static_cast<uint32_t>(bias_index)) {
|
||||
graph->subGraph.at(0)->nodeIndices.at(i)--;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
MatMulBiasAddFusionPass::~MatMulBiasAddFusionPass() = default;
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H
|
||||
#define MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class MatMulBiasAddFusionPass : public FusionPass {
|
||||
public:
|
||||
MatMulBiasAddFusionPass() = default;
|
||||
|
||||
~MatMulBiasAddFusionPass() override;
|
||||
|
||||
STATUS DefinePattern() override;
|
||||
|
||||
STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name,
|
||||
std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override;
|
||||
|
||||
STATUS Run(MetaGraphT *graph) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H
|
|
@ -21,9 +21,11 @@
|
|||
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <queue>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
|
||||
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
Loading…
Reference in New Issue