0115-fix-convert-2
This commit is contained in:
parent
70b75e1cfd
commit
fb4bd85656
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
|
@ -27,13 +28,53 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
for (auto &idx : subgraph->nodeIndices) {
|
||||
if (idx > node_idx) {
|
||||
idx--;
|
||||
}
|
||||
std::set<uint32_t> SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
|
||||
schema::MetaGraphT *graph) {
|
||||
std::set<uint32_t> tensors_indices{};
|
||||
for (auto &node_idx : subgraph->nodeIndices) {
|
||||
auto &node = graph->nodes.at(node_idx);
|
||||
for (auto &input_idx : node->inputIndex) {
|
||||
tensors_indices.insert(input_idx);
|
||||
}
|
||||
for (auto &output_idx : node->outputIndex) {
|
||||
tensors_indices.insert(output_idx);
|
||||
}
|
||||
}
|
||||
return tensors_indices;
|
||||
}
|
||||
|
||||
bool SubgraphNodePass::IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
|
||||
const std::unique_ptr<SubGraphT> &subgraph) {
|
||||
return (std::any_of(node->inputIndex.begin(), node->inputIndex.end(),
|
||||
[&tensors_indices, &subgraph](uint32_t idx) {
|
||||
return tensors_indices.count(idx) > 0 || IsContain(subgraph->inputIndices, idx);
|
||||
})) &&
|
||||
(std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
|
||||
return tensors_indices.count(idx) > 0 || IsContain(subgraph->outputIndices, idx);
|
||||
}));
|
||||
}
|
||||
|
||||
void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
|
||||
[&node_idx](uint32_t idx) {
|
||||
if (idx > node_idx) {
|
||||
return --idx;
|
||||
}
|
||||
return idx;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void SubgraphNodePass::IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
|
||||
[&node_idx](uint32_t idx) {
|
||||
if (idx >= node_idx) {
|
||||
return ++idx;
|
||||
}
|
||||
return idx;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,7 +91,7 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
|
|||
auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
|
||||
if (node_idx_pos != subgraph->nodeIndices.end()) {
|
||||
subgraph->nodeIndices.erase(node_idx_pos);
|
||||
UpdateSubgraphNodeIndices(node_idx, graph);
|
||||
DecreaseSubgraphNodeIndices(node_idx, graph);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -62,10 +103,12 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
|
|||
|
||||
for (uint32_t i = 0; i < new_nodes.size(); i++) {
|
||||
if (!IsContain(old_nodes_, new_nodes[i])) {
|
||||
auto &node = graph->nodes.at(i);
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) {
|
||||
subgraph->nodeIndices.push_back(old_nodes_.size());
|
||||
old_nodes_.push_back(new_nodes[i]);
|
||||
auto tensors_indices = GetSubgraphAllTensorIndices(subgraph, graph);
|
||||
if (IsNodeInSubgraph(tensors_indices, node, subgraph)) {
|
||||
IncreaseSubgraphNodeIndices(i, graph);
|
||||
subgraph->nodeIndices.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include "tools/converter/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -32,7 +34,11 @@ class SubgraphNodePass : public GraphPass {
|
|||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
private:
|
||||
void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
|
||||
void DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
|
||||
void IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
|
||||
std::set<uint32_t> GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph, schema::MetaGraphT *graph);
|
||||
bool IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
|
||||
const std::unique_ptr<SubGraphT> &subgraph);
|
||||
std::vector<schema::CNodeT *> old_nodes_;
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
Loading…
Reference in New Issue