Fix dropout bug when saving inference file

This commit is contained in:
nizzan 2022-02-20 18:05:46 +02:00
parent 01708bff83
commit 29ffaa5751
22 changed files with 540 additions and 345 deletions

View File

@ -142,7 +142,13 @@ set(TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_monitor.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/graph_dropout.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/meta_graph_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/optimizer.cc
)
if(ENABLE_V0)
set(TRAIN_SRC

View File

@ -0,0 +1,55 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include <memory>
#include "src/train/graph_dropout.h"
#include "tools/converter/optimizer.h"
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "src/common/utils.h"
namespace mindspore {
namespace lite {
std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT) {
std::vector<schema::CNodeT *> old_nodes{};
old_nodes.resize(graph_defT.nodes.size());
std::transform(graph_defT.nodes.begin(), graph_defT.nodes.end(), old_nodes.begin(),
[](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
return old_nodes;
}
STATUS GraphDropout::Run(schema::MetaGraphT *graph) {
if (graph == nullptr) {
MS_LOG(ERROR) << "graph is nullptr.";
return RET_ERROR;
}
Optimizer dropout_optimizer;
auto old_nodes = GetGraphNodes(*graph);
dropout_optimizer.AddPass(new (std::nothrow) DropoutNodeRemovePass());
dropout_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
dropout_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
auto status = dropout_optimizer.Run(graph);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "graph fusion failed.";
return RET_ERROR;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_
#define MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_
#include "tools/converter/optimizer.h"
#include "inner/model_generated.h"
#include "include/errorcode.h"
namespace mindspore {
namespace lite {
class GraphDropout {
public:
GraphDropout() = default;
~GraphDropout() = default;
STATUS Run(schema::MetaGraphT *graph);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_

View File

@ -23,6 +23,7 @@
#include <set>
#include "schema/inner/model_generated.h"
#include "src/train/train_utils.h"
#include "src/train/graph_dropout.h"
#include "src/common/quant_utils.h"
#include "tools/common/storage.h"
@ -420,6 +421,15 @@ int TrainExport::IsInputTensor(const schema::TensorT &t) {
return ((t.data.size() == 0) && (total_dims != 0));
}
int TrainExport::TrainModelDrop() {
GraphDropout graph_dropout;
auto status = graph_dropout.Run(meta_graph_);
if (status != RET_OK) {
return RET_ERROR;
}
return RET_OK;
}
TrainExport::~TrainExport() { delete meta_graph_; }
} // namespace lite
} // namespace mindspore

View File

@ -47,6 +47,7 @@ class TrainExport {
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
int LoadModel(void *buf, size_t buf_size);
int AddTransformNode();
int TrainModelDrop();
protected:
virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);

View File

@ -716,6 +716,13 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
MS_LOG(ERROR) << "cannot export Network";
return status;
}
if (model_type == MT_INFERENCE) {
status = texport.TrainModelDrop();
if (status != RET_OK) {
MS_LOG(ERROR) << "TrainModelDrop failed.";
return status;
}
}
status = texport.SaveToFile();
if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << file_name;

View File

@ -270,6 +270,7 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/parallel/spliter.cc
${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc
${LITE_DIR}/tools/common/graph_util.cc
${LITE_DIR}/tools/common/meta_graph_utils.cc
${LITE_DIR}/tools/common/tensor_util.cc
${LITE_DIR}/tools/common/node_util.cc
${LITE_DIR}/tools/common/storage.cc
@ -281,6 +282,15 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/converter/import/primitive_adjust.cc
${LITE_DIR}/tools/converter/import/mindir_adjust.cc
)
else()
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${LITE_DIR}/tools/common/meta_graph_utils.cc
${LITE_DIR}/tools/converter/optimizer.cc
${LITE_DIR}/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc
${LITE_DIR}/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc
${LITE_DIR}/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc
)
endif()
### train
if(SUPPORT_TRAIN)
@ -292,6 +302,7 @@ if(SUPPORT_TRAIN)
${LITE_DIR}/src/train/train_export.cc
${LITE_DIR}/src/train/train_utils.cc
${LITE_DIR}/src/train/transfer_session.cc
${LITE_DIR}/src/train/graph_dropout.cc
${LITE_DIR}/src/lite_session.cc
${LITE_DIR}/tools/common/storage.cc
)

View File

@ -37,6 +37,7 @@
#include "tools/converter/quantizer/bitpacking.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "src/ops/ops_utils.h"
#include "tools/common/node_util.h"
#include "tools/converter/converter_context.h"

View File

@ -2,6 +2,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
file(GLOB CONVERTER_COMMON_SRC
${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/meta_graph_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/node_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/storage.cc

View File

@ -24,6 +24,7 @@
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/bitpacking.h"
#include "tools/common/node_util.h"
#include "tools/common/meta_graph_utils.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/converter/ops/ops_def.h"
@ -69,315 +70,6 @@ OpDefCopyer GetSimpleOpCopyer() {
};
}
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
}
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
std::vector<uint32_t> inputIndexes;
if (inputIndexIdx == -1) {
inputIndexes = node.inputIndex;
} else {
MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
}
std::set<size_t> inputNodeIdx;
for (uint32_t inputIdx : inputIndexes) {
auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
return ret;
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
const int outputIndexIdx) {
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
}
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
std::replace_if(
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
for (auto &subGraph : graphT->subGraph) {
std::replace_if(
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
}
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
std::vector<uint32_t> outputIndexes;
if (outputIndexIdx == -1) {
outputIndexes = node.outputIndex;
} else {
MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
}
std::set<size_t> outputNodeIdx;
for (uint32_t outputIdx : outputIndexes) {
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
return ret;
}
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> preNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto outputIndexes = oldNode->outputIndex;
if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
preNodeIdx.emplace_back(i);
}
}
return preNodeIdx;
}
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> postNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto inputIndexes = oldNode->inputIndex;
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
postNodeIdx.emplace_back(i);
}
}
return postNodeIdx;
}
STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
size_t nodeIdx = 0;
for (size_t i = 0; i < graphT->nodes.size(); i++) {
auto &inNode = graphT->nodes.at(i);
MS_ASSERT(inNode != nullptr);
if (inNode->name == node->name) {
nodeIdx = i;
break;
}
}
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
return RET_ERROR;
}
if (outputTensorIdxes.size() != 1) {
MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
<< "should has 1 output, in fact: " << outputTensorIdxes.size();
return RET_ERROR;
}
auto inDataTensorIdx = inputTensorIdxes.front();
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_ASSERT(postNode != nullptr);
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
RemoveTensor(graphT, outputTensorIdxes);
node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
MS_ASSERT(graph != nullptr);
return IsolateOneWayNode(graph, nodeIdx, removeTensor);
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
MS_ASSERT(graphT != nullptr);
if (graphT->nodes.size() <= nodeIdx) {
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
return RET_PARAM_INVALID;
}
CNodeT *node = graphT->nodes.at(nodeIdx).get();
if (node == nullptr) {
MS_LOG(ERROR) << "node is null";
return RET_NULL_PTR;
}
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
return RET_ERROR;
}
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
return RET_ERROR;
}
auto inDataTensorIdx = inputTensorIdxes.front();
if (!outputTensorIdxes.empty()) {
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_ASSERT(postNode != nullptr);
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
}
if (removeTensor) {
// now all node's outputTensors are useless
// remove all node's outputTensors
auto status = RemoveTensor(graphT, outputTensorIdxes);
if (status != RET_OK) {
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
return RET_ERROR;
}
}
node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
bool isSubNode = false;
size_t nodeIdx = 0;
for (size_t i = 0; i < graphT->nodes.size(); i++) {
auto &inNode = graphT->nodes.at(i);
MS_ASSERT(inNode != nullptr);
if (inNode->name == node->name) {
isSubNode = true;
nodeIdx = i;
break;
}
}
if (!isSubNode) {
MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
return RET_PARAM_INVALID;
} else {
return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
}
}
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
MS_ASSERT(graphT != nullptr);
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
uint32_t deleteIdx = *iter;
if (!forceDelete) {
if (GetRefCount(graphT, deleteIdx) > 1) {
iter++;
continue;
}
}
// update graph input indices
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update graph output indices
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
for (auto &subgraph : graphT->subGraph) {
// update subgraph input indices
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update subgraph output indices
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
// update subgraph output indices
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
if (*idx > deleteIdx) {
(*idx)--;
}
}
}
// update nodes indexes
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
// update nodes input indexes
UpdateNodeIndex((*node_iter).get(), deleteIdx);
}
// update deleteTensorIdx
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
if (*selfIt > deleteIdx) {
(*selfIt)--;
}
}
graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
iter = toDeleteTensorIdxes.erase(iter);
}
return RET_OK;
}
STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
MS_ASSERT(node != nullptr);
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
if (*inIdxIt == deleteIdx) {
inIdxIt = node->inputIndex.erase(inIdxIt);
} else {
if (*inIdxIt > deleteIdx) {
(*inIdxIt)--;
}
inIdxIt++;
}
}
// update nodes output indexes
for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
if (*outIdxIt == deleteIdx) {
outIdxIt = node->outputIndex.erase(outIdxIt);
} else {
if (*outIdxIt > deleteIdx) {
(*outIdxIt)--;
}
outIdxIt++;
}
}
return RET_OK;
}
STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
InsertPlace place) {
if (nodeIdx >= graphT->nodes.size()) {

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H
#define MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H
#ifndef MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_
#define MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_
#include <cstdlib>
#include <unordered_map>
@ -48,34 +48,6 @@ OpDefCopyer GetSimpleOpCopyer();
int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
int inputIndexIdx = -1);
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1);
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
int outputIndexIdx = -1);
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true);
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx);
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false);
STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<schema::TensorT> tensor,
InsertPlace place = kBefore);
@ -320,4 +292,4 @@ bool PackRepetition(size_t bit_num, schema::TensorT *tensor) {
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H
#endif // MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_

View File

@ -0,0 +1,346 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/common/meta_graph_utils.h"
#include <vector>
#include <set>
#include "inner/model_generated.h"
#include "src/common/utils.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
namespace {
size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(graphT->allTensors.size() > tensorIdx);
size_t refCount = 0;
for (auto &node : graphT->nodes) {
MS_ASSERT(node != nullptr);
if (IsContain(node->inputIndex, tensorIdx)) {
refCount++;
}
}
return refCount;
}
} // namespace
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> postNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto inputIndexes = oldNode->inputIndex;
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
postNodeIdx.emplace_back(i);
}
}
return postNodeIdx;
}
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> preNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto outputIndexes = oldNode->outputIndex;
if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
preNodeIdx.emplace_back(i);
}
}
return preNodeIdx;
}
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
const int inputIndexIdx) {
std::vector<uint32_t> inputIndexes;
if (inputIndexIdx == -1) {
inputIndexes = node.inputIndex;
} else {
MS_ASSERT(node.inputIndex.size() > static_cast<uint32_t>(inputIndexIdx));
inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
}
std::set<size_t> inputNodeIdx;
for (uint32_t inputIdx : inputIndexes) {
auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
return ret;
}
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
const int outputIndexIdx) {
std::vector<uint32_t> outputIndexes;
if (outputIndexIdx == -1) {
outputIndexes = node.outputIndex;
} else {
MS_ASSERT(node.outputIndex.size() > static_cast<uint32_t>(outputIndexIdx));
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
}
std::set<size_t> outputNodeIdx;
for (uint32_t outputIdx : outputIndexes) {
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
return ret;
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
const int outputIndexIdx) {
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
}
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
std::replace_if(
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
for (auto &subGraph : graphT->subGraph) {
std::replace_if(
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
}
}
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx) {
MS_ASSERT(node != nullptr);
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
if (*inIdxIt == deleteIdx) {
inIdxIt = node->inputIndex.erase(inIdxIt);
} else {
if (*inIdxIt > deleteIdx) {
(*inIdxIt)--;
}
inIdxIt++;
}
}
// update nodes output indexes
for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
if (*outIdxIt == deleteIdx) {
outIdxIt = node->outputIndex.erase(outIdxIt);
} else {
if (*outIdxIt > deleteIdx) {
(*outIdxIt)--;
}
outIdxIt++;
}
}
return RET_OK;
}
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
MS_ASSERT(graphT != nullptr);
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
uint32_t deleteIdx = *iter;
if (!forceDelete) {
if (GetRefCount(graphT, deleteIdx) > 1) {
iter++;
continue;
}
}
// update graph input indices
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update graph output indices
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
for (auto &subgraph : graphT->subGraph) {
// update subgraph input indices
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update subgraph output indices
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
// update subgraph output indices
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
if (*idx > deleteIdx) {
(*idx)--;
}
}
}
// update nodes indexes
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
// update nodes input indexes
UpdateNodeIndex((*node_iter).get(), deleteIdx);
}
// update deleteTensorIdx
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
if (*selfIt > deleteIdx) {
(*selfIt)--;
}
}
graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
iter = toDeleteTensorIdxes.erase(iter);
}
return RET_OK;
}
STATUS IsolateNode(schema::MetaGraphT *graphT, schema::CNodeT *node) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
size_t nodeIdx = 0;
for (size_t i = 0; i < graphT->nodes.size(); i++) {
auto &inNode = graphT->nodes.at(i);
MS_ASSERT(postNode != nullptr);
if (inNode->name == node->name) {
nodeIdx = i;
break;
}
}
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
return RET_ERROR;
}
if (outputTensorIdxes.size() != 1) {
MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
<< "should has 1 output, in fact: " << outputTensorIdxes.size();
return RET_ERROR;
}
auto inDataTensorIdx = inputTensorIdxes.front();
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_ASSERT(postNode != nullptr);
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
RemoveTensor(graphT, outputTensorIdxes);
node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
MS_ASSERT(graphT != nullptr);
if (graphT->nodes.size() <= nodeIdx) {
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
return RET_PARAM_INVALID;
}
schema::CNodeT *node = graphT->nodes.at(nodeIdx).get();
if (node == nullptr) {
MS_LOG(ERROR) << "node is null";
return RET_NULL_PTR;
}
auto inputTensorIdxes = node->inputIndex;
auto outputTensorIdxes = node->outputIndex;
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
return RET_ERROR;
}
if (inputTensorIdxes.empty()) {
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
return RET_ERROR;
}
auto inDataTensorIdx = inputTensorIdxes.front();
if (!outputTensorIdxes.empty()) {
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_ASSERT(postNode != nullptr);
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
}
}
if (removeTensor) {
// now all node's outputTensors are useless
// remove all node's outputTensors
auto status = RemoveTensor(graphT, outputTensorIdxes);
if (status != RET_OK) {
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
return RET_ERROR;
}
}
node->inputIndex.clear();
node->outputIndex.clear();
return RET_OK;
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
MS_ASSERT(graph != nullptr);
return IsolateOneWayNode(graph, nodeIdx, removeTensor);
}
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
bool isSubNode = false;
size_t nodeIdx = 0;
for (size_t i = 0; i < graphT->nodes.size(); i++) {
auto &inNode = graphT->nodes.at(i);
MS_ASSERT(postNode != nullptr);
if (inNode->name == node->name) {
isSubNode = true;
nodeIdx = i;
break;
}
}
if (!isSubNode) {
MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
return RET_PARAM_INVALID;
} else {
return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
}
}
} // namespace mindspore::lite

