forked from mindspore-Ecosystem/mindspore
!19629 modify inc load mindir
Merge pull request !19629 from changzherui/inc_load
This commit is contained in:
commit
9523d28536
|
@ -260,7 +260,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
|
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
|
||||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||||
MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name;
|
MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name;
|
||||||
if (load_tensor_map.find(debug_info_name) == load_tensor_map.end()) {
|
if (!IsIncLoad() || load_tensor_map.find(debug_info_name) == load_tensor_map.end()) {
|
||||||
load_tensor_map[debug_info_name] = tensor_info;
|
load_tensor_map[debug_info_name] = tensor_info;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Parameter: " << debug_info_name << " has been already loaded, use it again.";
|
MS_LOG(DEBUG) << "Parameter: " << debug_info_name << " has been already loaded, use it again.";
|
||||||
|
|
|
@ -40,7 +40,9 @@ class MSANFModelParser {
|
||||||
std::string GetProducerVersion() { return model_version_; }
|
std::string GetProducerVersion() { return model_version_; }
|
||||||
std::string GetIrVersion() { return ir_version_; }
|
std::string GetIrVersion() { return ir_version_; }
|
||||||
void SetLite() { is_lite_ = true; }
|
void SetLite() { is_lite_ = true; }
|
||||||
bool IsLite() { return is_lite_; }
|
bool IsLite() const { return is_lite_; }
|
||||||
|
void SetIncLoad() { inc_load_ = true; }
|
||||||
|
bool IsIncLoad() const { return inc_load_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
|
@ -73,6 +75,7 @@ class MSANFModelParser {
|
||||||
std::string model_version_;
|
std::string model_version_;
|
||||||
std::string ir_version_;
|
std::string ir_version_;
|
||||||
bool is_lite_ = false;
|
bool is_lite_ = false;
|
||||||
|
bool inc_load_ = false;
|
||||||
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
|
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -171,19 +171,19 @@ bool ParseGraphProto(mind_ir::GraphProto *graph, std::string path, const unsigne
|
||||||
|
|
||||||
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite,
|
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite,
|
||||||
const unsigned char *dec_key, const size_t key_len,
|
const unsigned char *dec_key, const size_t key_len,
|
||||||
const std::string &dec_mode) {
|
const std::string &dec_mode, bool inc_load) {
|
||||||
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
|
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
|
||||||
MS_LOG(DEBUG) << "Load multiple MindIR files.";
|
MS_LOG(DEBUG) << "Load multiple MindIR files.";
|
||||||
for (size_t i = 0; i < file_names.size(); ++i) {
|
for (size_t i = 0; i < file_names.size(); ++i) {
|
||||||
std::string file_name = file_names[i];
|
std::string file_name = file_names[i];
|
||||||
MS_LOG(DEBUG) << "Load " << file_name;
|
MS_LOG(DEBUG) << "Load " << file_name;
|
||||||
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode));
|
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode, inc_load));
|
||||||
}
|
}
|
||||||
return funcgraph_vec;
|
return funcgraph_vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
|
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
|
||||||
const size_t key_len, const std::string &dec_mode) {
|
const size_t key_len, const std::string &dec_mode, bool inc_load) {
|
||||||
if (file_name.length() > PATH_MAX) {
|
if (file_name.length() > PATH_MAX) {
|
||||||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -254,6 +254,9 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
|
||||||
if (is_lite) {
|
if (is_lite) {
|
||||||
model_parser.SetLite();
|
model_parser.SetLite();
|
||||||
}
|
}
|
||||||
|
if (inc_load) {
|
||||||
|
model_parser.SetIncLoad();
|
||||||
|
}
|
||||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
|
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
|
||||||
return dstgraph_ptr;
|
return dstgraph_ptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,10 +25,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false,
|
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false,
|
||||||
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
|
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
|
||||||
const std::string &dec_mode = std::string("AES-GCM"));
|
const std::string &dec_mode = std::string("AES-GCM"), bool inc_load = false);
|
||||||
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names, bool is_lite = false,
|
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names, bool is_lite = false,
|
||||||
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
|
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
|
||||||
const std::string &dec_mode = std::string("AES-GCM"));
|
const std::string &dec_mode = std::string("AES-GCM"),
|
||||||
|
bool inc_load = true);
|
||||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
||||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
|
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue