forked from mindspore-Ecosystem/mindspore
set inference output tensor name for trained model && fix inference model for view
This commit is contained in:
parent
a789221722
commit
439435a86b
|
@ -69,7 +69,8 @@ class MS_API Serialization {
|
||||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
||||||
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true);
|
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
|
||||||
|
std::vector<std::string> output_tensor_name = {});
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
|
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||||
|
|
|
@ -278,7 +278,8 @@ Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool) {
|
Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool,
|
||||||
|
std::vector<std::string> output_tensor_name) {
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
MS_LOG(ERROR) << "Unsupported feature.";
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
}
|
}
|
||||||
|
|
|
@ -182,9 +182,11 @@ class MS_API LiteSession {
|
||||||
/// \param[in] model_type indication whether to save full model or only the inference part
|
/// \param[in] model_type indication whether to save full model or only the inference part
|
||||||
/// \param[in] quant_type indication whether to quantize exported model
|
/// \param[in] quant_type indication whether to quantize exported model
|
||||||
/// \param[in] format of exported file (currently only FT_FLATBUFFERS is supported)
|
/// \param[in] format of exported file (currently only FT_FLATBUFFERS is supported)
|
||||||
|
/// \param[in] out_put_tensor_name of exported tensorname
|
||||||
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
|
||||||
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
|
virtual int Export(const std::string &file_name, lite::ModelType model_type = lite::MT_TRAIN,
|
||||||
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS) {
|
lite::QuantizationType quant_type = lite::QT_DEFAULT, lite::FormatType = lite::FT_FLATBUFFERS,
|
||||||
|
std::vector<std::string> out_put_tensor_name = {}) {
|
||||||
return mindspore::lite::RET_ERROR;
|
return mindspore::lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,8 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
||||||
QuantizationType quantization_type, bool export_inference_only) {
|
QuantizationType quantization_type, bool export_inference_only,
|
||||||
|
std::vector<std::string> output_tensor_name) {
|
||||||
if (model.impl_ == nullptr) {
|
if (model.impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
return kLiteUninitializedObj;
|
return kLiteUninitializedObj;
|
||||||
|
@ -137,7 +138,7 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
|
||||||
return kLiteParamInvalid;
|
return kLiteParamInvalid;
|
||||||
}
|
}
|
||||||
auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
|
auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
|
||||||
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS);
|
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, output_tensor_name);
|
||||||
|
|
||||||
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <queue>
|
||||||
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
@ -28,8 +30,42 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
namespace {
|
||||||
constexpr static int kFmkVal = 3;
|
constexpr static int kFmkVal = 3;
|
||||||
constexpr static int kTransformTensorDim = 4;
|
constexpr static int kTransformTensorDim = 4;
|
||||||
|
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
|
||||||
|
std::vector<size_t> postNodeIdx;
|
||||||
|
for (size_t i = 0; i < graphT.nodes.size(); i++) {
|
||||||
|
auto &oldNode = graphT.nodes.at(i);
|
||||||
|
if (oldNode == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto inputIndexes = oldNode->inputIndex;
|
||||||
|
if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
|
||||||
|
postNodeIdx.emplace_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return postNodeIdx;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||||
|
const int outputIndexIdx = -1) {
|
||||||
|
std::vector<uint32_t> outputIndexes;
|
||||||
|
if (outputIndexIdx == -1) {
|
||||||
|
outputIndexes = node.outputIndex;
|
||||||
|
} else {
|
||||||
|
outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
|
||||||
|
}
|
||||||
|
std::set<size_t> outputNodeIdx;
|
||||||
|
for (uint32_t outputIdx : outputIndexes) {
|
||||||
|
auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
|
||||||
|
outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
|
||||||
|
}
|
||||||
|
std::vector<size_t> ret;
|
||||||
|
ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
|
std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
|
||||||
uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data_c());
|
uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data_c());
|
||||||
|
@ -397,16 +433,90 @@ int TrainExport::ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TagQuantizedNodes(); // do another loop to mark QUANT_WEIGHT_NODES
|
TagQuantizedNodes(); // do another loop to mark QUANT_WEIGHT_NODES
|
||||||
|
auto status = TopologicalSort();
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "TopologicalSort failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int TrainExport::TopologicalSort() {
|
||||||
|
MS_ASSERT(meta_graph_ != nullptr);
|
||||||
|
std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
|
||||||
|
std::vector<size_t> sinked_tensor_idxes;
|
||||||
|
for (auto &subgraph : meta_graph_->subGraph) {
|
||||||
|
std::copy(subgraph->inputIndices.begin(), subgraph->inputIndices.end(), std::back_inserter(sinked_tensor_idxes));
|
||||||
|
}
|
||||||
|
// put all const tensor index into sinked_tensor_idxes
|
||||||
|
for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) {
|
||||||
|
if (meta_graph_->allTensors.at(i)->nodeType == NodeType_ValueNode) {
|
||||||
|
sinked_tensor_idxes.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto &old_nodes = meta_graph_->nodes;
|
||||||
|
std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
|
||||||
|
// put all none depend node into queue
|
||||||
|
for (size_t i = 0; i < meta_graph_->subGraph.size(); i++) {
|
||||||
|
std::vector<unsigned int> new_subgraph_node_indices = {};
|
||||||
|
auto subgraph_node_indices = meta_graph_->subGraph[i]->nodeIndices;
|
||||||
|
|
||||||
|
for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
|
||||||
|
auto &node = old_nodes[subgraph_node_indices[j]];
|
||||||
|
if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
|
||||||
|
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
|
||||||
|
op_queue.push(std::move(node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (!op_queue.empty()) {
|
||||||
|
auto &node = op_queue.front();
|
||||||
|
auto post_node_idxes = GetOutputNodeIdx(*meta_graph_, *(node.get()));
|
||||||
|
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
|
||||||
|
for (auto post_node_idx : post_node_idxes) {
|
||||||
|
if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
|
||||||
|
auto &post_node = old_nodes.at(post_node_idx);
|
||||||
|
// check if post_node is non-depended
|
||||||
|
if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
|
||||||
|
op_queue.push(std::move(post_node));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_nodes.emplace_back(std::move(node));
|
||||||
|
new_subgraph_node_indices.push_back(new_nodes.size() - 1);
|
||||||
|
op_queue.pop();
|
||||||
|
}
|
||||||
|
meta_graph_->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
|
||||||
|
}
|
||||||
|
if (new_nodes.size() != old_nodes.size()) {
|
||||||
|
MS_LOG(ERROR) << "Unknown error in TopologicalSort, old_nodes size: " << old_nodes.size()
|
||||||
|
<< ", new_nodes size: " << new_nodes.size();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
meta_graph_->nodes.swap(new_nodes);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TrainExport::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
|
||||||
|
const std::vector<size_t> &sinked_tensor_idxes) {
|
||||||
|
MS_ASSERT(node != nullptr);
|
||||||
|
return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
|
||||||
|
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
|
||||||
|
}
|
||||||
|
|
||||||
int TrainExport::ExportInit(const std::string model_name, std::string version) {
|
int TrainExport::ExportInit(const std::string model_name, std::string version) {
|
||||||
meta_graph_ = new (std::nothrow) schema::MetaGraphT();
|
meta_graph_ = new (std::nothrow) schema::MetaGraphT();
|
||||||
if (meta_graph_ == nullptr) {
|
if (meta_graph_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "cannot allocate meta_graph";
|
MS_LOG(ERROR) << "cannot allocate meta_graph";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
auto sub_graph = std::make_unique<schema::SubGraphT>();
|
||||||
|
if (sub_graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "cannot allocate SubGraphT";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
sub_graph->name = model_name + "_subgraph";
|
||||||
|
meta_graph_->subGraph.emplace_back(std::move(sub_graph));
|
||||||
meta_graph_->fmkType = kFmkVal;
|
meta_graph_->fmkType = kFmkVal;
|
||||||
meta_graph_->name = model_name;
|
meta_graph_->name = model_name;
|
||||||
meta_graph_->version = version;
|
meta_graph_->version = version;
|
||||||
|
|
|
@ -57,6 +57,8 @@ class TrainExport {
|
||||||
std::vector<size_t> out_idx_;
|
std::vector<size_t> out_idx_;
|
||||||
std::map<size_t, size_t> remap_;
|
std::map<size_t, size_t> remap_;
|
||||||
std::unordered_map<size_t, size_t> connect_; // connection map (backbone tenor id-> head tensor id)
|
std::unordered_map<size_t, size_t> connect_; // connection map (backbone tenor id-> head tensor id)
|
||||||
|
bool IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node, const std::vector<size_t> &sinked_tensor_idxes);
|
||||||
|
int TopologicalSort();
|
||||||
void PrepareRemap(int offset);
|
void PrepareRemap(int offset);
|
||||||
Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model);
|
Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model);
|
||||||
std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor);
|
std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor);
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "src/executor.h"
|
#include "src/executor.h"
|
||||||
|
@ -793,8 +794,63 @@ int TrainSession::Resize(const std::vector<tensor::MSTensor *> &inputs, const st
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int TrainSession::FindUseInTensorKernel(std::vector<kernel::LiteKernel *> *use_in_tensor_kernels,
|
||||||
|
const std::vector<lite::Tensor *> &kernel_in_tensors,
|
||||||
|
const std::vector<kernel::LiteKernel *> &inference_kernels) {
|
||||||
|
for (size_t i = 0; i < inference_kernels.size(); i++) {
|
||||||
|
for (size_t j = 0; j < kernel_in_tensors.size(); j++) {
|
||||||
|
if (IsContain(inference_kernels[i]->out_tensors(), kernel_in_tensors[j])) {
|
||||||
|
use_in_tensor_kernels->push_back(inference_kernels[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TrainSession::FindExportKernels(std::vector<kernel::LiteKernel *> *export_kernels,
|
||||||
|
const std::vector<std::string> &export_output_tensor_names,
|
||||||
|
const std::vector<kernel::LiteKernel *> &inference_kernels) {
|
||||||
|
std::vector<std::string> all_kernel_name = {};
|
||||||
|
std::transform(inference_kernels.begin(), inference_kernels.end(), std::back_inserter(all_kernel_name),
|
||||||
|
[](kernel::LiteKernel *kernel) { return kernel->name(); });
|
||||||
|
std::queue<std::string> need_kernel_names;
|
||||||
|
// Find the kernel name according to the tensor name
|
||||||
|
for (auto &kernel : inference_kernels) {
|
||||||
|
if (std::any_of(kernel->out_tensors().begin(), kernel->out_tensors().end(), [&](lite::Tensor *out_tensor) {
|
||||||
|
return IsContain(export_output_tensor_names, out_tensor->tensor_name());
|
||||||
|
})) {
|
||||||
|
need_kernel_names.push(kernel->name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// find all kernel
|
||||||
|
while (!need_kernel_names.empty()) {
|
||||||
|
auto kernel_name = need_kernel_names.front();
|
||||||
|
need_kernel_names.pop();
|
||||||
|
auto it = find(all_kernel_name.begin(), all_kernel_name.end(), kernel_name);
|
||||||
|
if (it == all_kernel_name.end()) {
|
||||||
|
MS_LOG(ERROR) << "not find kernel name in export trained model.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto kernel = inference_kernels[it - all_kernel_name.begin()];
|
||||||
|
if (!IsContain(*export_kernels, kernel)) {
|
||||||
|
export_kernels->push_back(kernel);
|
||||||
|
}
|
||||||
|
auto kernel_in_tensors = kernel->in_tensors();
|
||||||
|
std::vector<kernel::LiteKernel *> use_in_tensor_kernels;
|
||||||
|
auto status = FindUseInTensorKernel(&use_in_tensor_kernels, kernel_in_tensors, inference_kernels);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "FindUseInTensorKernel failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < use_in_tensor_kernels.size(); i++) {
|
||||||
|
need_kernel_names.push(use_in_tensor_kernels[i]->name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
|
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
|
||||||
FormatType format) {
|
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||||
if (file_name.empty()) {
|
if (file_name.empty()) {
|
||||||
MS_LOG(ERROR) << "File name cannot be empty";
|
MS_LOG(ERROR) << "File name cannot be empty";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -820,9 +876,21 @@ int TrainSession::Export(const std::string &file_name, ModelType model_type, Qua
|
||||||
MS_LOG(ERROR) << "cannot init export";
|
MS_LOG(ERROR) << "cannot init export";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
|
|
||||||
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
|
if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
|
||||||
model_.get(), quant_type);
|
std::vector<kernel::LiteKernel *> export_kernels = {};
|
||||||
|
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "FindExportKernels failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
|
||||||
|
} else {
|
||||||
|
status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_,
|
||||||
|
(model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_,
|
||||||
|
model_.get(), quant_type);
|
||||||
|
}
|
||||||
|
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "cannot export Network";
|
MS_LOG(ERROR) << "cannot export Network";
|
||||||
return status;
|
return status;
|
||||||
|
|
|
@ -89,11 +89,18 @@ class TrainSession : virtual public lite::LiteSession {
|
||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType) override;
|
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||||
|
std::vector<std::string> out_put_tensor_name = {}) override;
|
||||||
|
|
||||||
std::vector<tensor::MSTensor *> GetFeatureMaps() const override;
|
std::vector<tensor::MSTensor *> GetFeatureMaps() const override;
|
||||||
|
|
||||||
int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) override;
|
int UpdateFeatureMaps(const std::vector<tensor::MSTensor *> &features_map) override;
|
||||||
|
int FindUseInTensorKernel(std::vector<kernel::LiteKernel *> *use_in_tensor_kernels,
|
||||||
|
const std::vector<lite::Tensor *> &kernel_in_tensors,
|
||||||
|
const std::vector<kernel::LiteKernel *> &inference_kernels);
|
||||||
|
int FindExportKernels(std::vector<kernel::LiteKernel *> *export_kernels,
|
||||||
|
const std::vector<std::string> &export_output_tensor_names,
|
||||||
|
const std::vector<kernel::LiteKernel *> &inference_kernels);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int AllocWorkSpace();
|
int AllocWorkSpace();
|
||||||
|
|
|
@ -179,7 +179,7 @@ std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
|
int TransferSession::Export(const std::string &filename, ModelType model_type, QuantizationType quant_type,
|
||||||
FormatType format) {
|
FormatType format, std::vector<std::string> out_put_tensor_name) {
|
||||||
if (format != FT_FLATBUFFERS) {
|
if (format != FT_FLATBUFFERS) {
|
||||||
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
|
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -206,7 +206,17 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type);
|
if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) {
|
||||||
|
std::vector<kernel::LiteKernel *> export_kernels = {};
|
||||||
|
status = FindExportKernels(&export_kernels, out_put_tensor_name, inference_kernels_);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "FindExportKernels failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type);
|
||||||
|
} else {
|
||||||
|
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type);
|
||||||
|
}
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "cannot serialize head";
|
MS_LOG(ERROR) << "cannot serialize head";
|
||||||
return status;
|
return status;
|
||||||
|
|
|
@ -61,7 +61,8 @@ class TransferSession : public lite::TrainSession {
|
||||||
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override;
|
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override;
|
||||||
|
|
||||||
int CompileTransferGraph();
|
int CompileTransferGraph();
|
||||||
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType) override;
|
int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType,
|
||||||
|
std::vector<std::string> out_put_tensor_name = {}) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
lite::LiteSession *backbone_session_ = nullptr;
|
lite::LiteSession *backbone_session_ = nullptr;
|
||||||
|
|
Loading…
Reference in New Issue