add model input index and output index

This commit is contained in:
mengyuanli 2021-06-28 09:47:47 +08:00
parent 19e95c5c8e
commit 44ba4afe3a
7 changed files with 47 additions and 36 deletions

View File

@ -40,6 +40,8 @@ struct MS_API Model {
using SubGraphPtrVector = Vector<SubGraph *>;
String name_;
String version_;
Uint32Vector input_indices_;
Uint32Vector output_indices_;
TensorPtrVector all_tensors_;
NodePtrVector all_nodes_;
char *buf = nullptr;

View File

@ -32,7 +32,7 @@ std::vector<size_t> GetGraphInputNodes(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(!(model->sub_graphs_.empty()));
std::vector<size_t> ret;
for (auto graph_in_index : model->sub_graphs_.front()->input_indices_) {
for (auto graph_in_index : model->input_indices_) {
auto node_size = model->all_nodes_.size();
for (size_t j = 0; j < node_size; ++j) {
auto node = model->all_nodes_[j];
@ -51,7 +51,7 @@ std::vector<size_t> GetGraphInputNodes(const lite::Model *model) {
std::vector<size_t> GetGraphOutputNodes(const lite::Model *model) {
MS_ASSERT(model != nullptr);
std::vector<size_t> ret;
for (auto graph_out_index : model->sub_graphs_.front()->output_indices_) {
for (auto graph_out_index : model->output_indices_) {
auto node_size = model->all_nodes_.size();
for (size_t j = 0; j < node_size; ++j) {
auto node = model->all_nodes_[j];

View File

@ -197,6 +197,17 @@ class LiteModel : public Model {
MS_LOG(ERROR) << "convert tensor failed";
return RET_ERROR;
}
// converterInputOutput
auto in_count = meta_graph.inputIndex()->size();
for (uint32_t i = 0; i < in_count; ++i) {
this->input_indices_.push_back(meta_graph.inputIndex()->Get(i));
}
auto out_count = meta_graph.outputIndex()->size();
for (uint32_t i = 0; i < out_count; ++i) {
this->output_indices_.push_back(meta_graph.outputIndex()->Get(i));
}
if (meta_graph.subGraph() == nullptr) {
int ret = MetaGraphMappingSubGraph<T>(meta_graph);
if (ret != RET_OK) {

View File

@ -173,8 +173,8 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
uint32_t tensor_count = model->all_tensors_.size();
MS_ASSERT(!model->sub_graphs_.empty());
auto model_input_indices = model->sub_graphs_.front()->input_indices_;
auto model_output_indices = model->sub_graphs_.front()->output_indices_;
auto model_input_indices = model->input_indices_;
auto model_output_indices = model->output_indices_;
for (uint32_t i = 0; i < tensor_count; ++i) {
auto *src_tensor = model->all_tensors_[i];
if (src_tensor == nullptr) {
@ -218,9 +218,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
void LiteSession::InitGraphInputTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(!(model->sub_graphs_.empty()));
auto graph_in_size = model->sub_graphs_.front()->input_indices_.size();
auto graph_in_size = model->input_indices_.size();
for (size_t i = 0; i < graph_in_size; ++i) {
auto in_tensor_idx = model->sub_graphs_.front()->input_indices_[i];
auto in_tensor_idx = model->input_indices_[i];
MS_ASSERT(in_tensor_idx < this->tensors_.size());
auto *in_tensor = this->tensors_.at(in_tensor_idx);
MS_ASSERT(in_tensor != nullptr);
@ -239,9 +239,9 @@ void LiteSession::InitGraphInputMSTensors() {
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(this->outputs_.empty());
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
auto graph_out_size = model->output_indices_.size();
for (size_t i = 0; i < graph_out_size; ++i) {
auto out_tensor_idx = model->sub_graphs_.front()->output_indices_[i];
auto out_tensor_idx = model->output_indices_[i];
MS_ASSERT(out_tensor_idx < this->tensors_.size());
auto *out_tensor = this->tensors_.at(out_tensor_idx);
MS_ASSERT(out_tensor != nullptr);
@ -253,7 +253,7 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(this->input_map_.empty());
auto graph_input_node_indexes = GetGraphInputNodes(model);
auto graph_in_size = model->sub_graphs_.front()->input_indices_.size();
auto graph_in_size = model->input_indices_.size();
for (auto in_node_index : graph_input_node_indexes) {
auto in_node = model->all_nodes_[in_node_index];
MS_ASSERT(in_node != nullptr);
@ -263,7 +263,7 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
auto in_tensor_index = size_t(in_node->input_indices_[i]);
bool is_graph_input = false;
for (size_t j = 0; j < graph_in_size; ++j) {
if (in_tensor_index == model->sub_graphs_.front()->input_indices_[j]) {
if (in_tensor_index == model->input_indices_[j]) {
is_graph_input = true;
break;
}
@ -290,7 +290,7 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(!(model->sub_graphs_.empty()));
auto graph_output_node_indexes = GetGraphOutputNodes(model);
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
auto graph_out_size = model->output_indices_.size();
for (auto out_node_index : graph_output_node_indexes) {
auto out_node = model->all_nodes_[out_node_index];
MS_ASSERT(out_node != nullptr);
@ -299,7 +299,7 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
auto out_tensor_index = out_node->output_indices_[i];
bool is_graph_output = false;
for (size_t j = 0; j < graph_out_size; ++j) {
if (out_tensor_index == model->sub_graphs_.front()->output_indices_[j]) {
if (out_tensor_index == model->output_indices_[j]) {
is_graph_output = true;
break;
}
@ -321,9 +321,9 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(this->output_tensor_map_.empty());
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
auto graph_out_size = model->output_indices_.size();
for (size_t i = 0; i < graph_out_size; ++i) {
size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i];
size_t graph_out_index = model->output_indices_[i];
MS_ASSERT(graph_out_index < this->tensors_.size());
auto *out_tensor = this->tensors_.at(graph_out_index);
if (out_tensor == nullptr) {
@ -342,9 +342,9 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
auto graph_out_size = model->output_indices_.size();
for (size_t i = 0; i < graph_out_size; ++i) {
size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i];
size_t graph_out_index = model->output_indices_[i];
MS_ASSERT(graph_out_index < this->tensors_.size());
auto *out_tensor = this->tensors_.at(graph_out_index);
if (out_tensor == nullptr) {

View File

@ -71,6 +71,18 @@ std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const siz
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
}
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
std::replace_if(
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
for (auto &subGraph : graphT->subGraph) {
std::replace_if(
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
}
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
std::vector<uint32_t> outputIndexes;
if (outputIndexIdx == -1) {
@ -146,13 +158,7 @@ STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
auto &gOutTensorIdx = graphT->outputIndex;
for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
@ -207,13 +213,8 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool remove
auto outDataTensorIdx = outputTensorIdxes.front();
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
auto &gOutTensorIdx = graphT->outputIndex;
for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
if (*iter == outDataTensorIdx) {
*iter = inDataTensorIdx;
break;
}
}
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
@ -605,11 +606,7 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
}
if (has_insert_for_graph_out) {
for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
if (*iter == postTensorIdx) {
*iter = toAddTensorIdx;
}
}
ReplaceOutput(postTensorIdx, toAddTensorIdx, graphT);
has_insert_for_graph_out = false;
} else {
auto &postNode = graphT->nodes.at(postNodeIdxes[is_output_index ? i - 1 : i]);

View File

@ -60,6 +60,8 @@ std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);

View File

@ -76,7 +76,6 @@ STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) {
STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) {
MS_ASSERT(graph->subGraph.size() > 0);
graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end());
graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end());
return RET_OK;
}