0115-fix-convert-2

This commit is contained in:
yefeng 2020-12-22 15:35:51 +08:00
parent 70b75e1cfd
commit fb4bd85656
2 changed files with 60 additions and 11 deletions

View File

@ -15,6 +15,7 @@
*/
#include <vector>
#include <set>
#include <algorithm>
#include <memory>
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
@ -27,13 +28,53 @@
namespace mindspore {
namespace lite {
void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
for (auto &subgraph : graph->subGraph) {
for (auto &idx : subgraph->nodeIndices) {
if (idx > node_idx) {
idx--;
}
std::set<uint32_t> SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
schema::MetaGraphT *graph) {
std::set<uint32_t> tensors_indices{};
for (auto &node_idx : subgraph->nodeIndices) {
auto &node = graph->nodes.at(node_idx);
for (auto &input_idx : node->inputIndex) {
tensors_indices.insert(input_idx);
}
for (auto &output_idx : node->outputIndex) {
tensors_indices.insert(output_idx);
}
}
return tensors_indices;
}
bool SubgraphNodePass::IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph) {
return (std::any_of(node->inputIndex.begin(), node->inputIndex.end(),
[&tensors_indices, &subgraph](uint32_t idx) {
return tensors_indices.count(idx) > 0 || IsContain(subgraph->inputIndices, idx);
})) &&
(std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
return tensors_indices.count(idx) > 0 || IsContain(subgraph->outputIndices, idx);
}));
}
void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
for (auto &subgraph : graph->subGraph) {
std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
[&node_idx](uint32_t idx) {
if (idx > node_idx) {
return --idx;
}
return idx;
});
}
}
void SubgraphNodePass::IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
for (auto &subgraph : graph->subGraph) {
std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
[&node_idx](uint32_t idx) {
if (idx >= node_idx) {
return ++idx;
}
return idx;
});
}
}
@ -50,7 +91,7 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
if (node_idx_pos != subgraph->nodeIndices.end()) {
subgraph->nodeIndices.erase(node_idx_pos);
UpdateSubgraphNodeIndices(node_idx, graph);
DecreaseSubgraphNodeIndices(node_idx, graph);
break;
}
}
@ -62,10 +103,12 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
for (uint32_t i = 0; i < new_nodes.size(); i++) {
if (!IsContain(old_nodes_, new_nodes[i])) {
auto &node = graph->nodes.at(i);
for (auto &subgraph : graph->subGraph) {
if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) {
subgraph->nodeIndices.push_back(old_nodes_.size());
old_nodes_.push_back(new_nodes[i]);
auto tensors_indices = GetSubgraphAllTensorIndices(subgraph, graph);
if (IsNodeInSubgraph(tensors_indices, node, subgraph)) {
IncreaseSubgraphNodeIndices(i, graph);
subgraph->nodeIndices.push_back(i);
}
}
}

View File

@ -19,6 +19,8 @@
#include <vector>
#include <utility>
#include <set>
#include <memory>
#include "tools/converter/optimizer.h"
namespace mindspore {
@ -32,7 +34,11 @@ class SubgraphNodePass : public GraphPass {
STATUS Run(schema::MetaGraphT *graph) override;
private:
void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
void DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
void IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
std::set<uint32_t> GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph, schema::MetaGraphT *graph);
bool IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph);
std::vector<schema::CNodeT *> old_nodes_;
};
} // namespace lite