set inference output tensor name for trained model && fix inference model for view

This commit is contained in:
yefeng 2021-08-23 11:36:51 +08:00
parent a789221722
commit 439435a86b
10 changed files with 216 additions and 13 deletions

View File

@ -69,7 +69,8 @@ class MS_API Serialization {
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model); static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file, static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true); QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
std::vector<std::string> output_tensor_name = {});
private: private:
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key, static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,

View File

@ -278,7 +278,8 @@ Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
return kMEFailed; return kMEFailed;
} }
Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool) { Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool,
std::vector<std::string> output_tensor_name) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return kMEFailed; return kMEFailed;
} }

View File

@ -182,9 +182,11 @@ class MS_API LiteSession {
/// \param[in] model_type indication whether to save full model or only the inference part /// \param[in] model_type indication whether to save full model or only the inference part
/// \param[in] quant_type indication whether to quantize exported model /// \param[in] quant_type indication whether to quantize exported model
/// \param[in] format of exported file (currently only FT_FLATBUFFERS is supported) /// \param[in] format of exported file (currently only FT_FLATBUFFERS is supported)
/// \param[in] out_put_tensor_name of exported tensorname
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h /// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN, virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS) { lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
std::vector<std::string> out_put_tensor_name = {}) {
return mindspore::lite::RET_ERROR; return mindspore::lite::RET_ERROR;
} }

View File

@ -123,7 +123,8 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff
} }
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file, Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
QuantizationType quantization_type, bool export_inference_only) { QuantizationType quantization_type, bool export_inference_only,
std::vector<std::string> output_tensor_name) {
if (model.impl_ == nullptr) { if (model.impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null."; MS_LOG(ERROR) << "Model implement is null.";
return kLiteUninitializedObj; return kLiteUninitializedObj;
@ -137,7 +138,7 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
return kLiteParamInvalid; return kLiteParamInvalid;
} }
auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN, auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS); A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, output_tensor_name);
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
} }

View File

