add model input index and output index
This commit is contained in:
parent
19e95c5c8e
commit
44ba4afe3a
|
@ -40,6 +40,8 @@ struct MS_API Model {
|
|||
using SubGraphPtrVector = Vector<SubGraph *>;
|
||||
String name_;
|
||||
String version_;
|
||||
Uint32Vector input_indices_;
|
||||
Uint32Vector output_indices_;
|
||||
TensorPtrVector all_tensors_;
|
||||
NodePtrVector all_nodes_;
|
||||
char *buf = nullptr;
|
||||
|
|
|
@ -32,7 +32,7 @@ std::vector<size_t> GetGraphInputNodes(const lite::Model *model) {
|
|||
MS_ASSERT(model != nullptr);
|
||||
MS_ASSERT(!(model->sub_graphs_.empty()));
|
||||
std::vector<size_t> ret;
|
||||
for (auto graph_in_index : model->sub_graphs_.front()->input_indices_) {
|
||||
for (auto graph_in_index : model->input_indices_) {
|
||||
auto node_size = model->all_nodes_.size();
|
||||
for (size_t j = 0; j < node_size; ++j) {
|
||||
auto node = model->all_nodes_[j];
|
||||
|
@ -51,7 +51,7 @@ std::vector<size_t> GetGraphInputNodes(const lite::Model *model) {
|
|||
std::vector<size_t> GetGraphOutputNodes(const lite::Model *model) {
|
||||
MS_ASSERT(model != nullptr);
|
||||
std::vector<size_t> ret;
|
||||
for (auto graph_out_index : model->sub_graphs_.front()->output_indices_) {
|
||||
for (auto graph_out_index : model->output_indices_) {
|
||||
auto node_size = model->all_nodes_.size();
|
||||
for (size_t j = 0; j < node_size; ++j) {
|
||||
auto node = model->all_nodes_[j];
|
||||
|
|
|
@ -197,6 +197,17 @@ class LiteModel : public Model {
|
|||
MS_LOG(ERROR) << "convert tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// converterInputOutput
|
||||
auto in_count = meta_graph.inputIndex()->size();
|
||||
for (uint32_t i = 0; i < in_count; ++i) {
|
||||
this->input_indices_.push_back(meta_graph.inputIndex()->Get(i));
|
||||
}
|
||||
auto out_count = meta_graph.outputIndex()->size();
|
||||
for (uint32_t i = 0; i < out_count; ++i) {
|
||||
this->output_indices_.push_back(meta_graph.outputIndex()->Get(i));
|
||||
}
|
||||
|
||||
if (meta_graph.subGraph() == nullptr) {
|
||||
int ret = MetaGraphMappingSubGraph<T>(meta_graph);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -173,8 +173,8 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
MS_ASSERT(model != nullptr);
|
||||
uint32_t tensor_count = model->all_tensors_.size();
|
||||
MS_ASSERT(!model->sub_graphs_.empty());
|
||||
auto model_input_indices = model->sub_graphs_.front()->input_indices_;
|
||||
auto model_output_indices = model->sub_graphs_.front()->output_indices_;
|
||||
auto model_input_indices = model->input_indices_;
|
||||
auto model_output_indices = model->output_indices_;
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
auto *src_tensor = model->all_tensors_[i];
|
||||
if (src_tensor == nullptr) {
|
||||
|
@ -218,9 +218,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
void LiteSession::InitGraphInputTensors(const lite::Model *model) {
|
||||
MS_ASSERT(model != nullptr);
|
||||
MS_ASSERT(!(model->sub_graphs_.empty()));
|
||||
auto graph_in_size = model->sub_graphs_.front()->input_indices_.size();
|
||||
auto graph_in_size = model->input_indices_.size();
|
||||
for (size_t i = 0; i < graph_in_size; ++i) {
|
||||
auto in_tensor_idx = model->sub_graphs_.front()->input_indices_[i];
|
||||
auto in_tensor_idx = model->input_indices_[i];
|
||||
MS_ASSERT(in_tensor_idx < this->tensors_.size());
|
||||
auto *in_tensor = this->tensors_.at(in_tensor_idx);
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
|
@ -239,9 +239,9 @@ void LiteSession::InitGraphInputMSTensors() {
|
|||
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
|
||||
MS_ASSERT(model != nullptr);
|
||||
MS_ASSERT(this->outputs_.empty());
|
||||
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
|
||||
auto graph_out_size = model->output_indices_.size();
|
||||
for (size_t i = 0; i < graph_out_size; ++i) {
|
||||
auto out_tensor_idx = model->sub_graphs_.front()->output_indices_[i];
|
||||
auto out_tensor_idx = model->output_indices_[i];
|
||||
MS_ASSERT(out_tensor_idx < this->tensors_.size());
|
||||
auto *out_tensor = this->tensors_.at(out_tensor_idx);
|
||||
MS_ASSERT(out_tensor != nullptr);
|
||||
|
@ -253,7 +253,7 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
|
|||
MS_ASSERT(model != nullptr);
|
||||
MS_ASSERT(this->input_map_.empty());
|
||||
auto graph_input_node_indexes = GetGraphInputNodes(model);
|
||||
auto graph_in_size = model->sub_graphs_.front()->input_indices_.size();
|
||||
auto graph_in_size = model->input_indices_.size();
|
||||
for (auto in_node_index : graph_input_node_indexes) {
|
||||
auto in_node = model->all_nodes_[in_node_index];
|
||||
MS_ASSERT(in_node != nullptr);
|
||||
|
@ -263,7 +263,7 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
|
|||
auto in_tensor_index = size_t(in_node->input_indices_[i]);
|
||||
bool is_graph_input = false;
|
||||
for (size_t j = 0; j < graph_in_size; ++j) {
|
||||
if (in_tensor_index == model->sub_graphs_.front()->input_indices_[j]) {
|
||||
if (in_tensor_index == model->input_indices_[j]) {
|
||||
is_graph_input = true;
|
||||
break;
|
||||
}
|
||||
|
@ -290,7 +290,7 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
|
|||
MS_ASSERT(model != nullptr);
|
||||
MS_ASSERT(!(model->sub_graphs_.empty()));
|
||||
auto graph_output_node_indexes = GetGraphOutputNodes(model);
|
||||
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
|
||||
auto graph_out_size = model->output_indices_.size();
|
||||
for (auto out_node_index : graph_output_node_indexes) {
|
||||
auto out_node = model->all_nodes_[out_node_index];
|
||||
MS_ASSERT(out_node != nullptr);
|
||||
|
@ -299,7 +299,7 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
|
|||
auto out_tensor_index = out_node->output_indices_[i];
|
||||
bool is_graph_output = false;
|
||||
for (size_t j = 0; j < graph_out_size; ++j) {
|
||||
if (out_tensor_index == model->sub_graphs_.front()->output_indices_[j]) {
|
||||
if (out_tensor_index == model->output_indices_[j]) {
|
||||
is_graph_output = true;
|
||||
break;
|
||||
}
|
||||
|
@ -321,9 +321,9 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
|
|||
void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
|
||||
MS_ASSERT(model != nullptr);
|
||||
MS_ASSERT(this->output_tensor_map_.empty());
|
||||
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
|
||||
auto graph_out_size = model->output_indices_.size();
|
||||
for (size_t i = 0; i < graph_out_size; ++i) {
|
||||
size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i];
|
||||
size_t graph_out_index = model->output_indices_[i];
|
||||
MS_ASSERT(graph_out_index < this->tensors_.size());
|
||||
auto *out_tensor = this->tensors_.at(graph_out_index);
|
||||
if (out_tensor == nullptr) {
|
||||
|
@ -342,9 +342,9 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
|
|||
|
||||
void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) {
|
||||
MS_ASSERT(model != nullptr);
|
||||
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
|
||||
auto graph_out_size = model->output_indices_.size();
|
||||
for (size_t i = 0; i < graph_out_size; ++i) {
|
||||
size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i];
|
||||
size_t graph_out_index = model->output_indices_[i];
|
||||
MS_ASSERT(graph_out_index < this->tensors_.size());
|
||||
auto *out_tensor = this->tensors_.at(graph_out_index);
|
||||
if (out_tensor == nullptr) {
|
||||
|
|
|
@ -71,6 +71,18 @@ std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const siz
|
|||
return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
|
||||
}
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
|
||||
std::replace_if(
|
||||
std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
|
||||
for (auto &subGraph : graphT->subGraph) {
|
||||
std::replace_if(
|
||||
std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
|
||||
[&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
|
||||
std::vector<uint32_t> outputIndexes;
|
||||
if (outputIndexIdx == -1) {
|
||||
|
@ -146,13 +158,7 @@ STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
|
|||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
auto &gOutTensorIdx = graphT->outputIndex;
|
||||
for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
|
@ -207,13 +213,8 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool remove
|
|||
auto outDataTensorIdx = outputTensorIdxes.front();
|
||||
MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
|
||||
MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
|
||||
auto &gOutTensorIdx = graphT->outputIndex;
|
||||
for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
|
||||
if (*iter == outDataTensorIdx) {
|
||||
*iter = inDataTensorIdx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
|
||||
|
||||
// find poseNode
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
|
@ -605,11 +606,7 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|||
toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
|
||||
}
|
||||
if (has_insert_for_graph_out) {
|
||||
for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
|
||||
if (*iter == postTensorIdx) {
|
||||
*iter = toAddTensorIdx;
|
||||
}
|
||||
}
|
||||
ReplaceOutput(postTensorIdx, toAddTensorIdx, graphT);
|
||||
has_insert_for_graph_out = false;
|
||||
} else {
|
||||
auto &postNode = graphT->nodes.at(postNodeIdxes[is_output_index ? i - 1 : i]);
|
||||
|
|
|
@ -60,6 +60,8 @@ std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size
|
|||
|
||||
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx);
|
||||
|
||||
void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT);
|
||||
|
||||
STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node);
|
||||
|
||||
STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true);
|
||||
|
|
|
@ -76,7 +76,6 @@ STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) {
|
|||
STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph->subGraph.size() > 0);
|
||||
graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end());
|
||||
graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue