!44145 fix code cleans for dynamic obfuscation

Merge pull request !44145 from jxlang910/master
This commit is contained in:
i-robot 2022-10-21 07:47:42 +00:00 committed by Gitee
commit 9e8fcda9a7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 249 additions and 31 deletions

View File

@ -110,6 +110,7 @@ mindspore
mindspore.load_param_into_net mindspore.load_param_into_net
mindspore.merge_pipeline_strategys mindspore.merge_pipeline_strategys
mindspore.merge_sliced_parameter mindspore.merge_sliced_parameter
mindspore.obfuscate_model
mindspore.parse_print mindspore.parse_print
mindspore.rank_list_for_transform mindspore.rank_list_for_transform
mindspore.restore_group_info_list mindspore.restore_group_info_list

View File

@ -28,14 +28,16 @@ mindspore.export
- **std_dev** (float) - 预处理后输入数据的方差用于量化网络的第一层。默认值127.5。 - **std_dev** (float) - 预处理后输入数据的方差用于量化网络的第一层。默认值127.5。
- **enc_key** (str) - 用于加密的字节类型密钥有效长度为16、24或者32。 - **enc_key** (str) - 用于加密的字节类型密钥有效长度为16、24或者32。
- **enc_mode** (Union[str, function]) - 指定加密模式,当设置 `enc_key` 时启用。 - **enc_mode** (Union[str, function]) - 指定加密模式,当设置 `enc_key` 时启用。
- 对于'AIR'和'ONNX'格式的模型,当前仅支持自定义加密导出。 - 对于'AIR'和'ONNX'格式的模型,当前仅支持自定义加密导出。
- 对于'MINDIR'格式的模型,支持的加密选项有:'AES-GCM''AES-CBC'和用户自定义加密算法。默认值:"AES-GCM"。 - 对于'MINDIR'格式的模型,支持的加密选项有:'AES-GCM''AES-CBC'和用户自定义加密算法。默认值:"AES-GCM"。
- 关于使用自定义加密导出的详情,请查看 `教程 <https://www.mindspore.cn/mindarmour/docs/zh-CN/master/model_encrypt_protection.html>`_ - 关于使用自定义加密导出的详情,请查看 `教程 <https://www.mindspore.cn/mindarmour/docs/zh-CN/master/model_encrypt_protection.html>`_
- **dataset** (Dataset) - 指定数据集的预处理方法用于将数据集的预处理导入MindIR。 - **dataset** (Dataset) - 指定数据集的预处理方法用于将数据集的预处理导入MindIR。
- **obf_config** (dict) - 模型混淆配置选项字典。 - **obf_config** (dict) - 模型混淆配置选项字典。
- **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。 - **type** (str) - 混淆类型,目前支持动态混淆,即 'dynamic'
- **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。 - **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串 "small" "medium" "large"
- **customized_func** (function) - 在自定义函数模式下需要设置的Python函数用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型且是恒定的用户可以参考不透明谓词进行设置。如果设置了`customized_func`,那么在使用`load`接口导入模型的时候,需要把这个函数也传入。 - **customized_func** (function) - 在自定义函数模式下需要设置的Python函数用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型且是恒定的用户可以参考不透明谓词进行设置。如果设置了 `customized_func` ,那么在使用 `load` 接口导入模型的时候,需要把这个函数也传入。
- **obf_password** (int) - 秘密口令用于password模式是一个大于0的整数。如果用户设置了`obf_password`,那么在部署混淆模型的时候,需要在`nn.GraphCell()`接口中传入`obf_password`。需要注意的是,如果用户同时设置了`customized_func``obf_password`那么password模式将会被采用。 - **obf_password** (int) - 秘密口令用于password模式是一个大于0的整数。如果用户设置了 `obf_password` ,那么在部署混淆模型的时候,需要在调用 :class:`mindspore.nn.GraphCell` 接口中传入 `obf_password` 。需要注意的是,如果用户同时设置了 `customized_func` `obf_password` 那么password模式将会被采用。

View File