@ -18,6 +18,8 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <fstream> #include <fstream>
#include <utility> #include <utility>
#include <queue>
#include <algorithm>
#include <functional> #include <functional>
#include <map> #include <map>
#include <set> #include <set>
@ -28,8 +30,42 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
constexpr static int kFmkVal = 3; constexpr static int kFmkVal = 3;
constexpr static int kTransformTensorDim = 4; constexpr static int kTransformTensorDim = 4;
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
std::vector<size_t> postNodeIdx;
for (size_t i = 0; i < graphT.nodes.size(); i++) {
auto &oldNode = graphT.nodes.at(i);
if (oldNode == nullptr) {
continue;
}
auto inputIndexes = oldNode->inputIndex;
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
postNodeIdx.emplace_back(i);
}
}
return postNodeIdx;
}
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
const int outputIndexIdx = -1) {
std::vector<uint32_t> outputIndexes;
if (outputIndexIdx == -1) {
outputIndexes = node.outputIndex;
} else {
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
}
std::set<size_t> outputNodeIdx;
for (uint32_t outputIdx : outputIndexes) {
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
}
std::vector<size_t> ret;
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
return ret;
}
} // namespace
std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) { std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data_c()); uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data_c());
@ -397,16 +433,90 @@ int TrainExport::ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &k
} }
} }
TagQuantizedNodes(); // do another loop to mark QUANT_WEIGHT_NODES TagQuantizedNodes(); // do another loop to mark QUANT_WEIGHT_NODES
auto status = TopologicalSort();
if (status != RET_OK) {
MS_LOG(ERROR) << "TopologicalSort failed.";
return RET_ERROR;
}
return RET_OK; return RET_OK;
} }
int TrainExport::TopologicalSort() {
MS_ASSERT(meta_graph_ != nullptr);
std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
std::vector<size_t> sinked_tensor_idxes;
for (auto &subgraph : meta_graph_->subGraph) {
std::copy(subgraph->inputIndices.begin(), subgraph->inputIndices.end(), std::back_inserter(sinked_tensor_idxes));
}
// put all const tensor index into sinked_tensor_idxes
for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) {
if (meta_graph_->allTensors.at(i)->nodeType == NodeType_ValueNode) {
sinked_tensor_idxes.push_back(i);
}
}
auto &old_nodes = meta_graph_->nodes;
std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
// put all none depend node into queue
for (size_t i = 0; i < meta_graph_->subGraph.size(); i++) {
std::vector<unsigned int> new_subgraph_node_indices = {};
auto subgraph_node_indices = meta_graph_->subGraph[i]->nodeIndices;
for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
auto &node = old_nodes[subgraph_node_indices[j]];
if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
op_queue.push(std::move(node));
}
}
while (!op_queue.empty()) {
auto &node = op_queue.front();
auto post_node_idxes = GetOutputNodeIdx(*meta_graph_, *(node.get()));
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
for (auto post_node_idx : post_node_idxes) {
if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
auto &post_node = old_nodes.at(post_node_idx);
// check if post_node is non-depended
if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
op_queue.push(std::move(post_node));
}
}
}
new_nodes.emplace_back(std::move(node));
new_subgraph_node_indices.push_back(new_nodes.size() - 1);
op_queue.pop();
}
meta_graph_->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
}
if (new_nodes.size() != old_nodes.size()) {
MS_LOG(ERROR) << "Unknown error in TopologicalSort, old_nodes size: " << old_nodes.size()
<< ", new_nodes size: " << new_nodes.size();
return RET_ERROR;
}
meta_graph_->nodes.swap(new_nodes);
return RET_OK;
}
bool TrainExport::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
const std::vector<size_t> &sinked_tensor_idxes) {
MS_ASSERT(node != nullptr);
return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
}
int TrainExport::ExportInit(const std::string model_name, std::string version) { int TrainExport::ExportInit(const std::string model_name, std::string version) {
meta_graph_ = new (std::nothrow) schema::MetaGraphT(); meta_graph_ = new (std::nothrow) schema::MetaGraphT();
if (meta_graph_ == nullptr) { if (meta_graph_ == nullptr) {
MS_LOG(ERROR) << "cannot allocate meta_graph"; MS_LOG(ERROR) << "cannot allocate meta_graph";
return RET_ERROR; return RET_ERROR;
} }
auto sub_graph = std::make_unique<schema::SubGraphT>();
if (sub_graph == nullptr) {
MS_LOG(ERROR) << "cannot allocate SubGraphT";
return RET_ERROR;
}
sub_graph->name = model_name + "_subgraph";
meta_graph_->subGraph.emplace_back(std::move(sub_graph));
meta_graph_->fmkType = kFmkVal; meta_graph_->fmkType = kFmkVal;
meta_graph_->name = model_name; meta_graph_->name = model_name;
meta_graph_->version = version; meta_graph_->version = version;

View File

@ -57,6 +57,8 @@ class TrainExport {
std::vector<size_t> out_idx_; std::vector<size_t> out_idx_;
std::map<size_t, size_t> remap_; std::map<size_t, size_t> remap_;
std::unordered_map<size_t, size_t> connect_; // connection map (backbone tenor id-> head tensor id) std::unordered_map<size_t, size_t> connect_; // connection map (backbone tenor id-> head tensor id)
bool IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node, const std::vector<size_t> &sinked_tensor_idxes);
int TopologicalSort();
void PrepareRemap(int offset); void PrepareRemap(int offset);
Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model); Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model);
std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor); std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor);

View File