View File

@ -0,0 +1,54 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_
#define MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_
#include <vector>
#include "inner/model_generated.h"
#include "include/errorcode.h"
namespace mindspore::lite {
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
int inputIndexIdx = -1);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
int outputIndexIdx = -1);
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1);
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false);
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_

View File

@ -22,6 +22,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/meta_graph_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc

View File

@ -27,6 +27,7 @@
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

View File

@ -21,6 +21,7 @@
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
@ -95,7 +96,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
MS_ASSERT(mulNodeBiasTensor != nullptr);
if (mulNodeBiasTensor->nodeType != NodeType_ValueNode) {
// dont fusion, return
// don't fusion, return
return RET_OK;
}
if (mulNodeBiasTensor->dataType == TypeId::kNumberTypeUInt8) {
@ -112,7 +113,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
const auto &addNodeBiasTensor = graph->allTensors.at(addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
MS_ASSERT(addNodeBiasTensor != nullptr);
if (addNodeBiasTensor->nodeType != NodeType_ValueNode) {
// dont fusion, return
// don't fusion, return
return RET_OK;
}
// scale requires scale shape tail sub of input shape, scale shape same as bias shape

View File

@ -21,6 +21,7 @@
#include "src/common/log_adapter.h"
#include "securec/include/securec.h"
#include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

View File

@ -17,7 +17,7 @@
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
#include <queue>
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

View File

@ -20,7 +20,6 @@
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

View File

@ -21,7 +21,6 @@
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

View File

@ -22,6 +22,7 @@
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/common/tensor_util.h"
#include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "tools/common/node_util.h"
#include "src/common/quant_utils.h"

View File

@ -18,8 +18,10 @@
#include <utility>
#include <memory>
#include <vector>
#include <algorithm>
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/common/node_util.h"
#include "tools/common/meta_graph_utils.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"