@ -9,14 +9,24 @@ mindspore.obfuscate_model
- **obf_config** (dict) - 模型混淆配置选项字典。 - **obf_config** (dict) - 模型混淆配置选项字典。
- **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。 - **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。
- **original_model_path** (str) - 待混淆的MindIR模型地址。如果该模型是加密文件的则需要在 `kwargs` 中传入 `enc_key``enc_mode` - **original_model_path** (str) - 待混淆的MindIR模型地址。如果该模型是加密文件的则需要在 `kwargs` 中传入 `enc_key``enc_mode`
- **save_model_path** (str) - 混淆模型的保存地址。 - **save_model_path** (str) - 混淆模型的保存地址。
- **model_inputs** (list[Tensor]) - 模型的推理输入Tensor的值可以是随机的和使用 `export()` 接口类似。 - **model_inputs** (list[Tensor]) - 模型的推理输入Tensor的值可以是随机的和使用 :func:`mindspore.export` 接口类似。
- **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。 - **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。
- **customized_func** (function) - 在自定义函数模式下需要设置的Python函数用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型且是恒定的用户可以参考不透明谓词进行设置。如果设置了 `customized_func` ,那么在使用 `load` 接口导入模型的时候,需要把这个函数也传入。 - **customized_func** (function) - 在自定义函数模式下需要设置的Python函数用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型且是恒定的用户可以参考不透明谓词进行设置。如果设置了 `customized_func` ,那么在使用 :func:`mindspore.load` 接口导入模型的时候,需要把这个函数也传入。
- **obf_password** (int) - 秘密口令用于password模式是一个大于0的整数。如果用户设置了 `obf_password` ,那么在部署混淆模型的时候,需要在 `nn.GraphCell()` 接口中传入 `obf_password` 。需要注意的是,如果用户同时设置了 `customized_func``obf_password` 那么password模式将会被采用。 - **obf_password** (int) - 秘密口令用于password模式是一个大于0的整数。如果用户设置了 `obf_password` ,那么在部署混淆模型的时候,需要在 :class:`mindspore.nn.GraphCell` 接口中传入 `obf_password` 。需要注意的是,如果用户同时设置了 `customized_func``obf_password` 那么password模式将会被采用。
- **kwargs** (dict) - 配置选项字典。 - **kwargs** (dict) - 配置选项字典。
- **enc_key** (str) - 用于加密的字节类型密钥有效长度为16、24或者32。 - **enc_key** (bytes) - 用于加密的字节类型密钥有效长度为16、24或者32。
- **enc_mode** (Union[str, function]) - 指定加密模式,当设置 `enc_key` 时启用。支持的加密选项有:'AES-GCM''AES-CBC'。默认值:"AES-GCM"。 - **enc_mode** (str) - 指定加密模式,当设置 `enc_key` 时启用。支持的加密选项有:'AES-GCM''AES-CBC'。默认值:"AES-GCM"。
异常:
- **TypeError** - `obf_config` 不是字典类型。
- **ValueError** - 传入了 `enc_key` 但是 `enc_mode` 不在["AES-GCM", "AES-CBC"]内。
- **ValueError** - `obf_config` 没有提供 `original_model_path`
- **ValueError** - `original_model_path` 中的模型是已经混淆过的。
- **ValueError** - `obf_config` 没有提供 `save_model_path`
- **ValueError** - `obf_config` 没有提供 `obf_ratio`
- **ValueError** - `customized_func``obf_password` 都不在 `obf_config` 里面。
- **ValueError** - `obf_password` 的取值没有在(0, 9223372036854775807]内。

View File

@ -1,7 +1,7 @@
mindspore.nn.GraphCell mindspore.nn.GraphCell
====================== ======================
.. py:class:: mindspore.nn.GraphCell(graph, params_init=None) .. py:class:: mindspore.nn.GraphCell(graph, params_init=None, obf_password=None)
运行从MindIR加载的计算图。 运行从MindIR加载的计算图。
@ -10,6 +10,7 @@ mindspore.nn.GraphCell
参数: 参数:
- **graph** (FuncGraph) - 从MindIR加载的编译图。 - **graph** (FuncGraph) - 从MindIR加载的编译图。
- **params_init** (dict) - 需要在图中初始化的参数。key为参数名称类型为字符串value为 Tensor 或 Parameter。如果参数名在图中已经存在则更新其值如果不存在则忽略。默认值None。 - **params_init** (dict) - 需要在图中初始化的参数。key为参数名称类型为字符串value为 Tensor 或 Parameter。如果参数名在图中已经存在则更新其值如果不存在则忽略。默认值None。
- **obf_password** (int) - 用于动态混淆保护的password。动态混淆是一种模型保护方法可以参考 :func:`mindspore.train.serialization.obfuscate_model` 。如果导入的 `graph` 是一个经过混淆的模型,那么 `obf_password` 应该要提供。 `obf_password` 的取值范围是(0, 9223372036854775807]。默认值None。
异常: 异常:
- **TypeError** - 如果图不是FuncGraph类型。 - **TypeError** - 如果图不是FuncGraph类型。

View File

@ -229,13 +229,13 @@ Serialization
mindspore.load_param_into_net mindspore.load_param_into_net
mindspore.merge_pipeline_strategys mindspore.merge_pipeline_strategys
mindspore.merge_sliced_parameter mindspore.merge_sliced_parameter
mindspore.obfuscate_model
mindspore.parse_print mindspore.parse_print
mindspore.rank_list_for_transform mindspore.rank_list_for_transform
mindspore.restore_group_info_list mindspore.restore_group_info_list
mindspore.save_checkpoint mindspore.save_checkpoint
mindspore.transform_checkpoint_by_rank mindspore.transform_checkpoint_by_rank
mindspore.transform_checkpoints mindspore.transform_checkpoints
mindspore.obfuscate_model
JIT JIT
--- ---

View File

@ -2288,11 +2288,14 @@ class GraphCell(Cell):
f"but got type {type(graph)}.") f"but got type {type(graph)}.")
self.graph = graph self.graph = graph
self.obf_password = obf_password self.obf_password = obf_password
int_64_max = 9223372036854775807 if obf_password is not None:
if (obf_password is not None) and (obf_password <= 0 or obf_password > int_64_max): if not isinstance(obf_password, int):
raise ValueError( raise TypeError("'obf_password' must be int, but got {}.".format(type(obf_password)))
"'obf_password' must be larger than 0, and less or equal than int64 ({})," int_64_max = 9223372036854775807
"but got {}.".format(int_64_max, obf_password)) if obf_password <= 0 or obf_password > int_64_max:
raise ValueError(
"'obf_password' must be larger than 0, and less or equal than int64 ({}),"
"but got {}.".format(int_64_max, obf_password))
params_init = {} if params_init is None else params_init params_init = {} if params_init is None else params_init
if not isinstance(params_init, dict): if not isinstance(params_init, dict):
raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.") raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")

View File

