From 2f978ac247cc541d83abeb50b765ee60140b7ac4 Mon Sep 17 00:00:00 2001 From: liuluobin Date: Mon, 21 Jun 2021 15:15:24 +0800 Subject: [PATCH] fix no exception when export with encryption is failed --- mindspore/ccsrc/pipeline/jit/pipeline.cc | 5 ++--- mindspore/train/serialization.py | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index fa2ff975f88..3c6f4372286 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1298,8 +1298,7 @@ py::bytes PyEncrypt(char *plain_data, const size_t plain_len, char *key, const s auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast(plain_data), plain_len, reinterpret_cast(key), key_len, enc_mode); if (encrypt_data == nullptr) { - MS_LOG(ERROR) << "Encrypt failed"; - return py::bytes(); + MS_EXCEPTION(ValueError) << "Encrypt failed"; } auto py_encrypt_data = py::bytes(reinterpret_cast(encrypt_data.get()), encrypt_len); return py_encrypt_data; @@ -1311,7 +1310,7 @@ py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const size_t key_l mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast(key), key_len, dec_mode); if (decrypt_data == nullptr) { MS_LOG(ERROR) << "Decrypt failed"; - return py::bytes(); + return py::none(); } auto py_decrypt_data = py::bytes(reinterpret_cast(decrypt_data.get()), decrypt_len); return py_decrypt_data; diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 9b4ba25489b..661d4972c6d 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -404,6 +404,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N pb_content = f.read() else: pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) + if pb_content is None: + raise ValueError checkpoint_list.ParseFromString(pb_content) except BaseException as e: if _is_cipher_file(ckpt_file_name):