forked from mindspore-Ecosystem/mindspore
Expand LoadMindIR with Encryption
This commit is contained in:
parent
e086ba7cf7
commit
52024260ad
|
@ -167,7 +167,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, py::arg("file_name"), py::arg("model_type"),
|
||||
py::arg("phase"), py::arg("encrypt") = py::none(), py::arg("key") = nullptr, "Export Graph.");
|
||||
(void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
|
||||
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
|
||||
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"),
|
||||
py::arg("decrypt") = py::none(), "Load model as Graph.");
|
||||
|
||||
(void)m.def("init_cluster", &mindspore::distributed::Initialize, "Init Cluster");
|
||||
|
||||
|
|
|
@ -1564,10 +1564,20 @@ void ExportGraph(const std::string &file_name, const std::string &model_type, co
|
|||
#endif
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
MindIRLoader mindir_loader(false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode, false);
|
||||
auto func_graph = mindir_loader.LoadMindIR(file_name);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode,
|
||||
const py::object decrypt) {
|
||||
FuncGraphPtr func_graph;
|
||||
if (dec_mode == "Customized") {
|
||||
py::bytes key_bytes(dec_key);
|
||||
py::bytes model_stream = decrypt(file_name, key_bytes);
|
||||
std::string model_string(model_stream);
|
||||
|
||||
MindIRLoader mindir_loader;
|
||||
func_graph = mindir_loader.LoadMindIR(model_string.c_str(), model_string.size());
|
||||
} else {
|
||||
MindIRLoader mindir_loader(false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode, false);
|
||||
func_graph = mindir_loader.LoadMindIR(file_name);
|
||||
}
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
|
|
@ -183,7 +183,8 @@ void MemoryRecycle();
|
|||
void ExportGraph(const std::string &file_name, const std::string &model_type, const std::string &phase,
|
||||
const py::object encrypt = py::none(), char *key = nullptr);
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode,
|
||||
const py::object decrypt = py::none());
|
||||
|
||||
// init and exec dataset sub graph
|
||||
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
|
||||
|
|
|
@ -388,8 +388,8 @@ def load(file_name, **kwargs):
|
|||
kwargs (dict): Configuration options dictionary.
|
||||
|
||||
- dec_key (bytes): Byte type key used for decryption. The valid length is 16, 24, or 32.
|
||||
- dec_mode (str): Specifies the decryption mode, to take effect when dec_key is set.
|
||||
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
|
||||
- dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
|
||||
Option: 'AES-GCM', 'AES-CBC' or customized decryption. Default: 'AES-GCM'.
|
||||
Returns:
|
||||
GraphCell, a compiled graph that can executed by `GraphCell`.
|
||||
|
||||
|
@ -429,9 +429,15 @@ def load(file_name, **kwargs):
|
|||
if 'dec_key' in kwargs.keys():
|
||||
dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
|
||||
dec_mode = "AES-GCM"
|
||||
dec_func = None
|
||||
if 'dec_mode' in kwargs.keys():
|
||||
dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
|
||||
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode)
|
||||
if callable(kwargs.get('dec_mode')):
|
||||
dec_mode = "Customized"
|
||||
dec_func = kwargs.get('dec_mode')
|
||||
else:
|
||||
dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
|
||||
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
|
||||
decrypt=dec_func)
|
||||
else:
|
||||
graph = load_mindir(file_name)
|
||||
|
||||
|
@ -859,8 +865,9 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
- std_dev (float): The variance of input data after preprocessing,
|
||||
used for quantizing the first layer of the network. Default: 127.5.
|
||||
- enc_key (byte): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
||||
- enc_mode (str): Specifies the encryption mode, to take effect when enc_key is set.
|
||||
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
|
||||
- enc_mode (Union[str, function]): Specifies the encryption mode, to take effect when enc_key is set.
|
||||
For 'AIR' and 'ONNX' models, only Customized encryption is supported. For 'MINDIR', all options are
|
||||
supported. Option: 'AES-GCM', 'AES-CBC' or Customized encryption by user. Default: 'AES-GCM'.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
|
@ -935,15 +942,16 @@ def _check_key_mode_type(file_format, **kwargs):
|
|||
enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
|
||||
enc_mode = kwargs.get('enc_mode')
|
||||
|
||||
if file_format in ('AIR', 'ONNX'):
|
||||
if callable(enc_mode):
|
||||
return enc_key, enc_mode
|
||||
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
||||
raise RuntimeError(f"AIR/ONNX only support customized encryption, but got {enc_mode}.")
|
||||
if callable(enc_mode):
|
||||
return enc_key, enc_mode
|
||||
|
||||
enc_mode = 'AES-GCM'
|
||||
if 'enc_mode' in kwargs.keys():
|
||||
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
||||
|
||||
if file_format in ('AIR', 'ONNX'):
|
||||
raise RuntimeError(f"AIR/ONNX only support customized encryption, but got {enc_mode}.")
|
||||
|
||||
if enc_mode in ('AES-CBC', 'AES-GCM'):
|
||||
return enc_key, enc_mode
|
||||
raise RuntimeError(f"MindIR only support AES-GCM/AES-CBC encryption, but got {enc_mode}")
|
||||
|
@ -1074,8 +1082,12 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|||
write_data = raw_data + bytes(append_size)
|
||||
offset += (data_length + append_size)
|
||||
if is_encrypt():
|
||||
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
|
||||
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
||||
if callable(kwargs.get('enc_mode')):
|
||||
enc_func = kwargs.get('enc_mode')
|
||||
write_data = enc_func(write_data, kwargs.get('enc_key'))
|
||||
else:
|
||||
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
|
||||
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
||||
f.write(write_data)
|
||||
|
||||
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
||||
|
@ -1147,8 +1159,12 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|||
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
|
||||
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
||||
if callable(kwargs.get('enc_mode')):
|
||||
enc_func = kwargs.get('enc_mode')
|
||||
model_string = enc_func(model_string, kwargs.get('enc_key'))
|
||||
else:
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs.get('enc_key'),
|
||||
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
||||
f.write(model_string)
|
||||
os.chmod(file_name, stat.S_IRUSR)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Test network export."""
|
||||
import os
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
@ -98,11 +99,29 @@ class TrainOneStepCell(nn.Cell):
|
|||
return self.optimizer(grads)
|
||||
|
||||
|
||||
def encrypt_func(model_stream, key):
|
||||
plain_data = BytesIO()
|
||||
plain_data.write(model_stream)
|
||||
return plain_data.getvalue()
|
||||
|
||||
|
||||
def decrypt_func(cipher_file, key):
|
||||
with open(cipher_file, 'rb') as f:
|
||||
plain_data = f.read()
|
||||
f.close()
|
||||
return plain_data
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_export_lenet_grad_mindir():
|
||||
"""
|
||||
Feature: export LeNet to MindIR file
|
||||
Description: Test export API to export network into MindIR
|
||||
Expectation: export successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
network = LeNet5()
|
||||
network.set_train()
|
||||
|
@ -119,6 +138,11 @@ def test_export_lenet_grad_mindir():
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_load_mindir_and_run():
|
||||
"""
|
||||
Feature: Load LeNet to MindIR
|
||||
Description: Test load API to load network into MindIR
|
||||
Expectation: load successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
network = LeNet5()
|
||||
network.set_train()
|
||||
|
@ -135,3 +159,32 @@ def test_load_mindir_and_run():
|
|||
loaded_net = nn.GraphCell(graph)
|
||||
outputs_after_load = loaded_net(inputs0)
|
||||
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_load_mindir_and_run_with_encryption():
|
||||
"""
|
||||
Feature: Load encrypted LeNet to MindIR with decryption
|
||||
Description: Test load API to load network with encryption into MindIR
|
||||
Expectation: load successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
network = LeNet5()
|
||||
network.set_train()
|
||||
|
||||
inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
outputs0 = network(inputs0)
|
||||
|
||||
inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32))
|
||||
export(network, inputs, file_name="test_lenet_load_enc", file_format='MINDIR',
|
||||
enc_key=b'123456789', enc_mode=encrypt_func)
|
||||
mindir_name = "test_lenet_load_enc.mindir"
|
||||
assert os.path.exists(mindir_name)
|
||||
|
||||
graph = load(mindir_name, dec_key=b'123456789', dec_mode=decrypt_func)
|
||||
loaded_net = nn.GraphCell(graph)
|
||||
outputs_after_load = loaded_net(inputs0)
|
||||
assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy())
|
||||
|
|
|
@ -180,8 +180,8 @@ def test_export_lenet_with_dataset():
|
|||
|
||||
def test_export_lenet_onnx_with_encryption():
|
||||
"""
|
||||
Feature: Export encrypted LeNet to MindIR
|
||||
Description: Test export API to save network with encryption into MindIR
|
||||
Feature: Export encrypted LeNet to ONNX
|
||||
Description: Test export API to save network with encryption into ONNX
|
||||
Expectation: save successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
@ -195,3 +195,22 @@ def test_export_lenet_onnx_with_encryption():
|
|||
verify_name = file_name + ".onnx"
|
||||
assert os.path.exists(verify_name)
|
||||
os.remove(verify_name)
|
||||
|
||||
|
||||
def test_export_lenet_mindir_with_encryption():
|
||||
"""
|
||||
Feature: Export encrypted LeNet to MindIR
|
||||
Description: Test export API to save network with encryption into MindIR
|
||||
Expectation: save successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
network = LeNet5()
|
||||
network.set_train()
|
||||
file_name = "lenet_preprocess"
|
||||
|
||||
input_tensor = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
export(network, input_tensor, file_name=file_name, file_format='MINDIR',
|
||||
enc_key=b'123456789', enc_mode=encrypt_func)
|
||||
verify_name = file_name + ".mindir"
|
||||
assert os.path.exists(verify_name)
|
||||
os.remove(verify_name)
|
||||
|
|
Loading…
Reference in New Issue