Add dynamic obfucation tool

This commit is contained in:
jin-xiulang 2022-10-11 09:44:23 +08:00
parent 3420611c13
commit 969e368cc5
23 changed files with 1367 additions and 33 deletions

View File

@ -23,6 +23,8 @@
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include" "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
"mindspore/mindspore/ccsrc/backend/common/optimizer/op_adaptation_info_factory.h" "runtime/explicit" "mindspore/mindspore/ccsrc/backend/common/optimizer/op_adaptation_info_factory.h" "runtime/explicit"
"mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cu" "runtime/int" "mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cu" "runtime/int"
"mindspore/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc" "runtime/threadsafe_fn"
"mindspore/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc" "runtime/references"
# Modelzoo # Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references" "mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"

View File

@ -32,3 +32,10 @@ mindspore.export
- 对于'MINDIR'格式的模型,支持的加密选项有:'AES-GCM''AES-CBC'和用户自定义加密算法。默认值:"AES-GCM"。 - 对于'MINDIR'格式的模型,支持的加密选项有:'AES-GCM''AES-CBC'和用户自定义加密算法。默认值:"AES-GCM"。
- 关于使用自定义加密导出的详情,请查看 `教程 <https://www.mindspore.cn/mindarmour/docs/zh-CN/master/model_encrypt_protection.html>`_ - 关于使用自定义加密导出的详情,请查看 `教程 <https://www.mindspore.cn/mindarmour/docs/zh-CN/master/model_encrypt_protection.html>`_
- **dataset** (Dataset) - 指定数据集的预处理方法用于将数据集的预处理导入MindIR。 - **dataset** (Dataset) - 指定数据集的预处理方法用于将数据集的预处理导入MindIR。
- **obf_config** (dict) - 模型混淆配置选项字典。
- **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。
- **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。
- **customized_func** (function) - 在自定义函数模式下需要设置的Python函数用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型且是恒定的用户可以参考不透明谓词进行设置。如果设置了`customized_func`,那么在使用`load`接口导入模型的时候,需要把这个函数也传入。
- **obf_password** (int) - 秘密口令用于password模式是一个大于0的整数。如果用户设置了`obf_password`,那么在部署混淆模型的时候,需要在`nn.GraphCell()`接口中传入`obf_password`。需要注意的是,如果用户同时设置了`customized_func``obf_password`那么password模式将会被采用。

View File

@ -0,0 +1,22 @@
mindspore.obfuscate_model
=========================
.. py:function:: mindspore.obfuscate_model(obf_config, **kwargs)
对MindIR格式的模型进行混淆混淆主要是修改模型的网络结构但不影响它的推理精度混淆后的模型可以防止被盗用。
参数:
- **obf_config** (dict) - 模型混淆配置选项字典。
- **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。
- **original_model_path** (str) - 待混淆的MindIR模型地址。如果该模型是加密文件的则需要在`kwargs`中传入`enc_key``enc_mode`
- **save_model_path** (str) - 混淆模型的保存地址。
- **model_inputs** (list[Tensor]) - 模型的推理输入Tensor的值可以是随机的和使用`export()`接口类似。
- **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。
- **customized_func** (function) - 在自定义函数模式下需要设置的Python函数用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型且是恒定的用户可以参考不透明谓词进行设置。如果设置了`customized_func`,那么在使用`load`接口导入模型的时候,需要把这个函数也传入。
- **obf_password** (int) - 秘密口令用于password模式是一个大于0的整数。如果用户设置了`obf_password`,那么在部署混淆模型的时候,需要在`nn.GraphCell()`接口中传入`obf_password`。需要注意的是,如果用户同时设置了`customized_func``obf_password`那么password模式将会被采用。
- **kwargs** (dict) - 配置选项字典。
- **enc_key** (str) - 用于加密的字节类型密钥有效长度为16、24或者32。
- **enc_mode** (Union[str, function]) - 指定加密模式,当设置 `enc_key` 时启用。支持的加密选项有:'AES-GCM''AES-CBC'。默认值:"AES-GCM"。

View File

@ -234,6 +234,7 @@ Serialization
mindspore.save_checkpoint mindspore.save_checkpoint
mindspore.transform_checkpoint_by_rank mindspore.transform_checkpoint_by_rank
mindspore.transform_checkpoints mindspore.transform_checkpoints
mindspore.obfuscate_model
JIT JIT
--- ---

View File

@ -147,6 +147,9 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_func_graph", &GraphExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") .def("get_func_graph", &GraphExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.")
.def("get_func_graph_proto", &GraphExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), .def("get_func_graph_proto", &GraphExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
.def("get_obfuscate_func_graph_proto", &GraphExecutorPy::GetObfuscateFuncGraphProto, py::arg("phase") = py::str(""),
py::arg("obf_ratio") = py::float_(1.0), py::arg("obf_pasword") = py::int_(0),
py::arg("append_password") = py::int_(0), "Get graph proto of dynamic-obfuscated model.")
.def("get_params", &GraphExecutorPy::GetParams, py::arg("phase") = py::str(""), "Get Parameters from graph") .def("get_params", &GraphExecutorPy::GetParams, py::arg("phase") = py::str(""), "Get Parameters from graph")
.def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), .def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
@ -200,11 +203,13 @@ PYBIND11_MODULE(_c_expression, m) {
py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset."); py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset.");
(void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode.");
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline."); (void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
(void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr, (void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"),
py::arg("decrypt") = py::none(), "Load model as Graph."); py::arg("decrypt") = py::none(), py::arg("obfuscated") = py::bool_(false), "Load model as Graph.");
(void)m.def("dynamic_obfuscate_mindir", &mindspore::pipeline::DynamicObfuscateMindIR, py::arg("file_name"),
py::arg("obf_ratio"), py::arg("obf_password") = py::int_(0), py::arg("append_password") = py::int_(0),
py::arg("dec_key") = nullptr, py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"),
"Obfuscate a mindir model by dynamic obfuscation.");
(void)m.def("init_cluster", &mindspore::distributed::Initialize, "Init Cluster"); (void)m.def("init_cluster", &mindspore::distributed::Initialize, "Init Cluster");
(void)m.def("set_cluster_exit_with_exception", &mindspore::distributed::set_cluster_exit_with_exception, (void)m.def("set_cluster_exit_with_exception", &mindspore::distributed::set_cluster_exit_with_exception,
"Set this process exits with exception."); "Set this process exits with exception.");

View File

@ -66,6 +66,8 @@
#include "runtime/pynative/op_executor.h" #include "runtime/pynative/op_executor.h"
#include "runtime/device/stream_synchronizer.h" #include "runtime/device/stream_synchronizer.h"
#include "distributed/collective/collective_manager.h" #include "distributed/collective/collective_manager.h"
#include "mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h"
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
#if defined(__linux__) && defined(WITH_BACKEND) #if defined(__linux__) && defined(WITH_BACKEND)
#include "ps/constants.h" #include "ps/constants.h"
@ -442,6 +444,7 @@ py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std
} }
if (ir_type == IR_TYPE_MINDIR) { if (ir_type == IR_TYPE_MINDIR) {
// obfuscate model
std::string proto_str = GetBinaryProtoString(fg_ptr); std::string proto_str = GetBinaryProtoString(fg_ptr);
if (proto_str.empty()) { if (proto_str.empty()) {
MS_LOG(EXCEPTION) << "Export MINDIR format model failed."; MS_LOG(EXCEPTION) << "Export MINDIR format model failed.";
@ -452,6 +455,24 @@ py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std
MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
} }
py::bytes GraphExecutorPy::GetObfuscateFuncGraphProto(const std::string &phase, const float obf_ratio,
const int obf_password, const int append_password) {
FuncGraphPtr fg_ptr = GetFuncGraph(phase);
// obfuscate model
if (obf_password == 0) {
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
MS_LOG(DEBUG) << "[GetObfuscateFuncGraphProto] set customized function names finished";
}
mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, obf_password, append_password);
mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(fg_ptr);
std::string proto_str = GetBinaryProtoString(obfuscated_graph);
if (proto_str.empty()) {
MS_LOG(EXCEPTION) << "GetBinaryProtoString failed.";
}
return proto_str;
}
py::bytes GraphExecutorPy::GetOptimizeGraphProto(const std::string &phase) { py::bytes GraphExecutorPy::GetOptimizeGraphProto(const std::string &phase) {
if (info_.count(phase) == 0) { if (info_.count(phase) == 0) {
MS_LOG(EXCEPTION) << "No phase in executor: " << phase; MS_LOG(EXCEPTION) << "No phase in executor: " << phase;
@ -1255,6 +1276,8 @@ void GraphExecutorPy::TerminateDebugger() {
#endif #endif
py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) { py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) {
// init for dynamic-obfuscated model infer
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
// Mindspore debugger notify main thread to exit after one step, and will not run next step // Mindspore debugger notify main thread to exit after one step, and will not run next step
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
TerminateDebugger(); TerminateDebugger();
@ -1655,7 +1678,12 @@ void GraphExecutorPy::ExportGraph(const std::string &file_name, const std::strin
} }
FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len, FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
const std::string &dec_mode, const py::object decrypt) { const std::string &dec_mode, const py::object decrypt, const bool obfuscated) {
if (obfuscated) {
MS_LOG(DEBUG) << "[LoadMindIR] Set customized function.";
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
}
FuncGraphPtr func_graph = nullptr; FuncGraphPtr func_graph = nullptr;
if (dec_mode == "Customized") { if (dec_mode == "Customized") {
py::bytes key_bytes(dec_key); py::bytes key_bytes(dec_key);
@ -1679,6 +1707,28 @@ FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const
return func_graph; return func_graph;
} }
FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int obf_password,
int append_password, char *dec_key, const size_t key_len,
const std::string &dec_mode) {
if (obf_password == 0) {
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
MS_LOG(DEBUG) << "[DynamicObfuscateMindIR] set function names finished.";
}
mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, obf_password, append_password);
MindIRLoader mindir_loader(false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode, false);
FuncGraphPtr func_graph = mindir_loader.LoadMindIR(file_name);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "[DynamicObfuscateMindIR] load mindir failed, please check the mindir file.";
return nullptr;
}
mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(func_graph);
if (obfuscated_graph == nullptr) {
MS_LOG(ERROR) << "[DynamicObfuscateMindIR] obfuscate model failed.";
return nullptr;
}
return obfuscated_graph;
}
void CloseTsd(bool force) { void CloseTsd(bool force) {
#ifdef WITH_BACKEND #ifdef WITH_BACKEND
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();

View File

@ -88,6 +88,8 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
FuncGraphPtr GetGradGraph(const std::string &phase); FuncGraphPtr GetGradGraph(const std::string &phase);
void SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase); void SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase);
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &ir_type); py::bytes GetFuncGraphProto(const std::string &phase, const std::string &ir_type);
py::bytes GetObfuscateFuncGraphProto(const std::string &phase, const float obf_ratio, const int obf_password,
const int append_password);
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
py::bytes GetOptimizeGraphProto(const std::string &phase); py::bytes GetOptimizeGraphProto(const std::string &phase);
#endif #endif
@ -185,7 +187,8 @@ void CloseTsd(bool force = false);
void MemoryRecycle(); void MemoryRecycle();
FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len, FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
const std::string &dec_mode, const py::object decrypt = py::none()); const std::string &dec_mode, const py::object decrypt = py::none(),
const bool obfuscated = false);
// init and exec dataset sub graph // init and exec dataset sub graph
bool ME_EXPORT InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, bool ME_EXPORT InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
@ -203,6 +206,9 @@ py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_le
py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode); py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode);
bool PyIsCipherFile(const std::string &file_path); bool PyIsCipherFile(const std::string &file_path);
void FinalizeCluster(); void FinalizeCluster();
FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int obf_password,
int append_password, char *dec_key, const size_t key_len,
const std::string &dec_mode);
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/opaque_predicate_kernel.h"
#include <utility>
#include <algorithm>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputsNum = 2;
constexpr size_t kOutputsNum = 1;
} // namespace
template <typename T>
bool OpaquePredicateKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
auto input1 = reinterpret_cast<T *>(inputs[0]->addr);
auto input2 = reinterpret_cast<T *>(inputs[1]->addr);
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
output[0] =
CustomizedOpaquePredicate::GetInstance().run_function(static_cast<float>(*input1), static_cast<float>(*input2));
return true;
}
bool OpaquePredicateKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
std::vector<std::pair<KernelAttr, OpaquePredicateKernelMod::OpaquePredicateFunc>> OpaquePredicateKernelMod::func_list_ =
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
&OpaquePredicateKernelMod::LaunchKernel<float>}};
std::vector<KernelAttr> OpaquePredicateKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, OpaquePredicateFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, OpaquePredicate, OpaquePredicateKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_
#include <complex>
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class OpaquePredicateKernelMod : public NativeCpuKernelMod {
public:
OpaquePredicateKernelMod() = default;
~OpaquePredicateKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using OpaquePredicateFunc =
std::function<bool(OpaquePredicateKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
OpaquePredicateFunc kernel_func_;
static std::vector<std::pair<KernelAttr, OpaquePredicateFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_

View File

@ -0,0 +1,494 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h"
#include <cstdlib>
#include <algorithm>
#include <map>
#include <memory>
#include <functional>
#include <random>
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
#include "include/common/debug/anf_ir_dump.h"
#include "utils/info.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "ops/core_ops.h"
namespace mindspore {
using Tensor = mindspore::tensor::Tensor;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTensorPtr;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
constexpr int expand_rate = 10; // total node need for a switch graph
ShapeVector get_node_shape(AnfNodePtr input_node) {
if (input_node == nullptr) {
MS_LOG(ERROR) << "Input_node is nullptr, get shape failed!";
return {};
}
AbstractBasePtr input_abstract = input_node->abstract();
if (input_abstract == nullptr) {
MS_LOG(ERROR) << "The abstract of input_node is nullptr, get shape failed!";
return {};
}
AbstractTensorPtr input_abstract_tensor = input_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
mindspore::abstract::ShapePtr shape_ptr = input_abstract_tensor->shape();
return shape_ptr->shape();
}
TypeId get_node_dtype(AnfNodePtr input_node) {
if (input_node == nullptr) {
MS_LOG(ERROR) << "Input_node is nullptr, get dtype failed!";
return {};
}
AbstractBasePtr input_abstract = input_node->abstract();
if (input_abstract == nullptr) {
MS_LOG(ERROR) << "The abstract of input_node is nullptr, get dtype failed!";
return {};
}
AbstractTensorPtr input_abstract_tensor = input_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
AbstractBasePtr node_element = input_abstract_tensor->element();
mindspore::abstract::AbstractScalarPtr node_element_abs =
node_element->cast<mindspore::abstract::AbstractScalarPtr>();
TypeId data_type = node_element_abs->BuildType()->type_id();
return data_type;
}
std::vector<std::string> name_split(std::string &node_name, const std::string &split_sign) {
node_name += split_sign;
unsigned int name_len = node_name.size();
std::string::size_type split_pos;
std::vector<std::string> res;
for (unsigned int i = 0; i < name_len; i++) {
split_pos = node_name.find(split_sign, i);
if (split_pos < name_len) {
std::string sub_str = node_name.substr(i, split_pos - i);
res.push_back(sub_str);
i = split_pos + split_sign.size() - 1;
}
}
return res;
}
ValueNodePtr build_tuple_value_node(std::vector<int64_t> values) {
mindspore::ValueNodePtr v_node = std::make_shared<mindspore::ValueNode>(MakeValue(values));
AbstractBasePtrList abs_list;
std::transform(values.begin(), values.end(), std::back_inserter(abs_list), [](const int64 &item) {
return std::make_shared<mindspore::abstract::AbstractScalar>(int64_t(item));
});
auto abs_tuple = std::make_shared<mindspore::abstract::AbstractTuple>(abs_list);
v_node->set_abstract(abs_tuple);
return v_node;
}
ValueNodePtr make_int_node(FuncGraphPtr func_graph, int int_value) {
ShapeVector int_shape{1, 1};
tensor::TensorPtr int_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeInt32, int_shape);
int *tensor_data = reinterpret_cast<int *>(int_tensor->data_c());
for (int i = 0; i < int_tensor->data().size(); i++) {
tensor_data[i] = int_value;
}
mindspore::ValueNodePtr int_tensor_node = std::make_shared<mindspore::ValueNode>(int_tensor);
int_tensor_node->set_abstract(int_tensor->ToAbstract());
(void)func_graph->AddValueNode(int_tensor_node);
return int_tensor_node;
}
tensor::TensorPtr make_weight_tensor(TypeId type_id, ShapeVector shape) {
tensor::TensorPtr weight_tensor = std::make_shared<Tensor>(type_id, shape);
std::default_random_engine generator;
int max_count = 10000;
int tensor_size = weight_tensor->data().size();
if (type_id == kNumberTypeFloat64) {
const double mean_64 = 0;
const double stddev_64 = 1;
std::normal_distribution<double> dist_64(mean_64, stddev_64);
double *float_64_data = reinterpret_cast<double *>(weight_tensor->data_c());
for (int i = 0; i < std::min(tensor_size, max_count); i++) {
double random_float_64 = dist_64(generator);
if (random_float_64 > 0) {
float_64_data[i] = random_float_64;
}
}
} else {
MS_LOG(DEBUG) << "Type id is: " << type_id << ", weights will be float_32 format.";
const float mean = 0;
const float stddev = 1;
std::normal_distribution<float> dist_32(mean, stddev);
float *float_32_data = reinterpret_cast<float *>(weight_tensor->data_c());
for (int i = 0; i < std::min(tensor_size, max_count); i++) {
float random_float_32 = dist_32(generator);
if (random_float_32 > 0) {
float_32_data[i] = random_float_32;
}
}
}
return weight_tensor;
}
bool CheckIfObfuscated(const FuncGraphPtr &func_graph) {
auto mgr = Manage(func_graph);
auto all_nodes = mgr->all_nodes();
for (AnfNodePtr node : all_nodes) {
std::string node_name = node->fullname_with_scope();
if (node_name.find("Switch") != std::string::npos) {
return true;
}
}
return false;
}
FuncGraphPtr DynamicObfuscator::ObfuscateMindIR(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "Start obfuscation.";
MS_EXCEPTION_IF_NULL(func_graph);
if (CheckIfObfuscated(func_graph)) {
MS_EXCEPTION(ValueError) << "The input model has been onfuscated, do not obfuscate it again.";
}
auto mgr = Manage(func_graph);
MS_EXCEPTION_IF_NULL(mgr);
auto all_nodes = mgr->all_nodes();
int node_nums = all_nodes.size();
MS_LOG(INFO) << "Total node num: " << node_nums;
// init the number control node that has been build
used_control_node_ = 0;
if (obf_password_ == 0) {
int obfuscate_target_num = std::ceil(all_nodes.size() * obf_ratio_ / expand_rate);
int obfuscate_node_num = 0;
// record customized_func computing results
for (AnfNodePtr node : all_nodes) {
std::string obf_type = single_op_obfuscate_type(node);
MS_LOG(INFO) << "obf_type: " << obf_type;
if (obf_type == "MatMul-op") {
obfuscate_node_num += 1;
MS_LOG(INFO) << "Find a MatMul Node: " << node->fullname_with_scope();
bool customized_func_result = mindspore::kernel::CustomizedOpaquePredicate::GetInstance().run_function(
static_cast<float>(1), static_cast<float>(1));
customized_func_results_.push_back(customized_func_result);
}
if (obfuscate_node_num >= obfuscate_target_num) {
break;
}
}
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
}
// do op-wise fake-branch obfuscation
(void)op_wise_fake_branch(func_graph);
if (used_control_node_ == 0) {
MS_LOG(WARNING)
<< "The model has not been obfuscated, which means obf_password or customized_func is not need to set.";
}
return func_graph;
}
void DynamicObfuscator::op_wise_fake_branch(FuncGraphPtr func_graph) {
auto mgr = Manage(func_graph);
auto all_nodes = mgr->all_nodes();
int obfuscate_target_num = std::ceil(all_nodes.size() * obf_ratio_ / expand_rate);
int obfuscate_node_num = 0;
for (AnfNodePtr node : all_nodes) {
std::string obf_type = single_op_obfuscate_type(node);
MS_LOG(INFO) << "The obf_type is: " << obf_type;
if (obf_type == "MatMul-op") {
obfuscate_node_num += 1;
MS_LOG(INFO) << "Find a MatMul Node: " << node->fullname_with_scope();
std::vector<AnfNodePtr> node_inputs = node->cast<mindspore::CNodePtr>()->inputs();
mindspore::AnfNodePtr input_1 = node_inputs[1];
CNodePtr control_c_node = get_control_node(func_graph, input_1);
(void)replace_matmul_node(node->cast<CNodePtr>(), func_graph, control_c_node);
MS_LOG(INFO) << "Finished replacement for: " << node->fullname_with_scope();
}
if (obfuscate_node_num >= obfuscate_target_num) {
break;
}
}
}
std::string DynamicObfuscator::single_op_obfuscate_type(AnfNodePtr node) {
if (node->isa<CNode>()) {
std::string node_name = node->fullname_with_scope();
MS_LOG(INFO) << "The node_name is: " << node_name;
std::vector<std::string> split_words = name_split(node_name, "/");
std::string op_name = split_words[split_words.size() - 1];
for (std::string target_op_name : obf_target_op) {
int op_name_len = op_name.size();
int target_name_len = target_op_name.size();
if ((op_name_len >= target_name_len) && (op_name.substr(0, target_name_len) == target_op_name)) {
return target_op_name;
}
}
return "";
}
return "";
}
CNodePtr DynamicObfuscator::password_mode_control(FuncGraphPtr func_graph) {
ShapeVector y_shape{1, 1};
tensor::TensorPtr y_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeInt32, y_shape);
if (!has_build_appended_input) {
MS_LOG(INFO) << "Build parameter y and y_append.";
auto y = func_graph->add_parameter();
y->set_name("y");
y->set_abstract(y_tensor->ToAbstract());
auto y_append = func_graph->add_parameter();
y_append->set_name("y_append");
y_append->set_abstract(y_tensor->ToAbstract());
has_build_appended_input = true;
}
auto y = func_graph->GetParameterByName("y");
auto y_append = func_graph->GetParameterByName("y_append");
if (used_control_node_ == 0) {
// make add function node
mindspore::PrimitivePtr add_prim = mindspore::prim::kPrimAdd;
add_prim->set_attr("is_load", MakeValue(true));
mindspore::ValueNodePtr add_v_node = std::make_shared<mindspore::ValueNode>(add_prim);
(void)func_graph->AddValueNode(add_v_node);
CNodePtr add_c_node = func_graph->NewCNode({add_v_node, y, y_append});
add_c_node->set_abstract(y_tensor->ToAbstract());
// make equal function node
ValueNodePtr equal_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimEqual);
(void)func_graph->AddValueNode(equal_v_node);
ValueNodePtr equal_compa_node = make_int_node(func_graph, obf_password_ + append_password_);
CNodePtr equal_c_node = func_graph->NewCNode({equal_v_node, add_c_node, equal_compa_node});
tensor::TensorPtr equal_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
equal_c_node->set_abstract(equal_tensor->ToAbstract());
(void)func_graph->AddNode(equal_c_node);
used_control_node_ += 1;
switch_branch_ = true;
return equal_c_node;
}
// make greater function node
int comparison_int = rand();
ValueNodePtr greater_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimGreater);
(void)func_graph->AddValueNode(greater_v_node);
ValueNodePtr greater_compa_node = make_int_node(func_graph, comparison_int);
CNodePtr greater_c_node = func_graph->NewCNode({greater_v_node, y, greater_compa_node});
tensor::TensorPtr greater_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
greater_c_node->set_abstract(greater_tensor->ToAbstract());
(void)func_graph->AddNode(greater_c_node);
used_control_node_ += 1;
switch_branch_ = obf_password_ > comparison_int;
return greater_c_node;
}
mindspore::CNodePtr AddStrideSliceNode(FuncGraphPtr func_graph, ShapeVector begin_vector, ShapeVector stride_vector,
ShapeVector end_vector, int end_mask, int begin_mask,
mindspore::CNodePtr prev_node) {
mindspore::ValueNodePtr begin_v_node = build_tuple_value_node(begin_vector);
mindspore::ValueNodePtr stride_v_node = build_tuple_value_node(stride_vector);
mindspore::ValueNodePtr end_v_node = build_tuple_value_node(end_vector);
(void)func_graph->AddValueNode(begin_v_node);
(void)func_graph->AddValueNode(stride_v_node);
(void)func_graph->AddValueNode(end_v_node);
mindspore::PrimitivePtr slice_prim = mindspore::prim::kPrimStridedSlice;
slice_prim->set_attr("is_load", MakeValue(true));
slice_prim->set_attr("new_axis_mask", MakeValue(int64_t(0)));
slice_prim->set_attr("shrink_axis_mask", MakeValue(int64_t(1)));
slice_prim->set_attr("end_mask", MakeValue(int64_t(end_mask)));
slice_prim->set_attr("begin_mask", MakeValue(int64_t(begin_mask)));
slice_prim->set_attr("ellipsis_mask", MakeValue(int64_t(0)));
mindspore::ValueNodePtr slice_v_node = std::make_shared<mindspore::ValueNode>(slice_prim);
(void)func_graph->AddValueNode(slice_v_node);
mindspore::CNodePtr slice_c_node =
func_graph->NewCNode({slice_v_node, prev_node, begin_v_node, end_v_node, stride_v_node});
return slice_c_node;
}
CNodePtr DynamicObfuscator::custom_op_mode_control(FuncGraphPtr func_graph, AnfNodePtr prev_node) {
mindspore::PrimitivePtr reshape_prim = mindspore::prim::kPrimReshape;
reshape_prim->set_attr("is_load", MakeValue(true));
mindspore::ValueNodePtr reshape_v_node = std::make_shared<mindspore::ValueNode>(reshape_prim);
(void)func_graph->AddValueNode(reshape_v_node);
ShapeVector prev_node_shape = get_node_shape(prev_node);
int shape_multiply = std::accumulate(prev_node_shape.begin(), prev_node_shape.end(), 1, std::multiplies<int>());
MS_LOG(INFO) << "The shape_multiply is: " << shape_multiply;
ShapeVector flat_shape{1, shape_multiply};
mindspore::ValueNodePtr shape_v_node = std::make_shared<mindspore::ValueNode>(MakeValue(flat_shape));
(void)func_graph->AddValueNode(shape_v_node);
mindspore::CNodePtr reshape_c_node = func_graph->NewCNode({reshape_v_node, prev_node, shape_v_node});
TypeId data_type = get_node_dtype(prev_node);
auto reshape_abstract = std::make_shared<Tensor>(data_type, flat_shape)->ToAbstract();
reshape_c_node->set_abstract(reshape_abstract);
(void)func_graph->AddNode(reshape_c_node);
// the first stride_slice x[0]
ShapeVector begin_1{0, 0};
ShapeVector stride_1{1, 1};
mindspore::CNodePtr slice_c_node_1 =
AddStrideSliceNode(func_graph, begin_1, stride_1, flat_shape, 2, 2, reshape_c_node);
ShapeVector slice_1_shape{shape_multiply};
slice_c_node_1->set_abstract(std::make_shared<Tensor>(data_type, slice_1_shape)->ToAbstract());
(void)func_graph->AddNode(slice_c_node_1);
// the first stride_slice x[0][0]
ShapeVector begin_2{0};
ShapeVector end_2{1};
ShapeVector stride_2{1};
mindspore::CNodePtr slice_c_node_2 =
AddStrideSliceNode(func_graph, begin_2, stride_2, stride_2, 0, 0, slice_c_node_1);
ShapeVector slice_2_shape{1};
slice_c_node_2->set_abstract(std::make_shared<Tensor>(data_type, slice_2_shape)->ToAbstract());
(void)func_graph->AddNode(slice_c_node_2);
// the second stride_slice x[0][1]
ShapeVector begin_3{1};
ShapeVector end_3{1};
ShapeVector stride_3{2};
mindspore::CNodePtr slice_c_node_3 =
AddStrideSliceNode(func_graph, begin_3, stride_3, stride_3, 0, 0, slice_c_node_1);
ShapeVector slice_3_shape{1};
slice_c_node_3->set_abstract(std::make_shared<Tensor>(data_type, slice_3_shape)->ToAbstract());
(void)func_graph->AddNode(slice_c_node_3);
// add opaque predicate
PrimitivePtr custom_prim = mindspore::prim::kPrimOpaquePredicate;
custom_prim->set_attr("is_load", MakeValue(true));
std::vector<ValuePtr> input_names_value;
input_names_value.push_back(std::make_shared<StringImm>("x"));
input_names_value.push_back(std::make_shared<StringImm>("y"));
custom_prim->set_attr("input_names", std::make_shared<ValueList>(input_names_value));
std::vector<ValuePtr> output_names_value;
output_names_value.push_back(std::make_shared<StringImm>("output"));
custom_prim->set_attr("output_names", std::make_shared<ValueList>(output_names_value));
auto opaque_v_node = std::make_shared<mindspore::ValueNode>(custom_prim);
(void)func_graph->AddValueNode(opaque_v_node);
auto opaque_c_node = func_graph->NewCNode({opaque_v_node, slice_c_node_2, slice_c_node_3});
ShapeVector y_shape{1, 1};
auto bool_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
opaque_c_node->set_abstract(bool_tensor->ToAbstract());
(void)func_graph->AddNode(opaque_c_node);
return opaque_c_node;
}
CNodePtr DynamicObfuscator::get_control_node(FuncGraphPtr func_graph, AnfNodePtr prev_node) {
if (obf_password_ != 0) {
MS_LOG(INFO) << "Run password mode.";
return password_mode_control(func_graph);
}
MS_LOG(INFO) << "Run customized function mode.";
return custom_op_mode_control(func_graph, prev_node);
}
void DynamicObfuscator::replace_matmul_node(CNodePtr node, FuncGraphPtr func_graph, CNodePtr control_node) {
std::vector<AnfNodePtr> node_inputs = node->cast<mindspore::CNodePtr>()->inputs();
mindspore::ValueNodePtr matmul_v_node = node_inputs[0]->cast<mindspore::ValueNodePtr>();
mindspore::AnfNodePtr input_1 = node_inputs[1];
mindspore::AnfNodePtr input_2 = node_inputs[2];
// construct branch 1
mindspore::FuncGraphPtr fg_1 = std::make_shared<FuncGraph>();
// input_x
ParameterPtr branch_1_input_x = fg_1->add_parameter();
branch_1_input_x->set_abstract(input_1->abstract());
branch_1_input_x->set_name("branch_1_input_x");
// input_y
ParameterPtr branch_1_input_y = fg_1->add_parameter();
branch_1_input_y->set_abstract(input_2->abstract());
branch_1_input_y->set_name("branch_1_input_y");
mindspore::CNodePtr matmul_c_node_1 = fg_1->NewCNode({matmul_v_node, branch_1_input_x, branch_1_input_y});
matmul_c_node_1->set_abstract(node->cast<mindspore::CNodePtr>()->abstract());
(void)fg_1->AddNode(matmul_c_node_1);
// add return node
mindspore::ValueNodePtr return_v_node_1 = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimReturn);
(void)fg_1->AddValueNode(return_v_node_1);
mindspore::CNodePtr branch_1_return = fg_1->NewCNode({return_v_node_1, matmul_c_node_1});
(void)fg_1->AddNode(branch_1_return);
fg_1->set_return(branch_1_return);
fg_1->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
mindspore::ValueNodePtr partial_v_node_1 = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimPartial);
(void)func_graph->AddValueNode(partial_v_node_1);
mindspore::ValueNodePtr fg_1_node = std::make_shared<mindspore::ValueNode>(fg_1);
fg_1_node->set_abstract(fg_1->ToAbstract());
(void)func_graph->AddValueNode(fg_1_node);
mindspore::CNodePtr partial_c_node_1 = func_graph->NewCNode({partial_v_node_1, fg_1_node, input_1, input_2});
(void)func_graph->AddNode(partial_c_node_1);
// construct branch 2
mindspore::FuncGraphPtr fg_2 = std::make_shared<FuncGraph>();
// add input_x
ParameterPtr branch_2_input_x = fg_2->add_parameter();
branch_2_input_x->set_abstract(input_1->abstract());
branch_2_input_x->set_name("branch_2_input_x");
// add input_y
ParameterPtr branch_2_input_y = fg_2->add_parameter();
branch_2_input_y->set_abstract(input_2->abstract());
branch_2_input_y->set_name("branch_2_input_y");
// add matmul CNode
mindspore::CNodePtr matmul_c_node_2 = fg_2->NewCNode({matmul_v_node, branch_2_input_x, branch_2_input_y});
matmul_c_node_2->set_abstract(node->cast<mindspore::CNodePtr>()->abstract());
(void)fg_2->AddNode(matmul_c_node_2);
// add return node
mindspore::ValueNodePtr return_v_node_2 = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimReturn);
(void)fg_2->AddValueNode(return_v_node_2);
mindspore::CNodePtr branch_2_return = fg_2->NewCNode({return_v_node_2, matmul_c_node_2});
(void)fg_2->AddNode(branch_2_return);
fg_2->set_return(branch_2_return);
fg_2->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
// add partial for branch_2
ShapeVector matmul_2_shape = get_node_shape(input_2);
TypeId type_id = get_node_dtype(input_2);
tensor::TensorPtr matmul_2_weight = make_weight_tensor(type_id, matmul_2_shape);
mindspore::ValueNodePtr matmul_weight_v_node = std::make_shared<mindspore::ValueNode>(matmul_2_weight);
matmul_weight_v_node->set_abstract(matmul_2_weight->ToAbstract());
(void)func_graph->AddValueNode(matmul_weight_v_node);
mindspore::ValueNodePtr partial_v_node_2 = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimPartial);
(void)func_graph->AddValueNode(partial_v_node_2);
mindspore::ValueNodePtr fg_2_node = std::make_shared<mindspore::ValueNode>(fg_2);
fg_2_node->set_abstract(fg_2->ToAbstract());
(void)func_graph->AddValueNode(fg_2_node);
mindspore::CNodePtr partial_c_node_2 =
func_graph->NewCNode({partial_v_node_2, fg_2_node, input_1, matmul_weight_v_node});
(void)func_graph->AddNode(partial_c_node_2);
// add switch node
mindspore::ValueNodePtr switch_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimSwitch);
(void)func_graph->AddValueNode(switch_v_node);
mindspore::CNodePtr switch_c_node;
if (obf_password_ == 0) {
int results_len = customized_func_results_.size();
switch_branch_ = customized_func_results_[results_len - 1 - used_control_node_];
used_control_node_ += 1;
}
if (switch_branch_) {
switch_c_node = func_graph->NewCNode({switch_v_node, control_node, partial_c_node_1, partial_c_node_2});
} else {
switch_c_node = func_graph->NewCNode({switch_v_node, control_node, partial_c_node_2, partial_c_node_1});
}
func_graph->AddNode(switch_c_node);
// add call node
mindspore::CNodePtr call_cnode = func_graph->NewCNode({switch_c_node});
func_graph->AddNode(call_cnode);
// add fg_1 and fg_2 to func_graph
auto mgr = mindspore::Manage(func_graph);
mgr->AddFuncGraph(fg_1);
mgr->AddFuncGraph(fg_2);
mgr->Replace(node, call_cnode);
}
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_DYNAMIC_OBFUSCATION_H
#define MINDSPORE_DYNAMIC_OBFUSCATION_H
#include <vector>
#include <string>
#include "load_mindir/load_model.h"
#include "include/common/visible.h"
namespace mindspore {
class COMMON_EXPORT DynamicObfuscator {
public:
DynamicObfuscator(const float obf_ratio, const int obf_password, const int append_password)
: obf_ratio_(obf_ratio), obf_password_(obf_password), append_password_(append_password) {}
~DynamicObfuscator() = default;
FuncGraphPtr ObfuscateMindIR(const FuncGraphPtr &func_graph);
private:
void op_wise_fake_branch(FuncGraphPtr func_graph);
std::string single_op_obfuscate_type(AnfNodePtr node);
CNodePtr get_control_node(FuncGraphPtr func_graph, AnfNodePtr prev_node);
CNodePtr password_mode_control(FuncGraphPtr func_graph);
CNodePtr custom_op_mode_control(FuncGraphPtr func_graph, AnfNodePtr prev_node);
void replace_matmul_node(CNodePtr node, FuncGraphPtr func_graph, CNodePtr flag_node);
const float obf_ratio_ = 0.01;
const int obf_password_;
const int append_password_;
bool has_build_appended_input = false;
std::vector<bool> customized_func_results_;
int used_control_node_ = 0;
bool switch_branch_ = true;
const std::vector<std::string> obf_target_op = {"MatMul-op", "Add-op", "Mat-op", "Sub-op", "Softmax-op", "Relu-op"};
};
} // namespace mindspore
#endif // MINDSPORE_DYNAMIC_OBFUSCATION_H

