!19744 [MS][LITE] r1.3 remove redundant load api
Merge pull request !19744 from zhengjun10/api1.3
This commit is contained in:
commit
fad679a7c0
|
@ -27,24 +27,24 @@
|
|||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
using Key = struct Key {
|
||||
constexpr char kDecModeAesGcm[] = "AES-GCM";
|
||||
|
||||
struct MS_API Key {
|
||||
const size_t max_key_len = 32;
|
||||
size_t len;
|
||||
unsigned char key[32];
|
||||
Key() : len(0) {}
|
||||
Key(const char *dec_key, size_t key_len);
|
||||
};
|
||||
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph);
|
||||
inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::string &dec_mode);
|
||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph);
|
||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::string &dec_mode);
|
||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {},
|
||||
const std::string &dec_mode = kDecModeAesGcm);
|
||||
inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
|
||||
const Key &dec_key = {}, const std::string &dec_mode = "AES-GCM");
|
||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
||||
|
@ -64,10 +64,6 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
|||
return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode));
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph) {
|
||||
return Load(StringToChar(file), model_type, graph);
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::string &dec_mode) {
|
||||
return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
|
||||
|
|
|
@ -35,6 +35,7 @@ enum ModelType : uint32_t {
|
|||
kAIR = 1,
|
||||
kOM = 2,
|
||||
kONNX = 3,
|
||||
kFlatBuffer = 4,
|
||||
// insert new data type here
|
||||
kUnknownType = 0xFFFFFFFF
|
||||
};
|
||||
|
|
|
@ -79,8 +79,20 @@ static Buffer ReadFile(const std::string &file) {
|
|||
return buffer;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
|
||||
return Load(model_data, data_size, model_type, graph, Key{}, StringToChar("AES-GCM"));
|
||||
Key::Key(const char *dec_key, size_t key_len) {
|
||||
len = 0;
|
||||
if (key_len >= max_key_len) {
|
||||
MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
|
||||
return;
|
||||
}
|
||||
|
||||
auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len);
|
||||
if (sec_ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret;
|
||||
return;
|
||||
}
|
||||
|
||||
len = key_len;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
|
@ -137,7 +149,7 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
|||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
||||
return Load(file, model_type, graph, Key{}, StringToChar("AES-GCM"));
|
||||
return Load(file, model_type, graph, Key{}, StringToChar(kDecModeAesGcm));
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
|
@ -256,11 +268,6 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
|
|||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
Status Serialization::LoadCheckPoint(const std::string &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return kMEFailed;
|
||||
}
|
||||
|
||||
Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return kMEFailed;
|
||||
|
|
|
@ -24,10 +24,27 @@
|
|||
#include "include/model.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "src/cxx_api/graph/graph_data.h"
|
||||
#include "src/cxx_api/model/model_impl.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
|
||||
Key::Key(const char *dec_key, size_t key_len) {
|
||||
len = 0;
|
||||
if (key_len >= max_key_len) {
|
||||
MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
|
||||
return;
|
||||
}
|
||||
memcpy(key, dec_key, key_len);
|
||||
len = key_len;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||
if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
if (model_data == nullptr) {
|
||||
MS_LOG(ERROR) << "model data is nullptr.";
|
||||
return kLiteNullptr;
|
||||
|
@ -40,6 +57,7 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
|||
MS_LOG(ERROR) << "Unsupported IR.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size));
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "New model failed.";
|
||||
|
@ -54,26 +72,45 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "graph is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
if (model_type != kFlatBuffer) {
|
||||
MS_LOG(ERROR) << "Unsupported IR.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
std::string filename = file.data();
|
||||
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
|
||||
filename = filename + ".ms";
|
||||
}
|
||||
|
||||
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "New model failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
|
||||
if (graph_data == nullptr) {
|
||||
MS_LOG(ERROR) << "New graph data failed.";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
*graph = Graph(graph_data);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return kMEFailed;
|
||||
Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
|
||||
std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model) {
|
||||
|
|
|
@ -46,11 +46,8 @@ TEST_F(TestCxxApiSerialization, test_load_file_not_exist_FAILED) {
|
|||
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_SUCCESS) {
|
||||
Graph graph;
|
||||
std::string key_str = "0123456789ABCDEF";
|
||||
Key key;
|
||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
||||
key.len = key_str.size();
|
||||
ASSERT_TRUE(Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
|
||||
key, "AES-GCM") == kSuccess);
|
||||
Key(key_str.c_str(), key_str.size()), kDecModeAesGcm) == kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) {
|
||||
|
@ -65,21 +62,16 @@ TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) {
|
|||
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_with_wrong_key_FAILED) {
|
||||
Graph graph;
|
||||
std::string key_str = "WRONG_KEY";
|
||||
Key key;
|
||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
||||
key.len = key_str.size();
|
||||
auto status = Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
|
||||
key, "AES-GCM");
|
||||
Key(key_str.c_str(), key_str.size()), kDecModeAesGcm);
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_with_wrong_key_FAILED) {
|
||||
Graph graph;
|
||||
std::string key_str = "WRONG_KEY";
|
||||
Key key;
|
||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
||||
key.len = key_str.size();
|
||||
auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph, key, "AES-GCM");
|
||||
auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph,
|
||||
Key(key_str.c_str(), key_str.size()), kDecModeAesGcm);
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue