forked from mindspore-Ecosystem/mindspore
!22726 [LITE] train output tensor name
Merge pull request !22726 from yefeng/153-train_output_input_tensor_name
This commit is contained in:
commit
7a69ee5ee2
|
@ -557,7 +557,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
|
|||
bool train_flag) {
|
||||
this->train_flag_ = train_flag;
|
||||
// hardcode for nnie and train
|
||||
this->reorder_input_ = !(train_flag) && !(ConverterContext::GetInstance()->GetGraphInputTensorNames().empty());
|
||||
this->reorder_input_ = !(ConverterContext::GetInstance()->GetGraphInputTensorNames().empty());
|
||||
this->graph_inputs_map_.clear();
|
||||
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
||||
MS_CHECK_TRUE_MSG(meta_graphT != nullptr, nullptr, "meta_graphT is nullptr");
|
||||
|
|
|
@ -132,13 +132,54 @@ int BenchmarkBase::LoadInput() {
|
|||
|
||||
// calibData is FP32
|
||||
int BenchmarkBase::ReadCalibData() {
|
||||
const char *calib_data_path = flags_->benchmark_data_file_.c_str();
|
||||
auto new_file = flags_->benchmark_data_file_ + ".new";
|
||||
const char *calib_data_path = new_file.c_str();
|
||||
// read calib data
|
||||
std::ifstream in_file(calib_data_path);
|
||||
if (!in_file.good()) {
|
||||
std::cerr << "file: " << calib_data_path << " is not exist" << std::endl;
|
||||
MS_LOG(ERROR) << "file: " << calib_data_path << " is not exist";
|
||||
return RET_ERROR;
|
||||
|
||||
auto old_file = flags_->benchmark_data_file_;
|
||||
const char *old_calib_data_path = old_file.c_str();
|
||||
std::ifstream old_in_file(old_calib_data_path);
|
||||
|
||||
if (!old_in_file.good()) {
|
||||
std::cerr << "file: " << old_calib_data_path << " is not exist" << std::endl;
|
||||
MS_LOG(ERROR) << "file: " << old_calib_data_path << " is not exist";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!old_in_file.is_open()) {
|
||||
std::cerr << "file: " << old_calib_data_path << " open failed" << std::endl;
|
||||
MS_LOG(ERROR) << "file: " << old_calib_data_path << " open failed";
|
||||
old_in_file.close();
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "Start reading calibData file";
|
||||
std::string line;
|
||||
std::string tensor_name;
|
||||
|
||||
while (!old_in_file.eof()) {
|
||||
getline(old_in_file, line);
|
||||
std::stringstream string_line1(line);
|
||||
size_t dim = 0;
|
||||
string_line1 >> tensor_name >> dim;
|
||||
std::vector<size_t> dims;
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
size_t tmp_dim;
|
||||
string_line1 >> tmp_dim;
|
||||
dims.push_back(tmp_dim);
|
||||
}
|
||||
auto ret = ReadTensorData(old_in_file, tensor_name, dims);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read tensor data failed, tensor name: " << tensor_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
old_in_file.close();
|
||||
MS_LOG(INFO) << "Finish reading calibData file";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (!in_file.is_open()) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/converter/import/mindspore_importer.h"
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
@ -100,6 +101,94 @@ size_t MindsporeImporter::Hex2ByteArray(const std::string &hex_str, unsigned cha
|
|||
return byte_len;
|
||||
}
|
||||
|
||||
STATUS MindsporeImporter::ProcessDependCnode(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
|
||||
output_tensor_name_.push_back(cnode->fullname_with_scope());
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
auto depend_input = cnode->input(1);
|
||||
if (utils::isa<CNodePtr>(depend_input)) {
|
||||
auto depend_input_cnode = utils::cast<CNodePtr>(depend_input);
|
||||
auto status = ProcessDependCnode(depend_input_cnode);
|
||||
if (status == RET_NO_CHANGE) {
|
||||
return RET_OK;
|
||||
}
|
||||
} else if (utils::isa<ParameterPtr>(depend_input) || utils::isa<ValueNode>(depend_input)) {
|
||||
output_tensor_name_.push_back(depend_input->fullname_with_scope());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MindsporeImporter::GetFuncGraphOutputName(const CNodePtr &return_node) {
|
||||
MS_ASSERT(return_node != nullptr);
|
||||
for (size_t i = 0; i < return_node->inputs().size(); i++) {
|
||||
auto output_node = return_node->input(i);
|
||||
if (output_node == nullptr) {
|
||||
MS_LOG(ERROR) << "output_node is nullptr.";
|
||||
return RET_ERROR;
|
||||
} else if (output_node->isa<mindspore::CNode>()) {
|
||||
if (opt::CheckPrimitiveType(output_node, prim::kPrimUpdateState) ||
|
||||
opt::CheckPrimitiveType(output_node, prim::kPrimLoad)) {
|
||||
continue;
|
||||
}
|
||||
auto output_cnode = utils::cast<CNodePtr>(output_node);
|
||||
if (opt::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
|
||||
for (size_t j = 0; j < output_cnode->inputs().size(); j++) {
|
||||
auto tuple_input = output_cnode->input(j);
|
||||
if (!utils::isa<CNodePtr>(tuple_input)) {
|
||||
continue;
|
||||
}
|
||||
auto tuple_input_cnode = utils::cast<CNodePtr>(tuple_input);
|
||||
if (opt::CheckPrimitiveType(output_node, prim::kPrimUpdateState) ||
|
||||
opt::CheckPrimitiveType(output_node, prim::kPrimLoad)) {
|
||||
continue;
|
||||
}
|
||||
auto status = ProcessDependCnode(tuple_input_cnode);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "ProcessDependCnode failed.";
|
||||
}
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
|
||||
auto status = ProcessDependCnode(output_cnode);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "ProcessDependCnode failed.";
|
||||
}
|
||||
} else {
|
||||
output_tensor_name_.push_back(output_cnode->fullname_with_scope());
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) {
|
||||
std::map<AnfNodePtr, bool> graph_input_map;
|
||||
for (auto &input : func_graph->get_inputs()) {
|
||||
graph_input_map[input] = false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t i = 0; i < cnode->inputs().size(); i++) {
|
||||
for (auto &input : func_graph->get_inputs()) {
|
||||
if (input == cnode->input(i) && graph_input_map.count(input) == 1) {
|
||||
graph_input_map[input] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &item : graph_input_map) {
|
||||
if (item.second == false) {
|
||||
func_graph->DropNode(item.first);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
||||
FuncGraphPtr func_graph;
|
||||
if (flag.dec_key.size() != 0) {
|
||||
|
@ -125,14 +214,28 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
}
|
||||
func_graph->set_attr("graph_name", MakeValue("main_graph"));
|
||||
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeMs)));
|
||||
auto status = RemoveUnusedGraphInput(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "RemoveUnusedGraphInput failed.";
|
||||
return nullptr;
|
||||
}
|
||||
for (auto input : func_graph->get_inputs()) {
|
||||
ConverterContext::GetInstance()->AddGraphInputTensorNames(input->fullname_with_scope());
|
||||
}
|
||||
status = GetFuncGraphOutputName(func_graph->get_return());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetFuncGraphOutputName failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if (output_tensor_name_.empty()) {
|
||||
MS_LOG(ERROR) << "Can not find output name.";
|
||||
return nullptr;
|
||||
}
|
||||
ConverterContext::GetInstance()->SetGraphOutputTensorNames(output_tensor_name_);
|
||||
#ifdef ENABLE_LITE_ACL
|
||||
MS_LOG(INFO) << "There is no need to adjust and pass graph when in Ascend310.";
|
||||
return func_graph;
|
||||
#endif
|
||||
STATUS status;
|
||||
if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
|
||||
|
@ -30,8 +31,12 @@ class MindsporeImporter {
|
|||
FuncGraphPtr ImportMindIR(const converter::Flags &flag);
|
||||
|
||||
private:
|
||||
STATUS RemoveUnusedGraphInput(const FuncGraphPtr &func_graph);
|
||||
STATUS ProcessDependCnode(const CNodePtr &cnode);
|
||||
STATUS GetFuncGraphOutputName(const CNodePtr &cnode);
|
||||
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
|
||||
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
|
||||
std::vector<std::string> output_tensor_name_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
Loading…
Reference in New Issue