View File

@ -0,0 +1,115 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
#include <algorithm>
#include "utils/info.h"
namespace mindspore {
namespace kernel {
CustomizedOpaquePredicate &CustomizedOpaquePredicate::GetInstance() {
static CustomizedOpaquePredicate instance;
return instance;
}
void CustomizedOpaquePredicate::set_func_names() {
// get the function of get_func_names()
py::gil_scoped_acquire gil_acquire;
static const std::string &module_name = "mindspore.ops.operations._opaque_predicate_registry";
static const std::string &entrance = "get_func_names";
py::module module = py::module::import(module_name.c_str());
py::object get_pyfunc_obj = module.attr(entrance.c_str());
if (get_pyfunc_obj.is_none()) {
MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name;
}
py::function get_pyfunc = get_pyfunc_obj.cast<py::function>();
py::tuple func_name_list = get_pyfunc();
// clear old functions
func_names_.clear();
for (size_t i = 0; i < func_name_list.size(); i++) {
func_names_.push_back(py::str(func_name_list[i]));
}
MS_LOG(DEBUG) << "Set function names finished, the number of functions is: " << func_names_.size();
}
const std::vector<std::string> CustomizedOpaquePredicate::get_func_names() {
if (func_names_.size() == 0) {
MS_LOG(EXCEPTION) << "The number of customized function names is zero, get function names failed.";
}
return func_names_;
}
py::function CustomizedOpaquePredicate::get_function() {
py::gil_scoped_acquire gil_acquire;
// get the function of get_opaque_predicate()
static const std::string &module_name = "mindspore.ops.operations._opaque_predicate_registry";
static const std::string &entrance = "get_opaque_predicate";
py::module module = py::module::import(module_name.c_str());
py::object get_pyfunc_obj = module.attr(entrance.c_str());
if (get_pyfunc_obj.is_none()) {
MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name;
}
py::function get_pyfunc = get_pyfunc_obj.cast<py::function>();
MS_LOG(DEBUG) << "The number of function is : " << func_names_.size();
if (func_names_.size() == 0) {
MS_EXCEPTION(ValueError) << "The customized_func is not set, please set it in load().";
}
std::string func_name = func_names_[0];
MS_LOG(DEBUG) << "Get function name: " << func_name;
func_name_code_.clear();
std::transform(func_name.begin(), func_name.end(), std::back_inserter(func_name_code_),
[](const char &item) { return static_cast<int>(item); });
if (func_name_code_.size() == 0) {
MS_EXCEPTION(ValueError) << "Set func_name_code_ failed.";
}
py::object py_func_obj = get_pyfunc(py::str(func_name));
if (py_func_obj.is_none()) {
MS_EXCEPTION(ValueError) << "Cannot find python func with name: " << func_name;
}
return py_func_obj.cast<py::function>();
}
bool CustomizedOpaquePredicate::run_function(float x, float y) {
if (Py_IsInitialized() != true) {
MS_LOG(ERROR) << "Py_IsInitialized failed.";
return false;
}
py::object customized_func = get_function();
py::gil_scoped_acquire gil_acquire;
int inputs_num = 2;
py::tuple inputs(inputs_num);
inputs[0] = py::float_(x);
inputs[1] = py::float_(y);
py::object result = customized_func(*inputs);
if (result.is_none()) {
MS_EXCEPTION(ValueError) << "Computing result of customized_func is None, please check it.";
}
bool bool_result = py::cast<bool>(result);
int even_num = 2;
if (func_name_code_[calling_count_ % func_name_code_.size()] % even_num == 0) {
calling_count_ += 1;
return bool_result;
}
calling_count_ += 1;
return !bool_result;
}
void CustomizedOpaquePredicate::init_calling_count() {
this->calling_count_ = 0;
MS_LOG(INFO) << "calling_count_ has been initialized to " << calling_count_;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H
#define MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H
#include <vector>
#include <utility>
#include <string>
#include <memory>
#include <map>
#include <mutex>
#include <unordered_map>
#include <list>
#include "pybind11/pybind11.h"
#include <Python.h>
#include "pybind11/numpy.h"
#include "include/common/visible.h"
namespace py = pybind11;
namespace mindspore {
namespace kernel {
class COMMON_EXPORT CustomizedOpaquePredicate {
public:
static CustomizedOpaquePredicate &GetInstance();
void set_func_names();
const std::vector<std::string> get_func_names();
bool run_function(float x, float y);
py::function get_function();
void init_calling_count();
private:
CustomizedOpaquePredicate() : func_names_({}) {}
~CustomizedOpaquePredicate() = default;
std::vector<std::string> func_names_;
int calling_count_ = 0;
std::vector<int> func_name_code_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H

View File

@ -56,6 +56,7 @@
#include "ops/grad/max_pool_grad_with_argmax.h" #include "ops/grad/max_pool_grad_with_argmax.h"
#include "ops/max_pool_with_argmax.h" #include "ops/max_pool_with_argmax.h"
#include "ops/mirror_pad.h" #include "ops/mirror_pad.h"
#include "ops/opaquePredicate.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
@ -323,6 +324,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimIdentity, R{InferImplIdentity, nullptr, true}}, {prim::kPrimIdentity, R{InferImplIdentity, nullptr, true}},
{prim::kPrimLoad, R{InferImplLoad, nullptr, true}}, {prim::kPrimLoad, R{InferImplLoad, nullptr, true}},
{prim::kPrimMutable, R{InferImplMutable, nullptr, true}}, {prim::kPrimMutable, R{InferImplMutable, nullptr, true}},
{prim::kPrimOpaquePredicate, R{ops::OpaquePredicateInfer, nullptr, true}},
// Set impl to null as it will use PartialEvaluator; // Set impl to null as it will use PartialEvaluator;
{prim::kPrimPartial, R{nullptr, nullptr, true}}, {prim::kPrimPartial, R{nullptr, nullptr, true}},
{prim::kPrimEnvironCreate, R{InferImplEnvironCreate, nullptr, true}}, {prim::kPrimEnvironCreate, R{InferImplEnvironCreate, nullptr, true}},

View File

@ -1376,6 +1376,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicLossScale, std::make_shared<Primitive>("_Dyna
GVAR_DEF(PrimitivePtr, kPrimScaleGrad, std::make_shared<Primitive>("ScaleGrad")); GVAR_DEF(PrimitivePtr, kPrimScaleGrad, std::make_shared<Primitive>("ScaleGrad"));
GVAR_DEF(PrimitivePtr, kPrimPopulationCount, std::make_shared<Primitive>("PopulationCount")); GVAR_DEF(PrimitivePtr, kPrimPopulationCount, std::make_shared<Primitive>("PopulationCount"));
GVAR_DEF(PrimitivePtr, kPrimBlackmanWindow, std::make_shared<Primitive>("BlackmanWindow")); GVAR_DEF(PrimitivePtr, kPrimBlackmanWindow, std::make_shared<Primitive>("BlackmanWindow"));
GVAR_DEF(PrimitivePtr, kPrimOpaquePredicate, std::make_shared<Primitive>("OpaquePredicate"));
// Structures // Structures
GVAR_DEF(PrimitivePtr, kPrimMakeList, std::make_shared<Primitive>("make_list")); GVAR_DEF(PrimitivePtr, kPrimMakeList, std::make_shared<Primitive>("make_list"));

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/opaquePredicate.h"
#include <map>
#include <string>
#include <vector>
#include <algorithm>
#include <memory>
#include <complex>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
abstract::ShapePtr OpaquePredicateInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}
TypePtr OpaquePredicateInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim->name(), input_args, 0);
auto y = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim->name(), input_args, 1);
(void)abstract::CheckDtypeSame(prim->name(), x, y);
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat, kFloat16, kUInt16,
kFloat64, kUInt8, kBool, kComplex64, kComplex128, kUInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("y", input_args[1]->BuildType(), valid_types, prim->name());
return std::make_shared<TensorType>(kBool);
}
MIND_API_OPERATOR_IMPL(OpaquePredicate, BaseOperator);
AbstractBasePtr OpaquePredicateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto shape = OpaquePredicateInferShape(primitive, input_args);
auto type = OpaquePredicateInferType(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_C(kNameOpaquePredicate, OpaquePredicate);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_
#define MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameOpaquePredicate = "OpaquePredicate";
/// \brief The opaque predicate used for dynamic obfuscation
class MIND_API OpaquePredicate : public BaseOperator {
public:
MIND_API_BASE_MEMBER(OpaquePredicate);
OpaquePredicate() : BaseOperator(kNameOpaquePredicate) { InitIOName({"x", "y"}, {"output"}); }
void Init() const {}
};
abstract::AbstractBasePtr OpaquePredicateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_

View File

@ -26,6 +26,7 @@ import inspect
import importlib import importlib
from collections import OrderedDict from collections import OrderedDict
from functools import wraps from functools import wraps
import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
@ -46,7 +47,6 @@ from mindspore._checkparam import Validator
from mindspore.common._utils import is_shape_unknown from mindspore.common._utils import is_shape_unknown
from mindspore.common.mutable import mutable from mindspore.common.mutable import mutable
# store ms_function class compiled pipeline cache # store ms_function class compiled pipeline cache
ms_compile_cache = set() ms_compile_cache = set()
# store cell compiled pipeline cache, # store cell compiled pipeline cache,
@ -670,6 +670,7 @@ class _MsFunctionCompileContext:
""" """
ms_function compile status manager ms_function compile status manager
""" """
def __init__(self): def __init__(self):
pass pass
@ -1064,10 +1065,12 @@ class _CellGraphExecutor:
Returns: Returns:
Graph, return the result of pipeline running. Graph, return the result of pipeline running.
""" """
def __init__(self): def __init__(self):
# create needed graph by lazy mode # create needed graph by lazy mode
self.is_init = False self.is_init = False
self.enable_tuple_broaden = False self.enable_tuple_broaden = False
self.obfuscate_config = None # used for model's dynamic obfuscation
self._graph_executor = GraphExecutor_.get_instance() self._graph_executor = GraphExecutor_.get_instance()
self._graph_executor.set_py_exe_path(sys.executable) self._graph_executor.set_py_exe_path(sys.executable)
self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep) self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
@ -1238,8 +1241,8 @@ class _CellGraphExecutor:
return self._graph_executor.get_allreduce_fusion(real_phase) return self._graph_executor.get_allreduce_fusion(real_phase)
def __call__(self, obj, *args, phase='predict'): def __call__(self, obj, *args, phase='predict'):
if context.get_context("precompile_only") or\ if context.get_context("precompile_only") or \
(_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched(): (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
return None return None
return self.run(obj, *args, phase=phase) return self.run(obj, *args, phase=phase)
@ -1290,6 +1293,21 @@ class _CellGraphExecutor:
exec_id = exec_id + '.' + obj.arguments_key exec_id = exec_id + '.' + obj.arguments_key
if self._graph_executor.has_compiled(exec_id) is False: if self._graph_executor.has_compiled(exec_id) is False:
return None return None
if self.obfuscate_config is not None:
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
'obf_password' not in self.obfuscate_config.keys()):
raise ValueError("'obf_ratio' and 'obf_password' must be in obfuscate_config.")
obf_password = self.obfuscate_config.get('obf_password')
if obf_password == 0:
append_password = 0
else:
seed_max = 2 ** 32 - 1
int_max = 2 ** 31 - 1
np.random.seed(obf_password % seed_max)
append_password = np.random.randint(int_max)
obf_password %= int_max
return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, self.obfuscate_config['obf_ratio'],
obf_password, append_password)
return self._graph_executor.get_func_graph_proto(exec_id, ir_type) return self._graph_executor.get_func_graph_proto(exec_id, ir_type)
def get_optimize_graph_proto(self, obj): def get_optimize_graph_proto(self, obj):

View File

@ -2219,18 +2219,18 @@ class Cell(Cell_):
if isinstance(set_input, Tensor): if isinstance(set_input, Tensor):
if not isinstance(net_input, Tensor): if not isinstance(net_input, Tensor):
raise TypeError( raise TypeError(
f"The {index+1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.") f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
if set_input.dtype is not net_input.dtype: if set_input.dtype is not net_input.dtype:
raise ValueError( raise ValueError(
f"The {index+1}th input type of 'set_inputs' must be the same as network's input, " f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.") f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
if net_input.dim() != 0 and set_input.dim() != net_input.dim(): if net_input.dim() != 0 and set_input.dim() != net_input.dim():
raise ValueError( raise ValueError(
f"The {index+1}th input dims of 'set_inputs' must be the same as network's input, " f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.") f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]): if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
raise ValueError( raise ValueError(
f"The {index+1}th input shape of 'set_inputs' must be the same as network's input, " f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, "
f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.") f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.")
@ -2247,6 +2247,11 @@ class GraphCell(Cell):
The key is the parameter name whose type is str, and the value is a Tensor or Parameter. The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
If the parameter exists in the graph according to the name, update it's value. If the parameter exists in the graph according to the name, update it's value.
If the parameter does not exist, ignore it. Default: None. If the parameter does not exist, ignore it. Default: None.
obf_password (int): The password used for dynamic obfuscation. "dynamic obfuscation" is used for model
protection, which can refer to `mindspore.train.serialization.obfuscate_model()`. If the input 'graph' is a
func_graph loaded from a mindir file obfuscated in password mode, then obf_password should be provided.
obf_password should be larger than zero and less or equal than int_64 (9223372036854775807). default: None.
Raises: Raises:
TypeError: If the `graph` is not a FuncGraph. TypeError: If the `graph` is not a FuncGraph.
TypeError: If the `params_init` is not a dict. TypeError: If the `params_init` is not a dict.
@ -2273,13 +2278,19 @@ class GraphCell(Cell):
[6. 9. 6.] [6. 9. 6.]
[4. 6. 4.]]]] [4. 6. 4.]]]]
""" """
def __init__(self, graph, params_init=None):
def __init__(self, graph, params_init=None, obf_password=None):
super(GraphCell, self).__init__(auto_prefix=True) super(GraphCell, self).__init__(auto_prefix=True)
if not isinstance(graph, FuncGraph): if not isinstance(graph, FuncGraph):
raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, " raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
f"but got type {type(graph)}.") f"but got type {type(graph)}.")
self.graph = graph self.graph = graph
self.obf_password = obf_password
int_64_max = 9223372036854775807
if (obf_password is not None) and (obf_password <= 0 or obf_password > int_64_max):
raise ValueError(
"'obf_password' must be larger than 0, and less or equal than int64 ({}),"
"but got {}.".format(int_64_max, obf_password))
params_init = {} if params_init is None else params_init params_init = {} if params_init is None else params_init
if not isinstance(params_init, dict): if not isinstance(params_init, dict):
raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.") raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
@ -2299,7 +2310,21 @@ class GraphCell(Cell):
def __call__(self, *inputs): def __call__(self, *inputs):
self.phase = "graph_load_from_mindir" self.phase = "graph_load_from_mindir"
self._add_attr("graph_load_from_mindir", self.graph) self._add_attr("graph_load_from_mindir", self.graph)
return self.compile_and_run(*inputs) if not self.obf_password:
return self.compile_and_run(*inputs)
append_input_1, append_input_2 = _obf_appended_inputs(self.obf_password)
return self.compile_and_run(*inputs, append_input_1, append_input_2)
def _obf_appended_inputs(obf_password):
seed_max = 2 ** 32 - 1
int_max = 2 ** 31 - 1
numpy.random.seed(obf_password % seed_max)
append_password = numpy.random.randint(int_max)
obf_password %= int_max
append_input_1 = Tensor((numpy.ones((1, 1)) * obf_password).astype(numpy.int32))
append_input_2 = Tensor((numpy.ones((1, 1)) * append_password).astype(numpy.int32))
return append_input_1, append_input_2
def _check_param_list_tuple(value): def _check_param_list_tuple(value):