@ -82,6 +82,7 @@ PROTO_LIMIT_SIZE = 1024 * 1024 * 2
TOTAL_SAVE = 1024 * 1024 TOTAL_SAVE = 1024 * 1024
PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024 PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024
ENCRYPT_BLOCK_SIZE = 64 * 1024 ENCRYPT_BLOCK_SIZE = 64 * 1024
INT_64_MAX = 9223372036854775807
def _special_process_par(par, new_par): def _special_process_par(par, new_par):
@ -386,9 +387,7 @@ def _check_append_dict(append_dict):
def _check_load_obfuscate(**kwargs): def _check_load_obfuscate(**kwargs):
if 'obf_func' in kwargs.keys(): if 'obf_func' in kwargs.keys():
customized_func = kwargs.get('obf_func') customized_func = _check_customized_func(kwargs.get('obf_func'))
if not callable(customized_func):
raise ValueError("obf_func must be a callable function, but got a {}.".format(type(customized_func)))
clean_funcs() clean_funcs()
add_opaque_predicate(customized_func.__name__, customized_func) add_opaque_predicate(customized_func.__name__, customized_func)
return True return True
@ -480,6 +479,10 @@ def _check_param_type(param_config, key, target_type, requested):
if key in param_config: if key in param_config:
if not isinstance(param_config[key], target_type): if not isinstance(param_config[key], target_type):
raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key]))) raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
if key == 'obf_password':
if param_config[key] > INT_64_MAX or param_config[key] <= 0:
raise ValueError(
"'obf_password' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX, param_config[key]))
return param_config[key] return param_config[key]
if requested: if requested:
raise ValueError("The parameter {} is requested, but not got.".format(key)) raise ValueError("The parameter {} is requested, but not got.".format(key))
@ -488,6 +491,22 @@ def _check_param_type(param_config, key, target_type, requested):
return None return None
def _check_customized_func(customized_func):
""" check customized function of dynamic obfuscation """
if not callable(customized_func):
raise TypeError(
"'customized_func' must be a function, but not got {}.".format(type(customized_func)))
# test customized_func
try:
func_result = customized_func(1.0, 1.0)
except Exception as ex:
raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex))
else:
if not isinstance(func_result, bool):
raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result)))
return customized_func
def _check_obfuscate_params(obf_config): def _check_obfuscate_params(obf_config):
"""check obfuscation parameters, including obf_password, obf_ratio, customized_func""" """check obfuscation parameters, including obf_password, obf_ratio, customized_func"""
if 'obf_password' not in obf_config.keys() and 'customized_func' not in obf_config.keys(): if 'obf_password' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
@ -507,16 +526,8 @@ def _check_obfuscate_params(obf_config):
raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio'])) raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
customized_funcs = [] customized_funcs = []
if 'customized_func' in obf_config.keys(): if 'customized_func' in obf_config.keys():
if callable(obf_config['customized_func']): customized_funcs.append(_check_customized_func(obf_config['customized_func']))
customized_funcs.append(obf_config['customized_func'])
else:
raise TypeError(
"'customized_func' must be a function, but not got {}.".format(type(obf_config['customized_func'])))
obf_password = _check_param_type(obf_config, "obf_password", int, False) obf_password = _check_param_type(obf_config, "obf_password", int, False)
int_64_max = 9223372036854775807
if obf_password > int_64_max:
raise ValueError(
"'obf_password' must be less or equal than int64 ({}), but got {}.".format(int_64_max, obf_password))
return obf_ratio, customized_funcs, obf_password return obf_ratio, customized_funcs, obf_password
@ -534,8 +545,8 @@ def obfuscate_model(obf_config, **kwargs):
- save_model_path (str): The path to save the obfuscated model. - save_model_path (str): The path to save the obfuscated model.
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
is the same as using `export()`. is the same as using `export()`.
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio` should - obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
be in range of (0, 1] or in ["small", "medium", "large"]. should be in range of (0, 1] or in ["small", "medium", "large"].
- customized_func (function): A python function used for customized function mode, which used for control - customized_func (function): A python function used for customized function mode, which used for control
the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
function needs to ensure that its result is constant for any input. Users can refer to opaque function needs to ensure that its result is constant for any input. Users can refer to opaque
@ -1046,6 +1057,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
- For details of using the customized encryption, please check the `tutorial - For details of using the customized encryption, please check the `tutorial
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_. <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
- dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
preprocessing of the dataset into MindIR.
- obf_config (dict): obfuscation config. - obf_config (dict): obfuscation config.
- type (str): The type of obfuscation, only 'dynamic' is supported until now. - type (str): The type of obfuscation, only 'dynamic' is supported until now.

View File

@ -0,0 +1,187 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test dynamic obfuscation"""
import os
import numpy as np
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import load, Tensor, export, obfuscate_model, context
from mindspore.common.initializer import TruncatedNormal
context.set_context(mode=context.GRAPH_MODE)
def weight_variable():
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
class ObfuscateNet(nn.Cell):
def __init__(self):
super(ObfuscateNet, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.matmul = ops.MatMul()
self.matmul_weight1 = Tensor(np.random.random((16 * 5 * 5, 120)).astype(np.float32))
self.matmul_weight2 = Tensor(np.random.random((120, 84)).astype(np.float32))
self.matmul_weight3 = Tensor(np.random.random((84, 10)).astype(np.float32))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.matmul(x, self.matmul_weight1)
x = self.relu(x)
x = self.matmul(x, self.matmul_weight2)
x = self.relu(x)
x = self.matmul(x, self.matmul_weight3)
return x
def test_obfuscate_model_password_mode():
"""
Feature: Obfuscate MindIR format model with dynamic obfuscation (password mode).
Description: Test obfuscate a MindIR format model and then load it for prediction.
Expectation: Success.
"""
net = ObfuscateNet()
input_tensor = Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
export(net, input_tensor, file_name="net", file_format="MINDIR")
original_result = net(input_tensor).asnumpy()
# obfuscate model
obf_config = {"original_model_path": "net.mindir", "save_model_path": "./obf_net",
"model_inputs": [input_tensor], "obf_ratio": 0.8, "obf_password": 3423}
obfuscate_model(obf_config)
# load obfuscated model, predict with right password
obf_graph = load("obf_net.mindir")
obf_net = nn.GraphCell(obf_graph, obf_password=3423)
right_password_result = obf_net(input_tensor).asnumpy()
os.remove("net.mindir")
os.remove("obf_net.mindir")
assert np.all(original_result == right_password_result)
def test_obfuscate_model_customized_func_mode():
"""
Feature: Obfuscate MindIR format model with dynamic obfuscation (cusomized_func mode).
Description: Test obfuscate a MindIR format model and then load it for prediction.
Expectation: Success.
"""
net = ObfuscateNet()
input_tensor = Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
export(net, input_tensor, file_name="net", file_format="MINDIR")
original_result = net(input_tensor).asnumpy()
# obfuscate model
def my_func(x1, x2):
if x1 + x2 > 1000000000:
return True
return False
obf_config = {"original_model_path": "net.mindir", "save_model_path": "./obf_net",
"model_inputs": [input_tensor], "obf_ratio": 0.8, "customized_func": my_func}
obfuscate_model(obf_config)
# load obfuscated model, predict with right customized function
obf_graph = load("obf_net.mindir", obf_func=my_func)
obf_net = nn.GraphCell(obf_graph)
right_func_result = obf_net(input_tensor).asnumpy()
os.remove("net.mindir")
os.remove("obf_net.mindir")
assert np.all(original_result == right_func_result)
def test_export_password_mode():
"""
Feature: Obfuscate MindIR format model with dynamic obfuscation (password mode) in export().
Description: Test obfuscate a MindIR format model and then load it for prediction.
Expectation: Success.
"""
net = ObfuscateNet()
input_tensor = Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
export(net, input_tensor, file_name="net", file_format="MINDIR")
original_result = net(input_tensor).asnumpy()
# obfuscate model
obf_config = {"obf_ratio": 0.8, "obf_password": 3423}
export(net, input_tensor, file_name="obf_net", file_format="MINDIR", obf_config=obf_config)
# load obfuscated model, predict with right password
obf_graph = load("obf_net.mindir")
obf_net = nn.GraphCell(obf_graph, obf_password=3423)
right_password_result = obf_net(input_tensor).asnumpy()
os.remove("net.mindir")
os.remove("obf_net.mindir")
assert np.all(original_result == right_password_result)
def test_export_customized_func_mode():
"""
Feature: Obfuscate MindIR format model with dynamic obfuscation (customized_func mode) in export().
Description: Test obfuscate a MindIR format model and then load it for prediction.
Expectation: Success.
"""
net = ObfuscateNet()
input_tensor = Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
export(net, input_tensor, file_name="net", file_format="MINDIR")
original_result = net(input_tensor).asnumpy()
# obfuscate model
def my_func(x1, x2):
if x1 + x2 > 1000000000:
return True
return False
obf_config = {"obf_ratio": 0.8, "customized_func": my_func}
export(net, input_tensor, file_name="obf_net", file_format="MINDIR", obf_config=obf_config)
# load obfuscated model, predict with customized function
obf_graph = load("obf_net.mindir", obf_func=my_func)
obf_net = nn.GraphCell(obf_graph)
right_func_result = obf_net(input_tensor).asnumpy()
os.remove("net.mindir")
os.remove("obf_net.mindir")
assert np.all(original_result == right_func_result)