forked from mindspore-Ecosystem/mindspore
!19629 modify inc load mindir
Merge pull request !19629 from changzherui/inc_load
This commit is contained in:
commit
9523d28536
|
@ -260,7 +260,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
|
|||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
MS_LOG(DEBUG) << "Load parameter name: " << debug_info_name;
|
||||
if (load_tensor_map.find(debug_info_name) == load_tensor_map.end()) {
|
||||
if (!IsIncLoad() || load_tensor_map.find(debug_info_name) == load_tensor_map.end()) {
|
||||
load_tensor_map[debug_info_name] = tensor_info;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Parameter: " << debug_info_name << " has been already loaded, use it again.";
|
||||
|
|
|
@ -40,7 +40,9 @@ class MSANFModelParser {
|
|||
std::string GetProducerVersion() { return model_version_; }
|
||||
std::string GetIrVersion() { return ir_version_; }
|
||||
void SetLite() { is_lite_ = true; }
|
||||
bool IsLite() { return is_lite_; }
|
||||
bool IsLite() const { return is_lite_; }
|
||||
void SetIncLoad() { inc_load_ = true; }
|
||||
bool IsIncLoad() const { return inc_load_; }
|
||||
|
||||
private:
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
|
@ -73,6 +75,7 @@ class MSANFModelParser {
|
|||
std::string model_version_;
|
||||
std::string ir_version_;
|
||||
bool is_lite_ = false;
|
||||
bool inc_load_ = false;
|
||||
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -171,19 +171,19 @@ bool ParseGraphProto(mind_ir::GraphProto *graph, std::string path, const unsigne
|
|||
|
||||
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite,
|
||||
const unsigned char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
const std::string &dec_mode, bool inc_load) {
|
||||
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
|
||||
MS_LOG(DEBUG) << "Load multiple MindIR files.";
|
||||
for (size_t i = 0; i < file_names.size(); ++i) {
|
||||
std::string file_name = file_names[i];
|
||||
MS_LOG(DEBUG) << "Load " << file_name;
|
||||
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode));
|
||||
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode, inc_load));
|
||||
}
|
||||
return funcgraph_vec;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
|
||||
const size_t key_len, const std::string &dec_mode) {
|
||||
const size_t key_len, const std::string &dec_mode, bool inc_load) {
|
||||
if (file_name.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||
return nullptr;
|
||||
|
@ -254,6 +254,9 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
|
|||
if (is_lite) {
|
||||
model_parser.SetLite();
|
||||
}
|
||||
if (inc_load) {
|
||||
model_parser.SetIncLoad();
|
||||
}
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
|
||||
return dstgraph_ptr;
|
||||
}
|
||||
|
|
|
@ -25,10 +25,11 @@
|
|||
namespace mindspore {
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false,
|
||||
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
|
||||
const std::string &dec_mode = std::string("AES-GCM"));
|
||||
const std::string &dec_mode = std::string("AES-GCM"), bool inc_load = false);
|
||||
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names, bool is_lite = false,
|
||||
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
|
||||
const std::string &dec_mode = std::string("AES-GCM"));
|
||||
const std::string &dec_mode = std::string("AES-GCM"),
|
||||
bool inc_load = true);
|
||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue