!18637 fix no exception when export with encryption failed

Merge pull request !18637 from liuluobin/fix_issue
This commit is contained in:
i-robot 2021-06-22 03:48:01 +00:00 committed by Gitee
commit 749de6636c
2 changed files with 4 additions and 3 deletions

View File

@ -1298,8 +1298,7 @@ py::bytes PyEncrypt(char *plain_data, const size_t plain_len, char *key, const s
auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len, auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
reinterpret_cast<Byte *>(key), key_len, enc_mode); reinterpret_cast<Byte *>(key), key_len, enc_mode);
if (encrypt_data == nullptr) { if (encrypt_data == nullptr) {
MS_LOG(ERROR) << "Encrypt failed"; MS_EXCEPTION(ValueError) << "Encrypt failed";
return py::bytes();
} }
auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len); auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
return py_encrypt_data; return py_encrypt_data;
@ -1311,7 +1310,7 @@ py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const size_t key_l
mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode); mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
if (decrypt_data == nullptr) { if (decrypt_data == nullptr) {
MS_LOG(ERROR) << "Decrypt failed"; MS_LOG(ERROR) << "Decrypt failed";
return py::bytes(); return py::none();
} }
auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len); auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
return py_decrypt_data; return py_decrypt_data;

View File

@ -404,6 +404,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
pb_content = f.read() pb_content = f.read()
else: else:
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
if pb_content is None:
raise ValueError
checkpoint_list.ParseFromString(pb_content) checkpoint_list.ParseFromString(pb_content)
except BaseException as e: except BaseException as e:
if _is_cipher_file(ckpt_file_name): if _is_cipher_file(ckpt_file_name):