@ -22,6 +22,7 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <queue>
#include <map> #include <map>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/executor.h" #include "src/executor.h"
@ -793,8 +794,63 @@ int TrainSession::Resize(const std::vector<tensor::MSTensor *> &inputs, const st
return RET_OK; return RET_OK;
} }
int TrainSession::FindUseInTensorKernel(std::vector<kernel::LiteKernel *> *use_in_tensor_kernels,
const std::vector<lite::Tensor *> &kernel_in_tensors,
const std::vector<kernel::LiteKernel *> &inference_kernels) {
for (size_t i = 0; i < inference_kernels.size(); i++) {
for (size_t j = 0; j < kernel_in_tensors.size(); j++) {
if (IsContain(inference_kernels[i]->out_tensors(), kernel_in_tensors[j])) {
use_in_tensor_kernels->push_back(inference_kernels[i]);
}
}
}
return RET_OK;
}
int TrainSession::FindExportKernels(std::vector<kernel::LiteKernel *> *export_kernels,
const std::vector<std::string> &export_output_tensor_names,
const std::vector<kernel::LiteKernel *> &inference_kernels) {
std::vector<std::string> all_kernel_name = {};
std::transform(inference_kernels.begin(), inference_kernels.end(), std::back_inserter(all_kernel_name),
[](kernel::LiteKernel *kernel) { return kernel->name(); });
std::queue<std::string> need_kernel_names;
// Find the kernel name according to the tensor name
for (auto &kernel : inference_kernels) {
if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(), [&](lite::Tensor *out_tensor) {
return IsContain(export_output_tensor_names, out_tensor->tensor_name());
})) {
need_kernel_names.push(kernel->name());
}
}
// find all kernel
while (!need_kernel_names.empty()) {
auto kernel_name = need_kernel_names.front();
need_kernel_names.pop();
auto it = find(all_kernel_name.begin(), all_kernel_name.end(), kernel_name);
if (it == all_kernel_name.end()) {
MS_LOG(ERROR) << "not find kernel name in export trained model.";
return RET_ERROR;
}
auto kernel = inference_kernels[it - all_kernel_name.begin()];
if (!IsContain(*export_kernels, kernel)) {
export_kernels->push_back(kernel);
}
auto kernel_in_tensors = kernel->in_tensors();
std::vector<kernel::LiteKernel *> use_in_tensor_kernels;
auto status = FindUseInTensorKernel(&use_in_tensor_kernels, kernel_in_tensors, inference_kernels);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindUseInTensorKernel failed.";
return RET_ERROR;
}
for (size_t i = 0; i < use_in_tensor_kernels.size(); i++) {
need_kernel_names.push(use_in_tensor_kernels[i]->name());
}
}
return RET_OK;
}
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type, int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
FormatType format) { FormatType format, std::vector<std::string> out_put_tensor_name) {
if (file_name.empty()) { if (file_name.empty()) {
MS_LOG(ERROR) << "File name cannot be empty"; MS_LOG(ERROR) << "File name cannot be empty";
return RET_ERROR; return RET_ERROR;
@ -820,9 +876,21 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
MS_LOG(ERROR) << "cannot init export"; MS_LOG(ERROR) << "cannot init export";
return status; return status;
} }
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
model_.get(), quant_type); std::vector<kernel::LiteKernel *> export_kernels = {};
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindExportKernels failed.";
return RET_ERROR;
}
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
} else {
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
model_.get(), quant_type);
}
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "cannot export Network"; MS_LOG(ERROR) << "cannot export Network";
return status; return status;

View File

@ -89,11 +89,18 @@ class TrainSession : virtual public lite::LiteSession {
} }
return outputs; return outputs;
} }
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType) override; int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {}) override;
std::vector<tensor::MSTensor *> GetFeatureMaps() const override; std::vector<tensor::MSTensor *> GetFeatureMaps() const override;
int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) override; int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) override;
int FindUseInTensorKernel(std::vector<kernel::LiteKernel *> *use_in_tensor_kernels,
const std::vector<lite::Tensor *> &kernel_in_tensors,
const std::vector<kernel::LiteKernel *> &inference_kernels);
int FindExportKernels(std::vector<kernel::LiteKernel *> *export_kernels,
const std::vector<std::string> &export_output_tensor_names,
const std::vector<kernel::LiteKernel *> &inference_kernels);
protected: protected:
int AllocWorkSpace(); int AllocWorkSpace();

View File

@ -179,7 +179,7 @@ std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
} }
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type, int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
FormatType format) { FormatType format, std::vector<std::string> out_put_tensor_name) {
if (format != FT_FLATBUFFERS) { if (format != FT_FLATBUFFERS) {
MS_LOG(ERROR) << "Currently only flatbuffer format is supported"; MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
return RET_ERROR; return RET_ERROR;
@ -206,7 +206,17 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
return status; return status;
} }
} }
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type); if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
std::vector<kernel::LiteKernel *> export_kernels = {};
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindExportKernels failed.";
return RET_ERROR;
}
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
} else {
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type);
}
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "cannot serialize head"; MS_LOG(ERROR) << "cannot serialize head";
return status; return status;

View File

@ -61,7 +61,8 @@ class TransferSession : public lite::TrainSession {
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override; mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override;
int CompileTransferGraph(); int CompileTransferGraph();
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType) override; int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
std::vector<std::string> out_put_tensor_name = {}) override;
protected: protected:
lite::LiteSession *backbone_session_ = nullptr; lite::LiteSession *backbone_session_ = nullptr;