forked from mindspore-Ecosystem/mindspore
add load multi mindir api
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
1baebef648
commit
0cdd2338b8
|
@ -31,10 +31,9 @@ using Key = struct Key {
|
|||
const size_t max_key_len = 32;
|
||||
size_t len;
|
||||
unsigned char key[32];
|
||||
Key(): len(0) {}
|
||||
Key() : len(0) {}
|
||||
};
|
||||
|
||||
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph);
|
||||
|
@ -43,17 +42,21 @@ class MS_API Serialization {
|
|||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph);
|
||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::string &dec_mode);
|
||||
inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
|
||||
const Key &dec_key = {}, const std::string &dec_mode = "AES-GCM");
|
||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
||||
|
||||
private:
|
||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode);
|
||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode);
|
||||
static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph);
|
||||
static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode);
|
||||
static Status Load(const std::vector<std::vector<char>> &files, ModelType model_type, std::vector<Graph> *graphs,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode);
|
||||
};
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
|
@ -69,5 +72,10 @@ Status Serialization::Load(const std::string &file, ModelType model_type, Graph
|
|||
const std::string &dec_mode) {
|
||||
return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode));
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "include/api/serialization.h"
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "mindspore/core/load_mindir/load_model.h"
|
||||
|
@ -69,62 +70,48 @@ static Buffer ReadFile(const std::string &file) {
|
|||
}
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Output args graph is nullptr.";
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = nullptr;
|
||||
try {
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
|
||||
} catch (const std::exception &) {
|
||||
if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) {
|
||||
MS_LOG(ERROR) << "Load model failed. The model_data may be encrypted, please pass in correct key.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Load model failed.";
|
||||
}
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
return kSuccess;
|
||||
} else if (model_type == kOM) {
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
|
||||
return kMEInvalidInput;
|
||||
return Load(model_data, data_size, model_type, graph, Key{}, StringToChar("AES-GCM"));
|
||||
}
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||
std::stringstream err_msg;
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Output args graph is nullptr.";
|
||||
return kMEInvalidInput;
|
||||
err_msg << "Output args graph is nullptr.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = nullptr;
|
||||
try {
|
||||
if (dec_key.len > dec_key.max_key_len) {
|
||||
MS_LOG(ERROR) << "The key length exceeds maximum length: 32.";
|
||||
return kMEInvalidInput;
|
||||
err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
} else if (dec_key.len == 0) {
|
||||
if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) {
|
||||
err_msg << "Load model failed. The model_data may be encrypted, please pass in correct key.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
} else {
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
|
||||
}
|
||||
} else {
|
||||
size_t plain_data_size;
|
||||
std::string dec_mode_str(dec_mode.begin(), dec_mode.end());
|
||||
auto plain_data = mindspore::Decrypt(&plain_data_size, reinterpret_cast<const unsigned char *>(model_data),
|
||||
data_size, dec_key.key, dec_key.len, dec_mode_str);
|
||||
data_size, dec_key.key, dec_key.len, CharToString(dec_mode));
|
||||
if (plain_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode.";
|
||||
return kMEInvalidInput;
|
||||
err_msg << "Load model failed. Please check the valid of dec_key and dec_mode.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(plain_data.get()), plain_data_size);
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode.";
|
||||
return kMEInvalidInput;
|
||||
err_msg << "Load model failed. Please check the valid of dec_key and dec_mode.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
|
@ -134,78 +121,112 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
|
||||
return kMEInvalidInput;
|
||||
err_msg << "Unsupported ModelType " << model_type;
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Output args graph is nullptr.";
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
std::string file_path = CharToString(file);
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = LoadMindIR(file_path);
|
||||
if (anf_graph == nullptr) {
|
||||
if (IsCipherFile(file_path)) {
|
||||
MS_LOG(ERROR) << "Load model failed. The file may be encrypted, please pass in correct key.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Load model failed.";
|
||||
}
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
return kSuccess;
|
||||
} else if (model_type == kOM) {
|
||||
Buffer data = ReadFile(file_path);
|
||||
if (data.Data() == nullptr) {
|
||||
MS_LOG(ERROR) << "Read file " << file_path << " failed.";
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
|
||||
return kMEInvalidInput;
|
||||
return Load(file, model_type, graph, Key{}, StringToChar("AES-GCM"));
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode) {
|
||||
std::stringstream err_msg;
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Output args graph is nullptr.";
|
||||
return kMEInvalidInput;
|
||||
err_msg << "Output args graph is nullptr.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
std::string file_path = CharToString(file);
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph;
|
||||
if (dec_key.len > dec_key.max_key_len) {
|
||||
MS_LOG(ERROR) << "The key length exceeds maximum length: 32.";
|
||||
return kMEInvalidInput;
|
||||
err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
} else if (dec_key.len == 0 && IsCipherFile(file_path)) {
|
||||
err_msg << "Load model failed. The file may be encrypted, please pass in correct key.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
} else {
|
||||
std::string dec_mode_str(dec_mode.begin(), dec_mode.end());
|
||||
anf_graph = LoadMindIR(file_path, false, dec_key.key, dec_key.len, dec_mode_str);
|
||||
anf_graph = LoadMindIR(file_path, false, nullptr, dec_key.len, CharToString(dec_mode));
|
||||
}
|
||||
if (anf_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode";
|
||||
return kMEInvalidInput;
|
||||
err_msg << "Load model failed. Please check the valid of dec_key and dec_mode";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
return kSuccess;
|
||||
} else if (model_type == kOM) {
|
||||
Buffer data = ReadFile(file_path);
|
||||
if (data.Data() == nullptr) {
|
||||
MS_LOG(ERROR) << "Read file " << file_path << " failed.";
|
||||
return kMEInvalidInput;
|
||||
err_msg << "Read file " << file_path << " failed.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
|
||||
return kMEInvalidInput;
|
||||
err_msg << "Unsupported ModelType " << model_type;
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
|
||||
std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||
std::stringstream err_msg;
|
||||
if (graphs == nullptr) {
|
||||
err_msg << "Output args graph is nullptr.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
if (files.size() == 1) {
|
||||
std::vector<Graph> result(files.size());
|
||||
auto ret = Load(files[0], model_type, &result[0], dec_key, dec_mode);
|
||||
*graphs = std::move(result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<std::string> files_path = VectorCharToString(files);
|
||||
if (model_type == kMindIR) {
|
||||
if (dec_key.len > dec_key.max_key_len) {
|
||||
err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
auto anf_graphs =
|
||||
LoadMindIRs(files_path, false, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode));
|
||||
if (anf_graphs.size() != files_path.size()) {
|
||||
err_msg << "Load model failed, " << files_path.size() << " files got " << anf_graphs.size() << " graphs.";
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
std::vector<Graph> results;
|
||||
for (size_t i = 0; i < anf_graphs.size(); ++i) {
|
||||
if (anf_graphs[i] == nullptr) {
|
||||
if (dec_key.len == 0 && IsCipherFile(files_path[i])) {
|
||||
err_msg << "Load model failed. The file " << files_path[i] << " be encrypted, please pass in correct key.";
|
||||
} else {
|
||||
err_msg << "Load model " << files_path[i] << " failed.";
|
||||
}
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
results.emplace_back(std::make_shared<Graph::GraphData>(anf_graphs[i], kMindIR));
|
||||
}
|
||||
|
||||
*graphs = std::move(results);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
err_msg << "Unsupported ModelType " << model_type;
|
||||
MS_LOG(ERROR) << err_msg.str();
|
||||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
Status Serialization::LoadCheckPoint(const std::string &, std::map<std::string, Buffer> *) {
|
||||
|
|
Loading…
Reference in New Issue