forked from mindspore-Ecosystem/mindspore
!18637 fix no exception when export with encryption failed
Merge pull request !18637 from liuluobin/fix_issue
This commit is contained in:
commit
749de6636c
|
@ -1298,8 +1298,7 @@ py::bytes PyEncrypt(char *plain_data, const size_t plain_len, char *key, const s
|
||||||
auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
|
auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
|
||||||
reinterpret_cast<Byte *>(key), key_len, enc_mode);
|
reinterpret_cast<Byte *>(key), key_len, enc_mode);
|
||||||
if (encrypt_data == nullptr) {
|
if (encrypt_data == nullptr) {
|
||||||
MS_LOG(ERROR) << "Encrypt failed";
|
MS_EXCEPTION(ValueError) << "Encrypt failed";
|
||||||
return py::bytes();
|
|
||||||
}
|
}
|
||||||
auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
|
auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
|
||||||
return py_encrypt_data;
|
return py_encrypt_data;
|
||||||
|
@ -1311,7 +1310,7 @@ py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const size_t key_l
|
||||||
mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
|
mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
|
||||||
if (decrypt_data == nullptr) {
|
if (decrypt_data == nullptr) {
|
||||||
MS_LOG(ERROR) << "Decrypt failed";
|
MS_LOG(ERROR) << "Decrypt failed";
|
||||||
return py::bytes();
|
return py::none();
|
||||||
}
|
}
|
||||||
auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
|
auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
|
||||||
return py_decrypt_data;
|
return py_decrypt_data;
|
||||||
|
|
|
@ -404,6 +404,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
pb_content = f.read()
|
pb_content = f.read()
|
||||||
else:
|
else:
|
||||||
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
||||||
|
if pb_content is None:
|
||||||
|
raise ValueError
|
||||||
checkpoint_list.ParseFromString(pb_content)
|
checkpoint_list.ParseFromString(pb_content)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
if _is_cipher_file(ckpt_file_name):
|
if _is_cipher_file(ckpt_file_name):
|
||||||
|
|
Loading…
Reference in New Issue