Fix AIR Encryption Error

This commit is contained in:
Shukun Zhang 2022-06-14 09:21:40 +08:00
parent 52024260ad
commit e2c79dbf21
4 changed files with 5 additions and 4 deletions

View File

@ -1566,7 +1566,7 @@ void ExportGraph(const std::string &file_name, const std::string &model_type, co
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode,
const py::object decrypt) {
FuncGraphPtr func_graph;
FuncGraphPtr func_graph = nullptr;
if (dec_mode == "Customized") {
py::bytes key_bytes(dec_key);
py::bytes model_stream = decrypt(file_name, key_bytes);

View File

@ -552,7 +552,7 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
}
}
void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::function encrypt, char *key) {
void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::object encrypt, char *key) {
MS_LOG(DEBUG) << "Export graph begin.";
transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase);
if (wrap_ptr == nullptr) {
@ -566,7 +566,7 @@ void ExportDFGraph(const std::string &file_name, const std::string &phase, const
return;
}
if (key != nullptr) {
if (py::isinstance<py::none>(encrypt)) {
if (py::isinstance<py::none()>(encrypt)) {
MS_LOG(ERROR) << "ERROR: encrypt is not a function";
return;
}

View File

@ -48,7 +48,7 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase);
void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::function encrypt = py::none(),
void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::object encrypt = py::none(),
char *key = nullptr);
} // namespace pipeline
} // namespace mindspore

View File

@ -188,3 +188,4 @@ def test_load_mindir_and_run_with_encryption():
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(inputs0)
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
os.remove(mindir_name)