From e2c79dbf2133223b9ab109ce5824d57c461573ed Mon Sep 17 00:00:00 2001 From: Shukun Zhang Date: Tue, 14 Jun 2022 09:21:40 +0800 Subject: [PATCH] Fix AIR Encryption Error --- mindspore/ccsrc/pipeline/jit/pipeline.cc | 2 +- mindspore/ccsrc/pipeline/jit/pipeline_ge.cc | 4 ++-- mindspore/ccsrc/pipeline/jit/pipeline_ge.h | 2 +- tests/st/export_and_load/test_train_mindir.py | 1 + 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 28e613fbd9d..6d8e1701716 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1566,7 +1566,7 @@ void ExportGraph(const std::string &file_name, const std::string &model_type, co FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode, const py::object decrypt) { - FuncGraphPtr func_graph; + FuncGraphPtr func_graph = nullptr; if (dec_mode == "Customized") { py::bytes key_bytes(dec_key); py::bytes model_stream = decrypt(file_name, key_bytes); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc index 4ad16d54f37..7c1301cd10c 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc @@ -552,7 +552,7 @@ py::object ExecDFGraph(const std::map &info, const } } -void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::function encrypt, char *key) { +void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::object encrypt, char *key) { MS_LOG(DEBUG) << "Export graph begin."; transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase); if (wrap_ptr == nullptr) { @@ -566,7 +566,7 @@ void ExportDFGraph(const std::string &file_name, const std::string &phase, const return; } if (key != nullptr) { - if (py::isinstance(encrypt)) { + if (py::isinstance(encrypt)) { MS_LOG(ERROR) << "ERROR: encrypt is not a function"; return; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.h b/mindspore/ccsrc/pipeline/jit/pipeline_ge.h index cfadbd2a4ac..7ac7f287047 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_ge.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.h @@ -48,7 +48,7 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc const std::vector &types, const std::vector> &shapes, const std::vector &input_indexes, const std::string &phase); -void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::function encrypt = py::none(), +void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::object encrypt = py::none(), char *key = nullptr); } // namespace pipeline } // namespace mindspore diff --git a/tests/st/export_and_load/test_train_mindir.py b/tests/st/export_and_load/test_train_mindir.py index 4f759407486..bc5500a392d 100644 --- a/tests/st/export_and_load/test_train_mindir.py +++ b/tests/st/export_and_load/test_train_mindir.py @@ -188,3 +188,4 @@ def test_load_mindir_and_run_with_encryption(): loaded_net = nn.GraphCell(graph) outputs_after_load = loaded_net(inputs0) assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy()) + os.remove(mindir_name)