!19629 modify inc load mindir

Merge pull request !19629 from changzherui/inc_load
This commit is contained in:
i-robot 2021-07-09 09:04:57 +00:00 committed by Gitee
commit 9523d28536
4 changed files with 14 additions and 7 deletions

View File

@ -260,7 +260,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto); tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
MS_EXCEPTION_IF_NULL(tensor_info); MS_EXCEPTION_IF_NULL(tensor_info);
MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name; MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name;
if (load_tensor_map.find(debug_info_name) == load_tensor_map.end()) { if (!IsIncLoad() || load_tensor_map.find(debug_info_name) == load_tensor_map.end()) {
load_tensor_map[debug_info_name] = tensor_info; load_tensor_map[debug_info_name] = tensor_info;
} else { } else {
MS_LOG(DEBUG) << "Parameter: " << debug_info_name << " has been already loaded, use it again."; MS_LOG(DEBUG) << "Parameter: " << debug_info_name << " has been already loaded, use it again.";

View File

@ -40,7 +40,9 @@ class MSANFModelParser {
std::string GetProducerVersion() { return model_version_; } std::string GetProducerVersion() { return model_version_; }
std::string GetIrVersion() { return ir_version_; } std::string GetIrVersion() { return ir_version_; }
void SetLite() { is_lite_ = true; } void SetLite() { is_lite_ = true; }
bool IsLite() { return is_lite_; } bool IsLite() const { return is_lite_; }
void SetIncLoad() { inc_load_ = true; }
bool IsIncLoad() const { return inc_load_; }
private: private:
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
@ -73,6 +75,7 @@ class MSANFModelParser {
std::string model_version_; std::string model_version_;
std::string ir_version_; std::string ir_version_;
bool is_lite_ = false; bool is_lite_ = false;
bool inc_load_ = false;
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_; std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
}; };
} // namespace mindspore } // namespace mindspore

View File

@ -171,19 +171,19 @@ bool ParseGraphProto(mind_ir::GraphProto *graph, std::string path, const unsigne
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite, std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite,
const unsigned char *dec_key, const size_t key_len, const unsigned char *dec_key, const size_t key_len,
const std::string &dec_mode) { const std::string &dec_mode, bool inc_load) {
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec; std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
MS_LOG(DEBUG) << "Load multiple MindIR files."; MS_LOG(DEBUG) << "Load multiple MindIR files.";
for (size_t i = 0; i < file_names.size(); ++i) { for (size_t i = 0; i < file_names.size(); ++i) {
std::string file_name = file_names[i]; std::string file_name = file_names[i];
MS_LOG(DEBUG) << "Load " << file_name; MS_LOG(DEBUG) << "Load " << file_name;
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode)); funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode, inc_load));
} }
return funcgraph_vec; return funcgraph_vec;
} }
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key, std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
const size_t key_len, const std::string &dec_mode) { const size_t key_len, const std::string &dec_mode, bool inc_load) {
if (file_name.length() > PATH_MAX) { if (file_name.length() > PATH_MAX) {
MS_LOG(ERROR) << "The length of the file name exceeds the limit."; MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
return nullptr; return nullptr;
@ -254,6 +254,9 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
if (is_lite) { if (is_lite) {
model_parser.SetLite(); model_parser.SetLite();
} }
if (inc_load) {
model_parser.SetIncLoad();
}
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model); FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
return dstgraph_ptr; return dstgraph_ptr;
} }

View File

@ -25,10 +25,11 @@
namespace mindspore { namespace mindspore {
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false, std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false,
const unsigned char *dec_key = nullptr, const size_t key_len = 0, const unsigned char *dec_key = nullptr, const size_t key_len = 0,
const std::string &dec_mode = std::string("AES-GCM")); const std::string &dec_mode = std::string("AES-GCM"), bool inc_load = false);
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names, bool is_lite = false, std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names, bool is_lite = false,
const unsigned char *dec_key = nullptr, const size_t key_len = 0, const unsigned char *dec_key = nullptr, const size_t key_len = 0,
const std::string &dec_mode = std::string("AES-GCM")); const std::string &dec_mode = std::string("AES-GCM"),
bool inc_load = true);
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file); std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false); std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
} // namespace mindspore } // namespace mindspore