!22726 [LITE] train output tensor name

Merge pull request !22726 from yefeng/153-train_output_input_tensor_name
This commit is contained in:
i-robot 2021-09-02 12:49:30 +00:00 committed by Gitee
commit 7a69ee5ee2
4 changed files with 153 additions and 4 deletions

View File

@ -557,7 +557,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
bool train_flag) {
this->train_flag_ = train_flag;
// hardcode for nnie and train
this->reorder_input_ = !(train_flag) && !(ConverterContext::GetInstance()->GetGraphInputTensorNames().empty());
this->reorder_input_ = !(ConverterContext::GetInstance()->GetGraphInputTensorNames().empty());
this->graph_inputs_map_.clear();
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
MS_CHECK_TRUE_MSG(meta_graphT != nullptr, nullptr, "meta_graphT is nullptr");

View File

@ -132,13 +132,54 @@ int BenchmarkBase::LoadInput() {
// calibData is FP32
int BenchmarkBase::ReadCalibData() {
const char *calib_data_path = flags_->benchmark_data_file_.c_str();
auto new_file = flags_->benchmark_data_file_ + ".new";
const char *calib_data_path = new_file.c_str();
// read calib data
std::ifstream in_file(calib_data_path);
if (!in_file.good()) {
std::cerr << "file: " << calib_data_path << " is not exist" << std::endl;
MS_LOG(ERROR) << "file: " << calib_data_path << " is not exist";
return RET_ERROR;
auto old_file = flags_->benchmark_data_file_;
const char *old_calib_data_path = old_file.c_str();
std::ifstream old_in_file(old_calib_data_path);
if (!old_in_file.good()) {
std::cerr << "file: " << old_calib_data_path << " is not exist" << std::endl;
MS_LOG(ERROR) << "file: " << old_calib_data_path << " is not exist";
return RET_ERROR;
}
if (!old_in_file.is_open()) {
std::cerr << "file: " << old_calib_data_path << " open failed" << std::endl;
MS_LOG(ERROR) << "file: " << old_calib_data_path << " open failed";
old_in_file.close();
return RET_ERROR;
}
MS_LOG(INFO) << "Start reading calibData file";
std::string line;
std::string tensor_name;
while (!old_in_file.eof()) {
getline(old_in_file, line);
std::stringstream string_line1(line);
size_t dim = 0;
string_line1 >> tensor_name >> dim;
std::vector<size_t> dims;
for (size_t i = 0; i < dim; i++) {
size_t tmp_dim;
string_line1 >> tmp_dim;
dims.push_back(tmp_dim);
}
auto ret = ReadTensorData(old_in_file, tensor_name, dims);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Read tensor data failed, tensor name: " << tensor_name;
return RET_ERROR;
}
}
old_in_file.close();
MS_LOG(INFO) << "Finish reading calibData file";
return RET_OK;
}
if (!in_file.is_open()) {

View File

@ -16,6 +16,7 @@
#include "tools/converter/import/mindspore_importer.h"
#include <memory>
#include <map>
#include <set>
#include <vector>
#include <regex>
@ -100,6 +101,94 @@ size_t MindsporeImporter::Hex2ByteArray(const std::string &hex_str, unsigned cha
return byte_len;
}
STATUS MindsporeImporter::ProcessDependCnode(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
if (!opt::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
output_tensor_name_.push_back(cnode->fullname_with_scope());
return RET_NO_CHANGE;
}
auto depend_input = cnode->input(1);
if (utils::isa<CNodePtr>(depend_input)) {
auto depend_input_cnode = utils::cast<CNodePtr>(depend_input);
auto status = ProcessDependCnode(depend_input_cnode);
if (status == RET_NO_CHANGE) {
return RET_OK;
}
} else if (utils::isa<ParameterPtr>(depend_input) || utils::isa<ValueNode>(depend_input)) {
output_tensor_name_.push_back(depend_input->fullname_with_scope());
}
return RET_OK;
}
STATUS MindsporeImporter::GetFuncGraphOutputName(const CNodePtr &return_node) {
MS_ASSERT(return_node != nullptr);
for (size_t i = 0; i < return_node->inputs().size(); i++) {
auto output_node = return_node->input(i);
if (output_node == nullptr) {
MS_LOG(ERROR) << "output_node is nullptr.";
return RET_ERROR;
} else if (output_node->isa<mindspore::CNode>()) {
if (opt::CheckPrimitiveType(output_node, prim::kPrimUpdateState) ||
opt::CheckPrimitiveType(output_node, prim::kPrimLoad)) {
continue;
}
auto output_cnode = utils::cast<CNodePtr>(output_node);
if (opt::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
for (size_t j = 0; j < output_cnode->inputs().size(); j++) {
auto tuple_input = output_cnode->input(j);
if (!utils::isa<CNodePtr>(tuple_input)) {
continue;
}
auto tuple_input_cnode = utils::cast<CNodePtr>(tuple_input);
if (opt::CheckPrimitiveType(output_node, prim::kPrimUpdateState) ||
opt::CheckPrimitiveType(output_node, prim::kPrimLoad)) {
continue;
}
auto status = ProcessDependCnode(tuple_input_cnode);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "ProcessDependCnode failed.";
}
}
} else if (opt::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
auto status = ProcessDependCnode(output_cnode);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "ProcessDependCnode failed.";
}
} else {
output_tensor_name_.push_back(output_cnode->fullname_with_scope());
}
}
}
return RET_OK;
}
STATUS MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) {
std::map<AnfNodePtr, bool> graph_input_map;
for (auto &input : func_graph->get_inputs()) {
graph_input_map[input] = false;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
for (size_t i = 0; i < cnode->inputs().size(); i++) {
for (auto &input : func_graph->get_inputs()) {
if (input == cnode->input(i) && graph_input_map.count(input) == 1) {
graph_input_map[input] = true;
}
}
}
}
for (auto &item : graph_input_map) {
if (item.second == false) {
func_graph->DropNode(item.first);
}
}
return RET_OK;
}
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
FuncGraphPtr func_graph;
if (flag.dec_key.size() != 0) {
@ -125,14 +214,28 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
}
func_graph->set_attr("graph_name", MakeValue("main_graph"));
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeMs)));
auto status = RemoveUnusedGraphInput(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "RemoveUnusedGraphInput failed.";
return nullptr;
}
for (auto input : func_graph->get_inputs()) {
ConverterContext::GetInstance()->AddGraphInputTensorNames(input->fullname_with_scope());
}
status = GetFuncGraphOutputName(func_graph->get_return());
if (status != RET_OK) {
MS_LOG(ERROR) << "GetFuncGraphOutputName failed.";
return nullptr;
}
if (output_tensor_name_.empty()) {
MS_LOG(ERROR) << "Can not find output name.";
return nullptr;
}
ConverterContext::GetInstance()->SetGraphOutputTensorNames(output_tensor_name_);
#ifdef ENABLE_LITE_ACL
MS_LOG(INFO) << "There is no need to adjust and pass graph when in Ascend310.";
return func_graph;
#endif
STATUS status;
if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -19,6 +19,7 @@
#include <set>
#include <string>
#include <vector>
#include "tools/converter/converter_flags.h"
#include "load_mindir/load_model.h"
@ -30,8 +31,12 @@ class MindsporeImporter {
FuncGraphPtr ImportMindIR(const converter::Flags &flag);
private:
STATUS RemoveUnusedGraphInput(const FuncGraphPtr &func_graph);
STATUS ProcessDependCnode(const CNodePtr &cnode);
STATUS GetFuncGraphOutputName(const CNodePtr &cnode);
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
std::vector<std::string> output_tensor_name_;
};
} // namespace mindspore::lite