forked from mindspore-Ecosystem/mindspore
Fix AIR Encryption Error
This commit is contained in:
parent
52024260ad
commit
e2c79dbf21
|
@ -1566,7 +1566,7 @@ void ExportGraph(const std::string &file_name, const std::string &model_type, co
|
|||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode,
|
||||
const py::object decrypt) {
|
||||
FuncGraphPtr func_graph;
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
if (dec_mode == "Customized") {
|
||||
py::bytes key_bytes(dec_key);
|
||||
py::bytes model_stream = decrypt(file_name, key_bytes);
|
||||
|
|
|
@ -552,7 +552,7 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
|
|||
}
|
||||
}
|
||||
|
||||
void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::function encrypt, char *key) {
|
||||
void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::object encrypt, char *key) {
|
||||
MS_LOG(DEBUG) << "Export graph begin.";
|
||||
transform::DfGraphWrapperPtr wrap_ptr = transform::GetGraphByName(phase);
|
||||
if (wrap_ptr == nullptr) {
|
||||
|
@ -566,7 +566,7 @@ void ExportDFGraph(const std::string &file_name, const std::string &phase, const
|
|||
return;
|
||||
}
|
||||
if (key != nullptr) {
|
||||
if (py::isinstance<py::none>(encrypt)) {
|
||||
if (py::isinstance<py::none()>(encrypt)) {
|
||||
MS_LOG(ERROR) << "ERROR: encrypt is not a function";
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
|
|||
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
|
||||
const std::vector<int64_t> &input_indexes, const std::string &phase);
|
||||
|
||||
void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::function encrypt = py::none(),
|
||||
void ExportDFGraph(const std::string &file_name, const std::string &phase, const py::object encrypt = py::none(),
|
||||
char *key = nullptr);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -188,3 +188,4 @@ def test_load_mindir_and_run_with_encryption():
|
|||
loaded_net = nn.GraphCell(graph)
|
||||
outputs_after_load = loaded_net(inputs0)
|
||||
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
|
||||
os.remove(mindir_name)
|
||||
|
|
Loading…
Reference in New Issue