View File

@ -60,3 +60,13 @@ class PyFuncRegistry(UserDict):
if key not in self: if key not in self:
raise ValueError(f"Python function with key{key} not registered.") raise ValueError(f"Python function with key{key} not registered.")
return self[key] return self[key]
class OpaquePredicateRegistry(PyFuncRegistry):
def __init__(self):
super(OpaquePredicateRegistry, self).__init__()
self.func_names = []
def register(self, key, value):
self[key] = value
self.func_names.append(key)

View File

@ -0,0 +1,37 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Register pyfunc for opaque_func_cpu_kernel"""
from mindspore.ops._register_for_op import OpaquePredicateRegistry
registered_func_name = OpaquePredicateRegistry()
def add_opaque_predicate(fn_name, func):
registered_func_name.register(fn_name, func)
def get_opaque_predicate(fn_name):
return registered_func_name.get(fn_name)
def get_func_names():
return registered_func_name.func_names
def clean_funcs():
registered_func_name.func_names = []

View File

@ -26,7 +26,7 @@ from mindspore.train.amp import build_train_network
from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \
load, parse_print, build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, \ load, parse_print, build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, \
async_ckpt_thread_status, restore_group_info_list, convert_model async_ckpt_thread_status, restore_group_info_list, convert_model, obfuscate_model
from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \ from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \
CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, \ CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, \
History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore
@ -39,7 +39,7 @@ __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "bui
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter", "load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
"load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list", "convert_model", "load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list", "convert_model",
"data_sink"] "data_sink", "obfuscate_model"]
__all__.extend(callback.__all__) __all__.extend(callback.__all__)
__all__.extend(summary.__all__) __all__.extend(summary.__all__)
__all__.extend(train_thor.__all__) __all__.extend(train_thor.__all__)

View File

@ -55,11 +55,11 @@ from mindspore.parallel._cell_wrapper import get_allgather_cell
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy,\ from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
_restore_group_info_list _restore_group_info_list
from mindspore.train._utils import read_proto from mindspore.train._utils import read_proto
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16, tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64, "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
@ -384,6 +384,17 @@ def _check_append_dict(append_dict):
return append_dict return append_dict
def _check_load_obfuscate(**kwargs):
if 'obf_func' in kwargs.keys():
customized_func = kwargs.get('obf_func')
if not callable(customized_func):
raise ValueError("obf_func must be a callable function, but got a {}.".format(type(customized_func)))
clean_funcs()
add_opaque_predicate(customized_func.__name__, customized_func)
return True
return False
def load(file_name, **kwargs): def load(file_name, **kwargs):
""" """
Load MindIR. Load MindIR.
@ -436,6 +447,9 @@ def load(file_name, **kwargs):
"please check whether the 'file_name' is correct.") "please check whether the 'file_name' is correct.")
file_name = os.path.realpath(file_name) file_name = os.path.realpath(file_name)
# set customized functions for dynamic obfuscation
obfuscated = _check_load_obfuscate(**kwargs)
logger.info("Execute the process of loading mindir.") logger.info("Execute the process of loading mindir.")
if 'dec_key' in kwargs.keys(): if 'dec_key' in kwargs.keys():
dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes) dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
@ -448,9 +462,9 @@ def load(file_name, **kwargs):
else: else:
dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str) dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode, graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
decrypt=dec_func) decrypt=dec_func, obfuscated=obfuscated)
else: else:
graph = load_mindir(file_name) graph = load_mindir(file_name, obfuscated=obfuscated)
if graph is None: if graph is None:
if _is_cipher_file(file_name): if _is_cipher_file(file_name):
@ -461,6 +475,152 @@ def load(file_name, **kwargs):
return graph return graph
def _check_param_type(param_config, key, target_type, requested):
"""check type of parameters"""
if key in param_config:
if not isinstance(param_config[key], target_type):
raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
return param_config[key]
if requested:
raise ValueError("The parameter {} is requested, but not got.".format(key))
if key == "obf_password":
return 0
return None
def _check_obfuscate_params(obf_config):
"""check obfuscation parameters, including obf_password, obf_ratio, customized_func"""
if 'obf_password' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
raise ValueError(
"At least one of 'obf_password' or 'customized_func' must be set in obf_config, but got None of them.")
obfuscate_type = _check_param_type(obf_config, "type", str, False)
if obfuscate_type not in (None, "dynamic"):
raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
obf_config['obf_ratio']))
ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
if (obf_ratio <= 0) or (obf_ratio > 1):
raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
customized_funcs = []
if 'customized_func' in obf_config.keys():
if callable(obf_config['customized_func']):
customized_funcs.append(obf_config['customized_func'])
else:
raise TypeError(
"'customized_func' must be a function, but not got {}.".format(type(obf_config['customized_func'])))
obf_password = _check_param_type(obf_config, "obf_password", int, False)
int_64_max = 9223372036854775807
if obf_password > int_64_max:
raise ValueError(
"'obf_password' must be less or equal than int64 ({}), but got {}.".format(int_64_max, obf_password))
return obf_ratio, customized_funcs, obf_password
def obfuscate_model(obf_config, **kwargs):
"""
Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
predict correctness. The obfuscated model can prevent attackers from stealing the model.
Args:
obf_config (dict): obfuscation config.
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
- original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
model is encrypted, then enc_key and enc_mode should be provided.
- save_model_path (str): The path to save the obfuscated model.
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
is the same as using `export()`.
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio` should
be in range of (0, 1] or in ["small", "medium", "large"].
- customized_func (function): A python function used for customized function mode, which used for control
the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
function needs to ensure that its result is constant for any input. Users can refer to opaque
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
obfuscated model.
- obf_password (int): A password used for password mode, which should be larger than zero. If
obf_password is set, then it should be passed to `nn.GraphCell()` interface when loading obfuscated
model. It should be noted that at least one of 'customized_func' or 'obf_password' should be set, and
'obf_password' mode would be applied if both of them are set.
kwargs (dict): Configuration options dictionary.
- enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
- enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
Raises:
TypeError: If obf_config is not a dict.
ValueError: If enc_key is passed and enc_mode is not in ["AES-GCM", "AES-CBC"].
ValueError: If original_model_path is not provided in obf_config.
ValueError: If the model saved in original_model_path has been obfuscated.
ValueError: If save_model_path is not provided in obf_config.
ValueError: If obf_ratio is not provided in obf_config.
ValueError: If both customized_func and obf_password are not provided in obf_config.
ValueError: If both obf_password is not in (0, 9223372036854775807].
Examples:
>>> obf_config = {'original_model_path': "./net.mindir",
... 'save_model_path': "./obf_net",
... 'model_inputs': [input1, ],
... 'obf_ratio': 0.1, 'obf_password': 173262358423}
>>> obfuscate_model(obf_config)
>>> obf_func = load("obf_net.mindir")
>>> obf_net = nn.GraphCell(obf_func, obf_password=173262358423)
>>> print(obf_net(input1).asnumpy())
"""
if not isinstance(obf_config, dict):
raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
file_path = _check_param_type(obf_config, "original_model_path", str, True)
saved_path = _check_param_type(obf_config, "save_model_path", str, True)
model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
for item in model_inputs:
if not isinstance(item, Tensor):
raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
obf_ratio, customized_funcs, obf_password = _check_obfuscate_params(obf_config)
if customized_funcs and obf_password > 0:
logger.warning("Although customized_func and obf_password are set, the 'obf_password' mode would be"
"applied, remember to set obf_password when loading obfuscated model.")
if obf_password == 0: # apply customized_func mode
clean_funcs()
for func in customized_funcs:
add_opaque_predicate(func.__name__, func)
append_password = 0
else:
seed_max = 2 ** 32 - 1
int_max = 2 ** 31 - 1
np.random.seed(obf_password % seed_max)
append_password = np.random.randint(int_max)
obf_password %= int_max
if 'enc_key' in kwargs.keys():
enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
enc_mode = "AES-GCM"
if 'enc_mode' in kwargs.keys():
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
if enc_mode not in ["AES-GCM", "AES-CBC"]:
raise ValueError(
"Only MindIR files that encrypted with 'AES-GCM' or 'AES-CBC' is supported for obfuscate_model(),"
" but got {}.".format(enc_mode))
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio, obf_password=obf_password,
append_password=append_password, dec_key=enc_key, key_len=len(enc_key),
dec_mode=enc_mode)
else:
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio, obf_password=obf_password,
append_password=append_password)
obf_net = nn.GraphCell(obf_graph)
if obf_password != 0:
y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
model_inputs += [y_tensor, append_y_tensor]
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
dec_key=None, dec_mode="AES-GCM", specify_prefix=None): dec_key=None, dec_mode="AES-GCM", specify_prefix=None):
""" """
@ -524,7 +684,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
np_type = tensor_to_np_type.get(data_type) np_type = tensor_to_np_type.get(data_type)
ms_type = tensor_to_ms_type[data_type] ms_type = tensor_to_ms_type[data_type]
if data_type == 'str': if data_type == 'str':
str_length = int(len(data)/4) str_length = int(len(data) / 4)
np_type = np_type + str(str_length) np_type = np_type + str(str_length)
element_data = np.frombuffer(data, np_type) element_data = np.frombuffer(data, np_type)
param_data_list.append(element_data) param_data_list.append(element_data)
@ -549,7 +709,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
except BaseException as e: except BaseException as e:
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name) logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', " raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
"failed to load the checkpoint file {}.".format(ckpt_file_name)) from e "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
if not parameter_dict: if not parameter_dict:
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether " raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
@ -617,10 +777,10 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
except BaseException as e: except BaseException as e:
if _is_cipher_file(ckpt_file_name): if _is_cipher_file(ckpt_file_name):
err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \ err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \
"please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name) "please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name)
else: else:
err_info = "Failed to read the checkpoint file {}. May not have permission to read it, please check" \ err_info = "Failed to read the checkpoint file {}. May not have permission to read it, please check" \
" the correct of the file.".format(ckpt_file_name) " the correct of the file.".format(ckpt_file_name)
logger.error(err_info) logger.error(err_info)
raise ValueError(err_info) from e raise ValueError(err_info) from e
return checkpoint_list return checkpoint_list
@ -885,6 +1045,20 @@ def export(net, *inputs, file_name, file_format, **kwargs):
- For details of using the customized encryption, please check the `tutorial - For details of using the customized encryption, please check the `tutorial
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_. <https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
- obf_config (dict): obfuscation config.
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
should be in range of (0, 1] or in ["small", "medium", "large"].
- customized_func (function): A python function used for customized function mode, which used for control
the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
function needs to ensure that its result is constant for any input. Users can refer to opaque
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
obfuscated model.
- obf_password (int): A password used for password mode, which should be larger than zero. If
obf_password is set, then it should be passed to `nn.GraphCell()` interface when loading obfuscated
model. It should be noted that at least one of 'customized_func' or 'obf_password' should be set, and
'obf_password' mode would be applied if both of them are set.
Examples: Examples:
>>> import mindspore as ms >>> import mindspore as ms
>>> import numpy as np >>> import numpy as np
@ -920,11 +1094,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
file_name = os.path.realpath(file_name) file_name = os.path.realpath(file_name)
net = _quant_export(net, *inputs, file_format=file_format, **kwargs) net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
if 'enc_key' in kwargs.keys(): if 'enc_key' in kwargs.keys():
enc_key, enc_mode = _check_key_mode_type(file_format, **kwargs) kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
dataset = kwargs.get('dataset') _export(net, file_name, file_format, *inputs, **kwargs)
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset)
else:
_export(net, file_name, file_format, *inputs, **kwargs)
def _export(net, file_name, file_format, *inputs, **kwargs): def _export(net, file_name, file_format, *inputs, **kwargs):
@ -1150,13 +1321,40 @@ def _cell_info(net, *inputs):
do_convert=False, auto_parallel_mode=net._auto_parallel_mode) do_convert=False, auto_parallel_mode=net._auto_parallel_mode)
# pylint: disable=protected-access # pylint: disable=protected-access
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
# clean obfuscation config to prevent the next call
_executor.obfuscate_config = None
net_dict = net.parameters_dict() net_dict = net.parameters_dict()
return mindir_stream, net_dict return mindir_stream, net_dict
def _set_obfuscate_config(**kwargs):
"""Set obfuscation config for executor."""
logger.warning("Obfuscate model.")
if 'enc_mode' in kwargs.keys():
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
if enc_mode not in ["AES-GCM", "AES-CBC"]:
raise ValueError(
"Only MindIR files that encrypted with 'AES-GCM' or 'AES-CBC' is supported for obfuscation,"
"but got {}.".format(enc_mode))
obf_ratio, customized_funcs, obf_password = _check_obfuscate_params(kwargs.get('obf_config'))
if customized_funcs and obf_password > 0:
logger.warning("Although customized_func and obf_password are set, the 'obf_password' mode would be"
"applied, remember to set obf_password when loading obfuscated model.")
if obf_password == 0: # apply customized_func mode
clean_funcs()
for func in customized_funcs:
add_opaque_predicate(func.__name__, func)
_executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_password': obf_password}
def _save_mindir(net, file_name, *inputs, **kwargs): def _save_mindir(net, file_name, *inputs, **kwargs):
"""Save MindIR format file.""" """Save MindIR format file."""
# set obfuscate configs
if 'obf_config' in kwargs.keys():
_set_obfuscate_config(**kwargs)
model = mindir_model() model = mindir_model()
if not isinstance(net, nn.Cell): if not isinstance(net, nn.Cell):
mindir_stream, net_dict = _msfunc_info(net, *inputs) mindir_stream, net_dict = _msfunc_info(net, *inputs)
@ -1248,6 +1446,7 @@ def _save_dataset_to_mindir(model, dataset):
def quant_mode_manage(func): def quant_mode_manage(func):
"""Inherit the quant_mode in old version.""" """Inherit the quant_mode in old version."""
@functools.wraps(func) @functools.wraps(func)
def warpper(network, *inputs, file_format, **kwargs): def warpper(network, *inputs, file_format, **kwargs):
if 'quant_mode' not in kwargs: if 'quant_mode' not in kwargs:
@ -1746,8 +1945,8 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}" logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
" and group is {}".format(param.name, split_param.data.shape, opt_shard_group)) " and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice" raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
f" in load distributed checkpoint for {param.name}. Data shape is " f" in load distributed checkpoint for {param.name}. Data shape is "
f"{split_param.data.shape} and group is {opt_shard_group}.") from e f"{split_param.data.shape} and group is {opt_shard_group}.") from e
split_param = Parameter(Tensor(data_slice), param.name, split_param = Parameter(Tensor(data_slice), param.name,
split_param.requires_grad, split_param.layerwise_parallel) split_param.requires_grad, split_param.layerwise_parallel)
param_dict[param.name] = split_param param_dict[param.name] = split_param