!19629 modify inc load mindir

Merge pull request !19629 from changzherui/inc_load
This commit is contained in:
i-robot 2021-07-09 09:04:57 +00:00 committed by Gitee
commit 9523d28536
4 changed files with 14 additions and 7 deletions

View File

@ -260,7 +260,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name;
if (load_tensor_map.find(debug_info_name) == load_tensor_map.end()) {
if (!IsIncLoad() || load_tensor_map.find(debug_info_name) == load_tensor_map.end()) {
load_tensor_map[debug_info_name] = tensor_info;
} else {
MS_LOG(DEBUG) << "Parameter: " << debug_info_name << " has been already loaded, use it again.";

View File

@ -40,7 +40,9 @@ class MSANFModelParser {
std::string GetProducerVersion() { return model_version_; }
std::string GetIrVersion() { return ir_version_; }
void SetLite() { is_lite_ = true; }
bool IsLite() { return is_lite_; }
bool IsLite() const { return is_lite_; }
void SetIncLoad() { inc_load_ = true; }
bool IsIncLoad() const { return inc_load_; }
private:
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
@ -73,6 +75,7 @@ class MSANFModelParser {
std::string model_version_;
std::string ir_version_;
bool is_lite_ = false;
bool inc_load_ = false;
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
};
} // namespace mindspore

View File

@ -171,19 +171,19 @@ bool ParseGraphProto(mind_ir::GraphProto *graph, std::string path, const unsigne
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite,
const unsigned char *dec_key, const size_t key_len,
const std::string &dec_mode) {
const std::string &dec_mode, bool inc_load) {
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
MS_LOG(DEBUG) << "Load multiple MindIR files.";
for (size_t i = 0; i < file_names.size(); ++i) {
std::string file_name = file_names[i];
MS_LOG(DEBUG) << "Load " << file_name;
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode));
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode, inc_load));
}
return funcgraph_vec;
}
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
const size_t key_len, const std::string &dec_mode) {
const size_t key_len, const std::string &dec_mode, bool inc_load) {
if (file_name.length() > PATH_MAX) {
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
return nullptr;
@ -254,6 +254,9 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
if (is_lite) {
model_parser.SetLite();
}
if (inc_load) {
model_parser.SetIncLoad();
}
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
return dstgraph_ptr;
}

View File

@ -25,10 +25,11 @@
namespace mindspore {
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false,
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
const std::string &dec_mode = std::string("AES-GCM"));
const std::string &dec_mode = std::string("AES-GCM"), bool inc_load = false);
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names, bool is_lite = false,
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
const std::string &dec_mode = std::string("AES-GCM"));
const std::string &dec_mode = std::string("AES-GCM"),
bool inc_load = true);
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
} // namespace mindspore