forked from mindspore-Ecosystem/mindspore
Fix dropout bug when saving inference file
This commit is contained in:
parent
01708bff83
commit
29ffaa5751
|
@ -142,7 +142,13 @@ set(TRAIN_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_monitor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/graph_dropout.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/meta_graph_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../tools/converter/optimizer.cc
|
||||
)
|
||||
if(ENABLE_V0)
|
||||
set(TRAIN_SRC
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "src/train/graph_dropout.h"
|
||||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT) {
|
||||
std::vector<schema::CNodeT *> old_nodes{};
|
||||
old_nodes.resize(graph_defT.nodes.size());
|
||||
std::transform(graph_defT.nodes.begin(), graph_defT.nodes.end(), old_nodes.begin(),
|
||||
[](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
|
||||
return old_nodes;
|
||||
}
|
||||
|
||||
STATUS GraphDropout::Run(schema::MetaGraphT *graph) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "graph is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
Optimizer dropout_optimizer;
|
||||
auto old_nodes = GetGraphNodes(*graph);
|
||||
dropout_optimizer.AddPass(new (std::nothrow) DropoutNodeRemovePass());
|
||||
dropout_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
dropout_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
auto status = dropout_optimizer.Run(graph);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "graph fusion failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_
|
||||
#define MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_
|
||||
|
||||
#include "tools/converter/optimizer.h"
|
||||
#include "inner/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class GraphDropout {
|
||||
public:
|
||||
GraphDropout() = default;
|
||||
~GraphDropout() = default;
|
||||
STATUS Run(schema::MetaGraphT *graph);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_TRAIN_GRAPH_DROPOUT_H_
|
|
@ -23,6 +23,7 @@
|
|||
#include <set>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/train/train_utils.h"
|
||||
#include "src/train/graph_dropout.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
#include "tools/common/storage.h"
|
||||
|
||||
|
@ -420,6 +421,15 @@ int TrainExport::IsInputTensor(const schema::TensorT &t) {
|
|||
return ((t.data.size() == 0) && (total_dims != 0));
|
||||
}
|
||||
|
||||
int TrainExport::TrainModelDrop() {
|
||||
GraphDropout graph_dropout;
|
||||
auto status = graph_dropout.Run(meta_graph_);
|
||||
if (status != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TrainExport::~TrainExport() { delete meta_graph_; }
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,6 +47,7 @@ class TrainExport {
|
|||
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
|
||||
int LoadModel(void *buf, size_t buf_size);
|
||||
int AddTransformNode();
|
||||
int TrainModelDrop();
|
||||
|
||||
protected:
|
||||
virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);
|
||||
|
|
|
@ -716,6 +716,13 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
|
|||
MS_LOG(ERROR) << "cannot export Network";
|
||||
return status;
|
||||
}
|
||||
if (model_type == MT_INFERENCE) {
|
||||
status = texport.TrainModelDrop();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TrainModelDrop failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
status = texport.SaveToFile();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to save to " << file_name;
|
||||
|
|
|
@ -270,6 +270,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/parallel/spliter.cc
|
||||
${LITE_DIR}/tools/optimizer/parallel/split_strategy.cc
|
||||
${LITE_DIR}/tools/common/graph_util.cc
|
||||
${LITE_DIR}/tools/common/meta_graph_utils.cc
|
||||
${LITE_DIR}/tools/common/tensor_util.cc
|
||||
${LITE_DIR}/tools/common/node_util.cc
|
||||
${LITE_DIR}/tools/common/storage.cc
|
||||
|
@ -281,6 +282,15 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/converter/import/primitive_adjust.cc
|
||||
${LITE_DIR}/tools/converter/import/mindir_adjust.cc
|
||||
)
|
||||
else()
|
||||
set(TEST_LITE_SRC
|
||||
${TEST_LITE_SRC}
|
||||
${LITE_DIR}/tools/common/meta_graph_utils.cc
|
||||
${LITE_DIR}/tools/converter/optimizer.cc
|
||||
${LITE_DIR}/tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc
|
||||
)
|
||||
endif()
|
||||
### train
|
||||
if(SUPPORT_TRAIN)
|
||||
|
@ -292,6 +302,7 @@ if(SUPPORT_TRAIN)
|
|||
${LITE_DIR}/src/train/train_export.cc
|
||||
${LITE_DIR}/src/train/train_utils.cc
|
||||
${LITE_DIR}/src/train/transfer_session.cc
|
||||
${LITE_DIR}/src/train/graph_dropout.cc
|
||||
${LITE_DIR}/src/lite_session.cc
|
||||
${LITE_DIR}/tools/common/storage.cc
|
||||
)
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
|
|
|
@ -2,6 +2,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
|||
|
||||
file(GLOB CONVERTER_COMMON_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/meta_graph_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/node_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/storage.cc
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
@ -69,315 +70,6 @@ OpDefCopyer GetSimpleOpCopyer() {
|
|||
};
|
||||
}
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
|
||||
return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
|
||||
}
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
|
||||
std::vector<uint32_t> inputIndexes;
|
||||
if (inputIndexIdx == -1) {
|
||||
inputIndexes = node.inputIndex;
|
||||
} else {
|
||||
MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
|
||||
inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
|
||||
}
|
||||
std::set<size_t> inputNodeIdx;
|
||||
for (uint32_t inputIdx : inputIndexes) {
|
||||
auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
|
||||
inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
|
||||
}
|
||||
std::vector<size_t> ret;
|
||||
ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
|
||||
const int outputIndexIdx) {
|
||||
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
|
||||
}
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
|
||||
std::replace_if(
|
||||
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
|
||||
for (auto &subGraph : graphT->subGraph) {
|
||||
std::replace_if(
|
||||
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
|
||||
std::vector<uint32_t> outputIndexes;
|
||||
if (outputIndexIdx == -1) {
|
||||
outputIndexes = node.outputIndex;
|
||||
} else {
|
||||
MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
|
||||
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
|
||||
}
|
||||
std::set<size_t> outputNodeIdx;
|
||||
for (uint32_t outputIdx : outputIndexes) {
|
||||
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
|
||||
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
|
||||
}
|
||||
std::vector<size_t> ret;
|
||||
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||
std::vector<size_t> preNodeIdx;
|
||||
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||
auto &oldNode = graphT.nodes.at(i);
|
||||
if (oldNode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto outputIndexes = oldNode->outputIndex;
|
||||
if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
|
||||
preNodeIdx.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return preNodeIdx;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||
std::vector<size_t> postNodeIdx;
|
||||
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||
auto &oldNode = graphT.nodes.at(i);
|
||||
if (oldNode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto inputIndexes = oldNode->inputIndex;
|
||||
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
|
||||
postNodeIdx.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return postNodeIdx;
|
||||
}
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
size_t nodeIdx = 0;
|
||||
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
||||
auto &inNode = graphT->nodes.at(i);
|
||||
MS_ASSERT(inNode != nullptr);
|
||||
if (inNode->name == node->name) {
|
||||
nodeIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
if (inputTensorIdxes.empty()) {
|
||||
MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputTensorIdxes.size() != 1) {
|
||||
MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
|
||||
<< "should has 1 output, in fact: " << outputTensorIdxes.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inDataTensorIdx = inputTensorIdxes.front();
|
||||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
|
||||
auto &postNode = graphT->nodes.at(postNodeIdx);
|
||||
MS_ASSERT(postNode != nullptr);
|
||||
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RemoveTensor(graphT, outputTensorIdxes);
|
||||
node->inputIndex.clear();
|
||||
node->outputIndex.clear();
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
return IsolateOneWayNode(graph, nodeIdx, removeTensor);
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
if (graphT->nodes.size() <= nodeIdx) {
|
||||
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
CNodeT *node = graphT->nodes.at(nodeIdx).get();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "node is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
|
||||
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
|
||||
MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (inputTensorIdxes.empty()) {
|
||||
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inDataTensorIdx = inputTensorIdxes.front();
|
||||
if (!outputTensorIdxes.empty()) {
|
||||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
|
||||
auto &postNode = graphT->nodes.at(postNodeIdx);
|
||||
MS_ASSERT(postNode != nullptr);
|
||||
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (removeTensor) {
|
||||
// now all node's outputTensors are useless
|
||||
// remove all node's outputTensors
|
||||
auto status = RemoveTensor(graphT, outputTensorIdxes);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
node->inputIndex.clear();
|
||||
node->outputIndex.clear();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
bool isSubNode = false;
|
||||
size_t nodeIdx = 0;
|
||||
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
||||
auto &inNode = graphT->nodes.at(i);
|
||||
MS_ASSERT(inNode != nullptr);
|
||||
if (inNode->name == node->name) {
|
||||
isSubNode = true;
|
||||
nodeIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isSubNode) {
|
||||
MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
|
||||
return RET_PARAM_INVALID;
|
||||
} else {
|
||||
return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
|
||||
}
|
||||
}
|
||||
|
||||
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
|
||||
uint32_t deleteIdx = *iter;
|
||||
if (!forceDelete) {
|
||||
if (GetRefCount(graphT, deleteIdx) > 1) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// update graph input indices
|
||||
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update graph output indices
|
||||
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &subgraph : graphT->subGraph) {
|
||||
// update subgraph input indices
|
||||
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
|
||||
if (*idx > deleteIdx) {
|
||||
(*idx)--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update nodes indexes
|
||||
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
|
||||
// update nodes input indexes
|
||||
UpdateNodeIndex((*node_iter).get(), deleteIdx);
|
||||
}
|
||||
// update deleteTensorIdx
|
||||
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
|
||||
if (*selfIt > deleteIdx) {
|
||||
(*selfIt)--;
|
||||
}
|
||||
}
|
||||
graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
|
||||
iter = toDeleteTensorIdxes.erase(iter);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
|
||||
if (*inIdxIt == deleteIdx) {
|
||||
inIdxIt = node->inputIndex.erase(inIdxIt);
|
||||
} else {
|
||||
if (*inIdxIt > deleteIdx) {
|
||||
(*inIdxIt)--;
|
||||
}
|
||||
inIdxIt++;
|
||||
}
|
||||
}
|
||||
// update nodes output indexes
|
||||
for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
|
||||
if (*outIdxIt == deleteIdx) {
|
||||
outIdxIt = node->outputIndex.erase(outIdxIt);
|
||||
} else {
|
||||
if (*outIdxIt > deleteIdx) {
|
||||
(*outIdxIt)--;
|
||||
}
|
||||
outIdxIt++;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
|
||||
InsertPlace place) {
|
||||
if (nodeIdx >= graphT->nodes.size()) {
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H
|
||||
#define MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_
|
||||
#define MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_
|
||||
|
||||
#include <cstdlib>
|
||||
#include <unordered_map>
|
||||
|
@ -48,34 +48,6 @@ OpDefCopyer GetSimpleOpCopyer();
|
|||
|
||||
int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
int outputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true);
|
||||
|
||||
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx);
|
||||
|
||||
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false);
|
||||
|
||||
STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<schema::TensorT> tensor,
|
||||
InsertPlace place = kBefore);
|
||||
|
||||
|
@ -320,4 +292,4 @@ bool PackRepetition(size_t bit_num, schema::TensorT *tensor) {
|
|||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_COMMON_GRAPH_UTIL_H_
|
||||
|
|
|
@ -0,0 +1,346 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "inner/model_generated.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(graphT->allTensors.size() > tensorIdx);
|
||||
size_t refCount = 0;
|
||||
for (auto &node : graphT->nodes) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (IsContain(node->inputIndex, tensorIdx)) {
|
||||
refCount++;
|
||||
}
|
||||
}
|
||||
return refCount;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||
std::vector<size_t> postNodeIdx;
|
||||
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||
auto &oldNode = graphT.nodes.at(i);
|
||||
if (oldNode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto inputIndexes = oldNode->inputIndex;
|
||||
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
|
||||
postNodeIdx.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return postNodeIdx;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||
std::vector<size_t> preNodeIdx;
|
||||
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||
auto &oldNode = graphT.nodes.at(i);
|
||||
if (oldNode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto outputIndexes = oldNode->outputIndex;
|
||||
if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
|
||||
preNodeIdx.emplace_back(i);
|
||||
}
|
||||
}
|
||||
return preNodeIdx;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
const int inputIndexIdx) {
|
||||
std::vector<uint32_t> inputIndexes;
|
||||
if (inputIndexIdx == -1) {
|
||||
inputIndexes = node.inputIndex;
|
||||
} else {
|
||||
MS_ASSERT(node.inputIndex.size() > static_cast<uint32_t>(inputIndexIdx));
|
||||
inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
|
||||
}
|
||||
std::set<size_t> inputNodeIdx;
|
||||
for (uint32_t inputIdx : inputIndexes) {
|
||||
auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
|
||||
inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
|
||||
}
|
||||
std::vector<size_t> ret;
|
||||
ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
|
||||
return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
const int outputIndexIdx) {
|
||||
std::vector<uint32_t> outputIndexes;
|
||||
if (outputIndexIdx == -1) {
|
||||
outputIndexes = node.outputIndex;
|
||||
} else {
|
||||
MS_ASSERT(node.outputIndex.size() > static_cast<uint32_t>(outputIndexIdx));
|
||||
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
|
||||
}
|
||||
std::set<size_t> outputNodeIdx;
|
||||
for (uint32_t outputIdx : outputIndexes) {
|
||||
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
|
||||
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
|
||||
}
|
||||
std::vector<size_t> ret;
|
||||
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
|
||||
const int outputIndexIdx) {
|
||||
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
|
||||
}
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
|
||||
std::replace_if(
|
||||
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
|
||||
for (auto &subGraph : graphT->subGraph) {
|
||||
std::replace_if(
|
||||
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
}
|
||||
}
|
||||
|
||||
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
|
||||
if (*inIdxIt == deleteIdx) {
|
||||
inIdxIt = node->inputIndex.erase(inIdxIt);
|
||||
} else {
|
||||
if (*inIdxIt > deleteIdx) {
|
||||
(*inIdxIt)--;
|
||||
}
|
||||
inIdxIt++;
|
||||
}
|
||||
}
|
||||
// update nodes output indexes
|
||||
for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
|
||||
if (*outIdxIt == deleteIdx) {
|
||||
outIdxIt = node->outputIndex.erase(outIdxIt);
|
||||
} else {
|
||||
if (*outIdxIt > deleteIdx) {
|
||||
(*outIdxIt)--;
|
||||
}
|
||||
outIdxIt++;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
|
||||
uint32_t deleteIdx = *iter;
|
||||
if (!forceDelete) {
|
||||
if (GetRefCount(graphT, deleteIdx) > 1) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// update graph input indices
|
||||
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update graph output indices
|
||||
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &subgraph : graphT->subGraph) {
|
||||
// update subgraph input indices
|
||||
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
|
||||
if (*idx > deleteIdx) {
|
||||
(*idx)--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update nodes indexes
|
||||
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
|
||||
// update nodes input indexes
|
||||
UpdateNodeIndex((*node_iter).get(), deleteIdx);
|
||||
}
|
||||
// update deleteTensorIdx
|
||||
for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
|
||||
if (*selfIt > deleteIdx) {
|
||||
(*selfIt)--;
|
||||
}
|
||||
}
|
||||
graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
|
||||
iter = toDeleteTensorIdxes.erase(iter);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *graphT, schema::CNodeT *node) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
size_t nodeIdx = 0;
|
||||
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
||||
auto &inNode = graphT->nodes.at(i);
|
||||
MS_ASSERT(postNode != nullptr);
|
||||
if (inNode->name == node->name) {
|
||||
nodeIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
if (inputTensorIdxes.empty()) {
|
||||
MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputTensorIdxes.size() != 1) {
|
||||
MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
|
||||
<< "should has 1 output, in fact: " << outputTensorIdxes.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inDataTensorIdx = inputTensorIdxes.front();
|
||||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
|
||||
auto &postNode = graphT->nodes.at(postNodeIdx);
|
||||
MS_ASSERT(postNode != nullptr);
|
||||
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
RemoveTensor(graphT, outputTensorIdxes);
|
||||
node->inputIndex.clear();
|
||||
node->outputIndex.clear();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
if (graphT->nodes.size() <= nodeIdx) {
|
||||
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
schema::CNodeT *node = graphT->nodes.at(nodeIdx).get();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "node is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
|
||||
if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
|
||||
MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (inputTensorIdxes.empty()) {
|
||||
MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inDataTensorIdx = inputTensorIdxes.front();
|
||||
if (!outputTensorIdxes.empty()) {
|
||||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
|
||||
auto &postNode = graphT->nodes.at(postNodeIdx);
|
||||
MS_ASSERT(postNode != nullptr);
|
||||
for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (removeTensor) {
|
||||
// now all node's outputTensors are useless
|
||||
// remove all node's outputTensors
|
||||
auto status = RemoveTensor(graphT, outputTensorIdxes);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
node->inputIndex.clear();
|
||||
node->outputIndex.clear();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
return IsolateOneWayNode(graph, nodeIdx, removeTensor);
|
||||
}
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor) {
|
||||
MS_ASSERT(graphT != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
bool isSubNode = false;
|
||||
size_t nodeIdx = 0;
|
||||
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
||||
auto &inNode = graphT->nodes.at(i);
|
||||
MS_ASSERT(postNode != nullptr);
|
||||
if (inNode->name == node->name) {
|
||||
isSubNode = true;
|
||||
nodeIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isSubNode) {
|
||||
MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
|
||||
return RET_PARAM_INVALID;
|
||||
} else {
|
||||
return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_
|
||||
|
||||
#include <vector>
|
||||
#include "inner/model_generated.h"
|
||||
#include "include/errorcode.h"
|
||||
namespace mindspore::lite {
|
||||
std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
int outputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1);
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
|
||||
|
||||
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete = false);
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
|
||||
|
||||
STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_COMMON_META_GRAPH_UTILS_H_
|
|
@ -22,6 +22,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/quant_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/meta_graph_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
|
@ -95,7 +96,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
|
|||
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
|
||||
MS_ASSERT(mulNodeBiasTensor != nullptr);
|
||||
if (mulNodeBiasTensor->nodeType != NodeType_ValueNode) {
|
||||
// dont fusion, return
|
||||
// don't fusion, return
|
||||
return RET_OK;
|
||||
}
|
||||
if (mulNodeBiasTensor->dataType == TypeId::kNumberTypeUInt8) {
|
||||
|
@ -112,7 +113,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
|
|||
const auto &addNodeBiasTensor = graph->allTensors.at(addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
|
||||
MS_ASSERT(addNodeBiasTensor != nullptr);
|
||||
if (addNodeBiasTensor->nodeType != NodeType_ValueNode) {
|
||||
// dont fusion, return
|
||||
// don't fusion, return
|
||||
return RET_OK;
|
||||
}
|
||||
// scale requires scale shape tail sub of input shape, scale shape same as bias shape
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
|
||||
#include <queue>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
#include <utility>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
|
|
Loading…
Reference in New Issue