Expand LoadMindIR with Encryption

This commit is contained in:
Shukun Zhang 2022-06-09 11:13:57 +08:00
parent e086ba7cf7
commit 52024260ad
7 changed files with 124 additions and 23 deletions

View File

@ -167,7 +167,8 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, py::arg("file_name"), py::arg("model_type"),
py::arg("phase"), py::arg("encrypt") = py::none(), py::arg("key") = nullptr, "Export Graph.");
(void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"),
py::arg("decrypt") = py::none(), "Load model as Graph.");
(void)m.def("init_cluster", &mindspore::distributed::Initialize, "Init Cluster");

View File

@ -1564,10 +1564,20 @@ void ExportGraph(const std::string &file_name, const std::string &model_type, co
#endif
}
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len,
const std::string &dec_mode) {
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode,
const py::object decrypt) {
FuncGraphPtr func_graph;
if (dec_mode == "Customized") {
py::bytes key_bytes(dec_key);
py::bytes model_stream = decrypt(file_name, key_bytes);
std::string model_string(model_stream);
MindIRLoader mindir_loader;
func_graph = mindir_loader.LoadMindIR(model_string.c_str(), model_string.size());
} else {
MindIRLoader mindir_loader(false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode, false);
auto func_graph = mindir_loader.LoadMindIR(file_name);
func_graph = mindir_loader.LoadMindIR(file_name);
}
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);

View File

@ -183,7 +183,8 @@ void MemoryRecycle();
void ExportGraph(const std::string &file_name, const std::string &model_type, const std::string &phase,
const py::object encrypt = py::none(), char *key = nullptr);
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode);
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode,
const py::object decrypt = py::none());
// init and exec dataset sub graph
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,

View File

@ -388,8 +388,8 @@ def load(file_name, **kwargs):
kwargs (dict): Configuration options dictionary.
- dec_key (bytes): Byte type key used for decryption. The valid length is 16, 24, or 32.
- dec_mode (str): Specifies the decryption mode, to take effect when dec_key is set.
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
- dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
Option: 'AES-GCM', 'AES-CBC' or customized decryption. Default: 'AES-GCM'.
Returns:
GraphCell, a compiled graph that can executed by `GraphCell`.
@ -429,9 +429,15 @@ def load(file_name, **kwargs):
if 'dec_key' in kwargs.keys():
dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
dec_mode = "AES-GCM"
dec_func = None
if 'dec_mode' in kwargs.keys():
if callable(kwargs.get('dec_mode')):
dec_mode = "Customized"
dec_func = kwargs.get('dec_mode')
else:
dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode)
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
decrypt=dec_func)
else:
graph = load_mindir(file_name)
@ -859,8 +865,9 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
- std_dev (float): The variance of input data after preprocessing,
used for quantizing the first layer of the network. Default: 127.5.
- enc_key (byte): Byte type key used for encryption. The valid length is 16, 24, or 32.
- enc_mode (str): Specifies the encryption mode, to take effect when enc_key is set.
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
- enc_mode (Union[str, function]): Specifies the encryption mode, to take effect when enc_key is set.
For 'AIR' and 'ONNX' models, only Customized encryption is supported. For 'MINDIR', all options are
supported. Option: 'AES-GCM', 'AES-CBC' or Customized encryption by user. Default: 'AES-GCM'.
Examples:
>>> import mindspore as ms
@ -935,15 +942,16 @@ def _check_key_mode_type(file_format, **kwargs):
enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
enc_mode = kwargs.get('enc_mode')
if file_format in ('AIR', 'ONNX'):
if callable(enc_mode):
return enc_key, enc_mode
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
raise RuntimeError(f"AIR/ONNX only support customized encryption, but got {enc_mode}.")
enc_mode = 'AES-GCM'
if 'enc_mode' in kwargs.keys():
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
if file_format in ('AIR', 'ONNX'):
raise RuntimeError(f"AIR/ONNX only support customized encryption, but got {enc_mode}.")
if enc_mode in ('AES-CBC', 'AES-GCM'):
return enc_key, enc_mode
raise RuntimeError(f"MindIR only support AES-GCM/AES-CBC encryption, but got {enc_mode}")
@ -1074,6 +1082,10 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
write_data = raw_data + bytes(append_size)
offset += (data_length + append_size)
if is_encrypt():
if callable(kwargs.get('enc_mode')):
enc_func = kwargs.get('enc_mode')
write_data = enc_func(write_data, kwargs.get('enc_key'))
else:
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
f.write(write_data)
@ -1147,6 +1159,10 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
model_string = model.SerializeToString()
if is_encrypt():
if callable(kwargs.get('enc_mode')):
enc_func = kwargs.get('enc_mode')
model_string = enc_func(model_string, kwargs.get('enc_key'))
else:
model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
f.write(model_string)

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Test network export."""
import os
from io import BytesIO
import numpy as np
import pytest

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from io import BytesIO
import numpy as np
import pytest
@ -98,11 +99,29 @@ class TrainOneStepCell(nn.Cell):
return self.optimizer(grads)
def encrypt_func(model_stream, key):
plain_data = BytesIO()
plain_data.write(model_stream)
return plain_data.getvalue()
def decrypt_func(cipher_file, key):
with open(cipher_file, 'rb') as f:
plain_data = f.read()
f.close()
return plain_data
@pytest.mark.level1
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_export_lenet_grad_mindir():
"""
Feature: export LeNet to MindIR file
Description: Test export API to export network into MindIR
Expectation: export successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
network = LeNet5()
network.set_train()
@ -119,6 +138,11 @@ def test_export_lenet_grad_mindir():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_load_mindir_and_run():
"""
Feature: Load LeNet to MindIR
Description: Test load API to load network into MindIR
Expectation: load successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
network = LeNet5()
network.set_train()
@ -135,3 +159,32 @@ def test_load_mindir_and_run():
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(inputs0)
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
@pytest.mark.level1
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_load_mindir_and_run_with_encryption():
"""
Feature: Load encrypted LeNet to MindIR with decryption
Description: Test load API to load network with encryption into MindIR
Expectation: load successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
network = LeNet5()
network.set_train()
inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
outputs0 = network(inputs0)
inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
export(network, inputs, file_name="test_lenet_load_enc", file_format='MINDIR',
enc_key=b'123456789', enc_mode=encrypt_func)
mindir_name = "test_lenet_load_enc.mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name, dec_key=b'123456789', dec_mode=decrypt_func)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(inputs0)
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())

View File

@ -180,8 +180,8 @@ def test_export_lenet_with_dataset():
def test_export_lenet_onnx_with_encryption():
"""
Feature: Export encrypted LeNet to MindIR
Description: Test export API to save network with encryption into MindIR
Feature: Export encrypted LeNet to ONNX
Description: Test export API to save network with encryption into ONNX
Expectation: save successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -195,3 +195,22 @@ def test_export_lenet_onnx_with_encryption():
verify_name = file_name + ".onnx"
assert os.path.exists(verify_name)
os.remove(verify_name)
def test_export_lenet_mindir_with_encryption():
"""
Feature: Export encrypted LeNet to MindIR
Description: Test export API to save network with encryption into MindIR
Expectation: save successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
network = LeNet5()
network.set_train()
file_name = "lenet_preprocess"
input_tensor = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
export(network, input_tensor, file_name=file_name, file_format='MINDIR',
enc_key=b'123456789', enc_mode=encrypt_func)
verify_name = file_name + ".mindir"
assert os.path.exists(verify_name)
os.remove(verify_name)