From 969e368cc54ee2162a1b45d4eab5a248e3b741d9 Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Tue, 11 Oct 2022 09:44:23 +0800 Subject: [PATCH] Add dynamic obfucation tool --- .jenkins/check/config/filter_cpplint.txt | 2 + .../api_python/mindspore/mindspore.export.rst | 7 + .../mindspore/mindspore.obfuscate_model.rst | 22 + docs/api/api_python_en/mindspore.rst | 1 + mindspore/ccsrc/pipeline/jit/init.cc | 11 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 52 +- mindspore/ccsrc/pipeline/jit/pipeline.h | 8 +- .../cpu/kernel/opaque_predicate_kernel.cc | 69 +++ .../cpu/kernel/opaque_predicate_kernel.h | 57 ++ .../dynamic_obfuscation.cc | 494 ++++++++++++++++++ .../dynamic_obfuscation/dynamic_obfuscation.h | 53 ++ .../registry_opaque_predicate.cc | 115 ++++ .../registry_opaque_predicate.h | 57 ++ .../core/abstract/ops/primitive_infer_map.cc | 2 + mindspore/core/ops/core_ops.h | 1 + mindspore/core/ops/opaquePredicate.cc | 63 +++ mindspore/core/ops/opaquePredicate.h | 41 ++ mindspore/python/mindspore/common/api.py | 24 +- mindspore/python/mindspore/nn/cell.py | 39 +- .../python/mindspore/ops/_register_for_op.py | 10 + .../operations/_opaque_predicate_registry.py | 37 ++ mindspore/python/mindspore/train/__init__.py | 4 +- .../python/mindspore/train/serialization.py | 231 +++++++- 23 files changed, 1367 insertions(+), 33 deletions(-) create mode 100644 docs/api/api_python/mindspore/mindspore.obfuscate_model.rst create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/opaque_predicate_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/opaque_predicate_kernel.h create mode 100644 mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc create mode 100644 mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h create mode 100644 mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.cc create mode 100644 mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h create mode 100644 mindspore/core/ops/opaquePredicate.cc create mode 100644 mindspore/core/ops/opaquePredicate.h create mode 100644 mindspore/python/mindspore/ops/operations/_opaque_predicate_registry.py diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 92e834b5899..a01ec86b4b4 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -23,6 +23,8 @@ "mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include" "mindspore/mindspore/ccsrc/backend/common/optimizer/op_adaptation_info_factory.h" "runtime/explicit" "mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cu" "runtime/int" +"mindspore/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc" "runtime/threadsafe_fn" +"mindspore/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc" "runtime/references" # Modelzoo "mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references" diff --git a/docs/api/api_python/mindspore/mindspore.export.rst b/docs/api/api_python/mindspore/mindspore.export.rst index 775070b63f0..4f4746996ff 100644 --- a/docs/api/api_python/mindspore/mindspore.export.rst +++ b/docs/api/api_python/mindspore/mindspore.export.rst @@ -32,3 +32,10 @@ mindspore.export - 对于'MINDIR'格式的模型,支持的加密选项有:'AES-GCM','AES-CBC'和用户自定义加密算法。默认值:"AES-GCM"。 - 关于使用自定义加密导出的详情,请查看 `教程 `_。 - **dataset** (Dataset) - 指定数据集的预处理方法,用于将数据集的预处理导入MindIR。 + + - **obf_config** (dict) - 模型混淆配置选项字典。 + + - **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。 + - **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。 + - **customized_func** (function) - 在自定义函数模式下需要设置的Python函数,用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型,且是恒定的,用户可以参考不透明谓词进行设置。如果设置了`customized_func`,那么在使用`load`接口导入模型的时候,需要把这个函数也传入。 + - **obf_password** (int) - 秘密口令,用于password模式,是一个大于0的整数。如果用户设置了`obf_password`,那么在部署混淆模型的时候,需要在`nn.GraphCell()`接口中传入`obf_password`。需要注意的是,如果用户同时设置了`customized_func`和`obf_password`,那么password模式将会被采用。 diff --git a/docs/api/api_python/mindspore/mindspore.obfuscate_model.rst b/docs/api/api_python/mindspore/mindspore.obfuscate_model.rst new file mode 100644 index 00000000000..954be4b98a2 --- /dev/null +++ b/docs/api/api_python/mindspore/mindspore.obfuscate_model.rst @@ -0,0 +1,22 @@ +mindspore.obfuscate_model +========================= + +.. py:function:: mindspore.obfuscate_model(obf_config, **kwargs) + + 对MindIR格式的模型进行混淆,混淆主要是修改模型的网络结构但不影响它的推理精度,混淆后的模型可以防止被盗用。 + + 参数: + - **obf_config** (dict) - 模型混淆配置选项字典。 + + - **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。 + - **original_model_path** (str) - 待混淆的MindIR模型地址。如果该模型是加密文件的,则需要在`kwargs`中传入`enc_key`和`enc_mode`。 + - **save_model_path** (str) - 混淆模型的保存地址。 + - **model_inputs** (list[Tensor]) - 模型的推理输入,Tensor的值可以是随机的,和使用`export()`接口类似。 + - **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。 + - **customized_func** (function) - 在自定义函数模式下需要设置的Python函数,用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型,且是恒定的,用户可以参考不透明谓词进行设置。如果设置了`customized_func`,那么在使用`load`接口导入模型的时候,需要把这个函数也传入。 + - **obf_password** (int) - 秘密口令,用于password模式,是一个大于0的整数。如果用户设置了`obf_password`,那么在部署混淆模型的时候,需要在`nn.GraphCell()`接口中传入`obf_password`。需要注意的是,如果用户同时设置了`customized_func`和`obf_password`,那么password模式将会被采用。 + + - **kwargs** (dict) - 配置选项字典。 + + - **enc_key** (str) - 用于加密的字节类型密钥,有效长度为16、24或者32。 + - **enc_mode** (Union[str, function]) - 指定加密模式,当设置 `enc_key` 时启用。支持的加密选项有:'AES-GCM','AES-CBC'。默认值:"AES-GCM"。 diff --git a/docs/api/api_python_en/mindspore.rst b/docs/api/api_python_en/mindspore.rst index 46211dbb3ad..825fcd058e6 100644 --- a/docs/api/api_python_en/mindspore.rst +++ b/docs/api/api_python_en/mindspore.rst @@ -234,6 +234,7 @@ Serialization mindspore.save_checkpoint mindspore.transform_checkpoint_by_rank mindspore.transform_checkpoints + mindspore.obfuscate_model JIT --- diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 192bf43c0a0..44312ccc7cc 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -147,6 +147,9 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_func_graph", &GraphExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") .def("get_func_graph_proto", &GraphExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") + .def("get_obfuscate_func_graph_proto", &GraphExecutorPy::GetObfuscateFuncGraphProto, py::arg("phase") = py::str(""), + py::arg("obf_ratio") = py::float_(1.0), py::arg("obf_pasword") = py::int_(0), + py::arg("append_password") = py::int_(0), "Get graph proto of dynamic-obfuscated model.") .def("get_params", &GraphExecutorPy::GetParams, py::arg("phase") = py::str(""), "Get Parameters from graph") .def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") @@ -200,11 +203,13 @@ PYBIND11_MODULE(_c_expression, m) { py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset."); (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); (void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline."); - (void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr, py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), - py::arg("decrypt") = py::none(), "Load model as Graph."); - + py::arg("decrypt") = py::none(), py::arg("obfuscated") = py::bool_(false), "Load model as Graph."); + (void)m.def("dynamic_obfuscate_mindir", &mindspore::pipeline::DynamicObfuscateMindIR, py::arg("file_name"), + py::arg("obf_ratio"), py::arg("obf_password") = py::int_(0), py::arg("append_password") = py::int_(0), + py::arg("dec_key") = nullptr, py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), + "Obfuscate a mindir model by dynamic obfuscation."); (void)m.def("init_cluster", &mindspore::distributed::Initialize, "Init Cluster"); (void)m.def("set_cluster_exit_with_exception", &mindspore::distributed::set_cluster_exit_with_exception, "Set this process exits with exception."); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 96cf99b2884..c3df935dd95 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -66,6 +66,8 @@ #include "runtime/pynative/op_executor.h" #include "runtime/device/stream_synchronizer.h" #include "distributed/collective/collective_manager.h" +#include "mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h" +#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h" #if defined(__linux__) && defined(WITH_BACKEND) #include "ps/constants.h" @@ -442,6 +444,7 @@ py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std } if (ir_type == IR_TYPE_MINDIR) { + // obfuscate model std::string proto_str = GetBinaryProtoString(fg_ptr); if (proto_str.empty()) { MS_LOG(EXCEPTION) << "Export MINDIR format model failed."; @@ -452,6 +455,24 @@ py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; } +py::bytes GraphExecutorPy::GetObfuscateFuncGraphProto(const std::string &phase, const float obf_ratio, + const int obf_password, const int append_password) { + FuncGraphPtr fg_ptr = GetFuncGraph(phase); + // obfuscate model + if (obf_password == 0) { + (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names(); + MS_LOG(DEBUG) << "[GetObfuscateFuncGraphProto] set customized function names finished"; + } + mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, obf_password, append_password); + mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(fg_ptr); + + std::string proto_str = GetBinaryProtoString(obfuscated_graph); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "GetBinaryProtoString failed."; + } + return proto_str; +} + py::bytes GraphExecutorPy::GetOptimizeGraphProto(const std::string &phase) { if (info_.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor: " << phase; @@ -1255,6 +1276,8 @@ void GraphExecutorPy::TerminateDebugger() { #endif py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) { + // init for dynamic-obfuscated model infer + (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count(); // Mindspore debugger notify main thread to exit after one step, and will not run next step #ifdef ENABLE_DEBUGGER TerminateDebugger(); @@ -1655,7 +1678,12 @@ void GraphExecutorPy::ExportGraph(const std::string &file_name, const std::strin } FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len, - const std::string &dec_mode, const py::object decrypt) { + const std::string &dec_mode, const py::object decrypt, const bool obfuscated) { + if (obfuscated) { + MS_LOG(DEBUG) << "[LoadMindIR] Set customized function."; + (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names(); + (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count(); + } FuncGraphPtr func_graph = nullptr; if (dec_mode == "Customized") { py::bytes key_bytes(dec_key); @@ -1679,6 +1707,28 @@ FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const return func_graph; } +FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int obf_password, + int append_password, char *dec_key, const size_t key_len, + const std::string &dec_mode) { + if (obf_password == 0) { + (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names(); + MS_LOG(DEBUG) << "[DynamicObfuscateMindIR] set function names finished."; + } + mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, obf_password, append_password); + MindIRLoader mindir_loader(false, reinterpret_cast(dec_key), key_len, dec_mode, false); + FuncGraphPtr func_graph = mindir_loader.LoadMindIR(file_name); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "[DynamicObfuscateMindIR] load mindir failed, please check the mindir file."; + return nullptr; + } + mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(func_graph); + if (obfuscated_graph == nullptr) { + MS_LOG(ERROR) << "[DynamicObfuscateMindIR] obfuscate model failed."; + return nullptr; + } + return obfuscated_graph; +} + void CloseTsd(bool force) { #ifdef WITH_BACKEND auto context_ptr = MsContext::GetInstance(); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 0f7d996404c..3762e4dfa1f 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -88,6 +88,8 @@ class GraphExecutorPy : public std::enable_shared_from_this { FuncGraphPtr GetGradGraph(const std::string &phase); void SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase); py::bytes GetFuncGraphProto(const std::string &phase, const std::string &ir_type); + py::bytes GetObfuscateFuncGraphProto(const std::string &phase, const float obf_ratio, const int obf_password, + const int append_password); #ifndef ENABLE_SECURITY py::bytes GetOptimizeGraphProto(const std::string &phase); #endif @@ -185,7 +187,8 @@ void CloseTsd(bool force = false); void MemoryRecycle(); FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len, - const std::string &dec_mode, const py::object decrypt = py::none()); + const std::string &dec_mode, const py::object decrypt = py::none(), + const bool obfuscated = false); // init and exec dataset sub graph bool ME_EXPORT InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, @@ -203,6 +206,9 @@ py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_le py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode); bool PyIsCipherFile(const std::string &file_path); void FinalizeCluster(); +FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int obf_password, + int append_password, char *dec_key, const size_t key_len, + const std::string &dec_mode); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/opaque_predicate_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/opaque_predicate_kernel.cc new file mode 100644 index 00000000000..8c6d92c6a07 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/opaque_predicate_kernel.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/opaque_predicate_kernel.h" +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputsNum = 2; +constexpr size_t kOutputsNum = 1; +} // namespace + +template +bool OpaquePredicateKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + auto input1 = reinterpret_cast(inputs[0]->addr); + auto input2 = reinterpret_cast(inputs[1]->addr); + bool *output = reinterpret_cast(outputs[0]->addr); + output[0] = + CustomizedOpaquePredicate::GetInstance().run_function(static_cast(*input1), static_cast(*input2)); + return true; +} + +bool OpaquePredicateKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +std::vector> OpaquePredicateKernelMod::func_list_ = + {{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + &OpaquePredicateKernelMod::LaunchKernel}}; + +std::vector OpaquePredicateKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, OpaquePredicate, OpaquePredicateKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/opaque_predicate_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/opaque_predicate_kernel.h new file mode 100644 index 00000000000..f70e70ec60d --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/opaque_predicate_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class OpaquePredicateKernelMod : public NativeCpuKernelMod { + public: + OpaquePredicateKernelMod() = default; + ~OpaquePredicateKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using OpaquePredicateFunc = + std::function &, + const std::vector &, const std::vector &)>; + OpaquePredicateFunc kernel_func_; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_ diff --git a/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc b/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc new file mode 100644 index 00000000000..16dccb4077d --- /dev/null +++ b/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc @@ -0,0 +1,494 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h" +#include +#include +#include +#include +#include +#include +#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h" +#include "include/common/debug/anf_ir_dump.h" +#include "utils/info.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "ops/core_ops.h" + +namespace mindspore { +using Tensor = mindspore::tensor::Tensor; +using mindspore::abstract::AbstractTensor; +using mindspore::abstract::AbstractTensorPtr; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; + +constexpr int expand_rate = 10; // total node need for a switch graph + +ShapeVector get_node_shape(AnfNodePtr input_node) { + if (input_node == nullptr) { + MS_LOG(ERROR) << "Input_node is nullptr, get shape failed!"; + return {}; + } + AbstractBasePtr input_abstract = input_node->abstract(); + if (input_abstract == nullptr) { + MS_LOG(ERROR) << "The abstract of input_node is nullptr, get shape failed!"; + return {}; + } + AbstractTensorPtr input_abstract_tensor = input_abstract->cast(); + mindspore::abstract::ShapePtr shape_ptr = input_abstract_tensor->shape(); + return shape_ptr->shape(); +} + +TypeId get_node_dtype(AnfNodePtr input_node) { + if (input_node == nullptr) { + MS_LOG(ERROR) << "Input_node is nullptr, get dtype failed!"; + return {}; + } + AbstractBasePtr input_abstract = input_node->abstract(); + if (input_abstract == nullptr) { + MS_LOG(ERROR) << "The abstract of input_node is nullptr, get dtype failed!"; + return {}; + } + AbstractTensorPtr input_abstract_tensor = input_abstract->cast(); + AbstractBasePtr node_element = input_abstract_tensor->element(); + mindspore::abstract::AbstractScalarPtr node_element_abs = + node_element->cast(); + + TypeId data_type = node_element_abs->BuildType()->type_id(); + return data_type; +} + +std::vector name_split(std::string &node_name, const std::string &split_sign) { + node_name += split_sign; + unsigned int name_len = node_name.size(); + std::string::size_type split_pos; + std::vector res; + for (unsigned int i = 0; i < name_len; i++) { + split_pos = node_name.find(split_sign, i); + if (split_pos < name_len) { + std::string sub_str = node_name.substr(i, split_pos - i); + res.push_back(sub_str); + i = split_pos + split_sign.size() - 1; + } + } + return res; +} + +ValueNodePtr build_tuple_value_node(std::vector values) { + mindspore::ValueNodePtr v_node = std::make_shared(MakeValue(values)); + AbstractBasePtrList abs_list; + std::transform(values.begin(), values.end(), std::back_inserter(abs_list), [](const int64 &item) { + return std::make_shared(int64_t(item)); + }); + auto abs_tuple = std::make_shared(abs_list); + v_node->set_abstract(abs_tuple); + return v_node; +} + +ValueNodePtr make_int_node(FuncGraphPtr func_graph, int int_value) { + ShapeVector int_shape{1, 1}; + tensor::TensorPtr int_tensor = std::make_shared(mindspore::kNumberTypeInt32, int_shape); + int *tensor_data = reinterpret_cast(int_tensor->data_c()); + for (int i = 0; i < int_tensor->data().size(); i++) { + tensor_data[i] = int_value; + } + mindspore::ValueNodePtr int_tensor_node = std::make_shared(int_tensor); + int_tensor_node->set_abstract(int_tensor->ToAbstract()); + (void)func_graph->AddValueNode(int_tensor_node); + return int_tensor_node; +} + +tensor::TensorPtr make_weight_tensor(TypeId type_id, ShapeVector shape) { + tensor::TensorPtr weight_tensor = std::make_shared(type_id, shape); + std::default_random_engine generator; + int max_count = 10000; + int tensor_size = weight_tensor->data().size(); + if (type_id == kNumberTypeFloat64) { + const double mean_64 = 0; + const double stddev_64 = 1; + std::normal_distribution dist_64(mean_64, stddev_64); + double *float_64_data = reinterpret_cast(weight_tensor->data_c()); + for (int i = 0; i < std::min(tensor_size, max_count); i++) { + double random_float_64 = dist_64(generator); + if (random_float_64 > 0) { + float_64_data[i] = random_float_64; + } + } + } else { + MS_LOG(DEBUG) << "Type id is: " << type_id << ", weights will be float_32 format."; + const float mean = 0; + const float stddev = 1; + std::normal_distribution dist_32(mean, stddev); + float *float_32_data = reinterpret_cast(weight_tensor->data_c()); + for (int i = 0; i < std::min(tensor_size, max_count); i++) { + float random_float_32 = dist_32(generator); + if (random_float_32 > 0) { + float_32_data[i] = random_float_32; + } + } + } + return weight_tensor; +} + +bool CheckIfObfuscated(const FuncGraphPtr &func_graph) { + auto mgr = Manage(func_graph); + auto all_nodes = mgr->all_nodes(); + for (AnfNodePtr node : all_nodes) { + std::string node_name = node->fullname_with_scope(); + if (node_name.find("Switch") != std::string::npos) { + return true; + } + } + return false; +} + +FuncGraphPtr DynamicObfuscator::ObfuscateMindIR(const FuncGraphPtr &func_graph) { + MS_LOG(INFO) << "Start obfuscation."; + MS_EXCEPTION_IF_NULL(func_graph); + if (CheckIfObfuscated(func_graph)) { + MS_EXCEPTION(ValueError) << "The input model has been onfuscated, do not obfuscate it again."; + } + auto mgr = Manage(func_graph); + MS_EXCEPTION_IF_NULL(mgr); + auto all_nodes = mgr->all_nodes(); + int node_nums = all_nodes.size(); + MS_LOG(INFO) << "Total node num: " << node_nums; + // init the number control node that has been build + used_control_node_ = 0; + if (obf_password_ == 0) { + int obfuscate_target_num = std::ceil(all_nodes.size() * obf_ratio_ / expand_rate); + int obfuscate_node_num = 0; + // record customized_func computing results + for (AnfNodePtr node : all_nodes) { + std::string obf_type = single_op_obfuscate_type(node); + MS_LOG(INFO) << "obf_type: " << obf_type; + if (obf_type == "MatMul-op") { + obfuscate_node_num += 1; + MS_LOG(INFO) << "Find a MatMul Node: " << node->fullname_with_scope(); + bool customized_func_result = mindspore::kernel::CustomizedOpaquePredicate::GetInstance().run_function( + static_cast(1), static_cast(1)); + customized_func_results_.push_back(customized_func_result); + } + if (obfuscate_node_num >= obfuscate_target_num) { + break; + } + } + (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count(); + } + // do op-wise fake-branch obfuscation + (void)op_wise_fake_branch(func_graph); + if (used_control_node_ == 0) { + MS_LOG(WARNING) + << "The model has not been obfuscated, which means obf_password or customized_func is not need to set."; + } + return func_graph; +} + +void DynamicObfuscator::op_wise_fake_branch(FuncGraphPtr func_graph) { + auto mgr = Manage(func_graph); + auto all_nodes = mgr->all_nodes(); + int obfuscate_target_num = std::ceil(all_nodes.size() * obf_ratio_ / expand_rate); + int obfuscate_node_num = 0; + for (AnfNodePtr node : all_nodes) { + std::string obf_type = single_op_obfuscate_type(node); + MS_LOG(INFO) << "The obf_type is: " << obf_type; + if (obf_type == "MatMul-op") { + obfuscate_node_num += 1; + MS_LOG(INFO) << "Find a MatMul Node: " << node->fullname_with_scope(); + std::vector node_inputs = node->cast()->inputs(); + mindspore::AnfNodePtr input_1 = node_inputs[1]; + CNodePtr control_c_node = get_control_node(func_graph, input_1); + (void)replace_matmul_node(node->cast(), func_graph, control_c_node); + MS_LOG(INFO) << "Finished replacement for: " << node->fullname_with_scope(); + } + if (obfuscate_node_num >= obfuscate_target_num) { + break; + } + } +} + +std::string DynamicObfuscator::single_op_obfuscate_type(AnfNodePtr node) { + if (node->isa()) { + std::string node_name = node->fullname_with_scope(); + MS_LOG(INFO) << "The node_name is: " << node_name; + std::vector split_words = name_split(node_name, "/"); + std::string op_name = split_words[split_words.size() - 1]; + for (std::string target_op_name : obf_target_op) { + int op_name_len = op_name.size(); + int target_name_len = target_op_name.size(); + if ((op_name_len >= target_name_len) && (op_name.substr(0, target_name_len) == target_op_name)) { + return target_op_name; + } + } + return ""; + } + return ""; +} + +CNodePtr DynamicObfuscator::password_mode_control(FuncGraphPtr func_graph) { + ShapeVector y_shape{1, 1}; + tensor::TensorPtr y_tensor = std::make_shared(mindspore::kNumberTypeInt32, y_shape); + if (!has_build_appended_input) { + MS_LOG(INFO) << "Build parameter y and y_append."; + auto y = func_graph->add_parameter(); + y->set_name("y"); + y->set_abstract(y_tensor->ToAbstract()); + auto y_append = func_graph->add_parameter(); + y_append->set_name("y_append"); + y_append->set_abstract(y_tensor->ToAbstract()); + has_build_appended_input = true; + } + auto y = func_graph->GetParameterByName("y"); + auto y_append = func_graph->GetParameterByName("y_append"); + + if (used_control_node_ == 0) { + // make add function node + mindspore::PrimitivePtr add_prim = mindspore::prim::kPrimAdd; + add_prim->set_attr("is_load", MakeValue(true)); + mindspore::ValueNodePtr add_v_node = std::make_shared(add_prim); + (void)func_graph->AddValueNode(add_v_node); + CNodePtr add_c_node = func_graph->NewCNode({add_v_node, y, y_append}); + add_c_node->set_abstract(y_tensor->ToAbstract()); + // make equal function node + ValueNodePtr equal_v_node = std::make_shared(mindspore::prim::kPrimEqual); + (void)func_graph->AddValueNode(equal_v_node); + ValueNodePtr equal_compa_node = make_int_node(func_graph, obf_password_ + append_password_); + CNodePtr equal_c_node = func_graph->NewCNode({equal_v_node, add_c_node, equal_compa_node}); + tensor::TensorPtr equal_tensor = std::make_shared(mindspore::kNumberTypeBool, y_shape); + equal_c_node->set_abstract(equal_tensor->ToAbstract()); + (void)func_graph->AddNode(equal_c_node); + used_control_node_ += 1; + switch_branch_ = true; + return equal_c_node; + } + // make greater function node + int comparison_int = rand(); + ValueNodePtr greater_v_node = std::make_shared(mindspore::prim::kPrimGreater); + (void)func_graph->AddValueNode(greater_v_node); + ValueNodePtr greater_compa_node = make_int_node(func_graph, comparison_int); + CNodePtr greater_c_node = func_graph->NewCNode({greater_v_node, y, greater_compa_node}); + tensor::TensorPtr greater_tensor = std::make_shared(mindspore::kNumberTypeBool, y_shape); + greater_c_node->set_abstract(greater_tensor->ToAbstract()); + (void)func_graph->AddNode(greater_c_node); + used_control_node_ += 1; + switch_branch_ = obf_password_ > comparison_int; + return greater_c_node; +} + +mindspore::CNodePtr AddStrideSliceNode(FuncGraphPtr func_graph, ShapeVector begin_vector, ShapeVector stride_vector, + ShapeVector end_vector, int end_mask, int begin_mask, + mindspore::CNodePtr prev_node) { + mindspore::ValueNodePtr begin_v_node = build_tuple_value_node(begin_vector); + mindspore::ValueNodePtr stride_v_node = build_tuple_value_node(stride_vector); + mindspore::ValueNodePtr end_v_node = build_tuple_value_node(end_vector); + (void)func_graph->AddValueNode(begin_v_node); + (void)func_graph->AddValueNode(stride_v_node); + (void)func_graph->AddValueNode(end_v_node); + mindspore::PrimitivePtr slice_prim = mindspore::prim::kPrimStridedSlice; + slice_prim->set_attr("is_load", MakeValue(true)); + slice_prim->set_attr("new_axis_mask", MakeValue(int64_t(0))); + slice_prim->set_attr("shrink_axis_mask", MakeValue(int64_t(1))); + slice_prim->set_attr("end_mask", MakeValue(int64_t(end_mask))); + slice_prim->set_attr("begin_mask", MakeValue(int64_t(begin_mask))); + slice_prim->set_attr("ellipsis_mask", MakeValue(int64_t(0))); + mindspore::ValueNodePtr slice_v_node = std::make_shared(slice_prim); + (void)func_graph->AddValueNode(slice_v_node); + mindspore::CNodePtr slice_c_node = + func_graph->NewCNode({slice_v_node, prev_node, begin_v_node, end_v_node, stride_v_node}); + return slice_c_node; +} + +CNodePtr DynamicObfuscator::custom_op_mode_control(FuncGraphPtr func_graph, AnfNodePtr prev_node) { + mindspore::PrimitivePtr reshape_prim = mindspore::prim::kPrimReshape; + reshape_prim->set_attr("is_load", MakeValue(true)); + mindspore::ValueNodePtr reshape_v_node = std::make_shared(reshape_prim); + (void)func_graph->AddValueNode(reshape_v_node); + ShapeVector prev_node_shape = get_node_shape(prev_node); + int shape_multiply = std::accumulate(prev_node_shape.begin(), prev_node_shape.end(), 1, std::multiplies()); + MS_LOG(INFO) << "The shape_multiply is: " << shape_multiply; + + ShapeVector flat_shape{1, shape_multiply}; + mindspore::ValueNodePtr shape_v_node = std::make_shared(MakeValue(flat_shape)); + (void)func_graph->AddValueNode(shape_v_node); + mindspore::CNodePtr reshape_c_node = func_graph->NewCNode({reshape_v_node, prev_node, shape_v_node}); + TypeId data_type = get_node_dtype(prev_node); + auto reshape_abstract = std::make_shared(data_type, flat_shape)->ToAbstract(); + reshape_c_node->set_abstract(reshape_abstract); + (void)func_graph->AddNode(reshape_c_node); + + // the first stride_slice x[0] + ShapeVector begin_1{0, 0}; + ShapeVector stride_1{1, 1}; + mindspore::CNodePtr slice_c_node_1 = + AddStrideSliceNode(func_graph, begin_1, stride_1, flat_shape, 2, 2, reshape_c_node); + ShapeVector slice_1_shape{shape_multiply}; + slice_c_node_1->set_abstract(std::make_shared(data_type, slice_1_shape)->ToAbstract()); + (void)func_graph->AddNode(slice_c_node_1); + + // the first stride_slice x[0][0] + ShapeVector begin_2{0}; + ShapeVector end_2{1}; + ShapeVector stride_2{1}; + mindspore::CNodePtr slice_c_node_2 = + AddStrideSliceNode(func_graph, begin_2, stride_2, stride_2, 0, 0, slice_c_node_1); + ShapeVector slice_2_shape{1}; + slice_c_node_2->set_abstract(std::make_shared(data_type, slice_2_shape)->ToAbstract()); + (void)func_graph->AddNode(slice_c_node_2); + + // the second stride_slice x[0][1] + ShapeVector begin_3{1}; + ShapeVector end_3{1}; + ShapeVector stride_3{2}; + mindspore::CNodePtr slice_c_node_3 = + AddStrideSliceNode(func_graph, begin_3, stride_3, stride_3, 0, 0, slice_c_node_1); + ShapeVector slice_3_shape{1}; + slice_c_node_3->set_abstract(std::make_shared(data_type, slice_3_shape)->ToAbstract()); + (void)func_graph->AddNode(slice_c_node_3); + + // add opaque predicate + PrimitivePtr custom_prim = mindspore::prim::kPrimOpaquePredicate; + custom_prim->set_attr("is_load", MakeValue(true)); + std::vector input_names_value; + input_names_value.push_back(std::make_shared("x")); + input_names_value.push_back(std::make_shared("y")); + custom_prim->set_attr("input_names", std::make_shared(input_names_value)); + std::vector output_names_value; + output_names_value.push_back(std::make_shared("output")); + custom_prim->set_attr("output_names", std::make_shared(output_names_value)); + auto opaque_v_node = std::make_shared(custom_prim); + (void)func_graph->AddValueNode(opaque_v_node); + auto opaque_c_node = func_graph->NewCNode({opaque_v_node, slice_c_node_2, slice_c_node_3}); + ShapeVector y_shape{1, 1}; + auto bool_tensor = std::make_shared(mindspore::kNumberTypeBool, y_shape); + opaque_c_node->set_abstract(bool_tensor->ToAbstract()); + (void)func_graph->AddNode(opaque_c_node); + return opaque_c_node; +} + +CNodePtr DynamicObfuscator::get_control_node(FuncGraphPtr func_graph, AnfNodePtr prev_node) { + if (obf_password_ != 0) { + MS_LOG(INFO) << "Run password mode."; + return password_mode_control(func_graph); + } + MS_LOG(INFO) << "Run customized function mode."; + return custom_op_mode_control(func_graph, prev_node); +} + +void DynamicObfuscator::replace_matmul_node(CNodePtr node, FuncGraphPtr func_graph, CNodePtr control_node) { + std::vector node_inputs = node->cast()->inputs(); + mindspore::ValueNodePtr matmul_v_node = node_inputs[0]->cast(); + mindspore::AnfNodePtr input_1 = node_inputs[1]; + mindspore::AnfNodePtr input_2 = node_inputs[2]; + + // construct branch 1 + mindspore::FuncGraphPtr fg_1 = std::make_shared(); + + // input_x + ParameterPtr branch_1_input_x = fg_1->add_parameter(); + branch_1_input_x->set_abstract(input_1->abstract()); + branch_1_input_x->set_name("branch_1_input_x"); + + // input_y + ParameterPtr branch_1_input_y = fg_1->add_parameter(); + branch_1_input_y->set_abstract(input_2->abstract()); + branch_1_input_y->set_name("branch_1_input_y"); + + mindspore::CNodePtr matmul_c_node_1 = fg_1->NewCNode({matmul_v_node, branch_1_input_x, branch_1_input_y}); + matmul_c_node_1->set_abstract(node->cast()->abstract()); + (void)fg_1->AddNode(matmul_c_node_1); + + // add return node + mindspore::ValueNodePtr return_v_node_1 = std::make_shared(mindspore::prim::kPrimReturn); + (void)fg_1->AddValueNode(return_v_node_1); + mindspore::CNodePtr branch_1_return = fg_1->NewCNode({return_v_node_1, matmul_c_node_1}); + (void)fg_1->AddNode(branch_1_return); + fg_1->set_return(branch_1_return); + fg_1->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); + + mindspore::ValueNodePtr partial_v_node_1 = std::make_shared(mindspore::prim::kPrimPartial); + (void)func_graph->AddValueNode(partial_v_node_1); + mindspore::ValueNodePtr fg_1_node = std::make_shared(fg_1); + fg_1_node->set_abstract(fg_1->ToAbstract()); + (void)func_graph->AddValueNode(fg_1_node); + mindspore::CNodePtr partial_c_node_1 = func_graph->NewCNode({partial_v_node_1, fg_1_node, input_1, input_2}); + (void)func_graph->AddNode(partial_c_node_1); + + // construct branch 2 + mindspore::FuncGraphPtr fg_2 = std::make_shared(); + // add input_x + ParameterPtr branch_2_input_x = fg_2->add_parameter(); + branch_2_input_x->set_abstract(input_1->abstract()); + branch_2_input_x->set_name("branch_2_input_x"); + // add input_y + ParameterPtr branch_2_input_y = fg_2->add_parameter(); + branch_2_input_y->set_abstract(input_2->abstract()); + branch_2_input_y->set_name("branch_2_input_y"); + + // add matmul CNode + mindspore::CNodePtr matmul_c_node_2 = fg_2->NewCNode({matmul_v_node, branch_2_input_x, branch_2_input_y}); + matmul_c_node_2->set_abstract(node->cast()->abstract()); + (void)fg_2->AddNode(matmul_c_node_2); + + // add return node + mindspore::ValueNodePtr return_v_node_2 = std::make_shared(mindspore::prim::kPrimReturn); + (void)fg_2->AddValueNode(return_v_node_2); + mindspore::CNodePtr branch_2_return = fg_2->NewCNode({return_v_node_2, matmul_c_node_2}); + (void)fg_2->AddNode(branch_2_return); + fg_2->set_return(branch_2_return); + fg_2->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); + + // add partial for branch_2 + ShapeVector matmul_2_shape = get_node_shape(input_2); + TypeId type_id = get_node_dtype(input_2); + tensor::TensorPtr matmul_2_weight = make_weight_tensor(type_id, matmul_2_shape); + mindspore::ValueNodePtr matmul_weight_v_node = std::make_shared(matmul_2_weight); + matmul_weight_v_node->set_abstract(matmul_2_weight->ToAbstract()); + (void)func_graph->AddValueNode(matmul_weight_v_node); + + mindspore::ValueNodePtr partial_v_node_2 = std::make_shared(mindspore::prim::kPrimPartial); + (void)func_graph->AddValueNode(partial_v_node_2); + mindspore::ValueNodePtr fg_2_node = std::make_shared(fg_2); + fg_2_node->set_abstract(fg_2->ToAbstract()); + (void)func_graph->AddValueNode(fg_2_node); + mindspore::CNodePtr partial_c_node_2 = + func_graph->NewCNode({partial_v_node_2, fg_2_node, input_1, matmul_weight_v_node}); + (void)func_graph->AddNode(partial_c_node_2); + + // add switch node + mindspore::ValueNodePtr switch_v_node = std::make_shared(mindspore::prim::kPrimSwitch); + (void)func_graph->AddValueNode(switch_v_node); + mindspore::CNodePtr switch_c_node; + if (obf_password_ == 0) { + int results_len = customized_func_results_.size(); + switch_branch_ = customized_func_results_[results_len - 1 - used_control_node_]; + used_control_node_ += 1; + } + if (switch_branch_) { + switch_c_node = func_graph->NewCNode({switch_v_node, control_node, partial_c_node_1, partial_c_node_2}); + } else { + switch_c_node = func_graph->NewCNode({switch_v_node, control_node, partial_c_node_2, partial_c_node_1}); + } + func_graph->AddNode(switch_c_node); + + // add call node + mindspore::CNodePtr call_cnode = func_graph->NewCNode({switch_c_node}); + func_graph->AddNode(call_cnode); + // add fg_1 and fg_2 to func_graph + auto mgr = mindspore::Manage(func_graph); + mgr->AddFuncGraph(fg_1); + mgr->AddFuncGraph(fg_2); + mgr->Replace(node, call_cnode); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h b/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h new file mode 100644 index 00000000000..757d4d74928 --- /dev/null +++ b/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h @@ -0,0 +1,53 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DYNAMIC_OBFUSCATION_H +#define MINDSPORE_DYNAMIC_OBFUSCATION_H + +#include +#include +#include "load_mindir/load_model.h" +#include "include/common/visible.h" + +namespace mindspore { +class COMMON_EXPORT DynamicObfuscator { + public: + DynamicObfuscator(const float obf_ratio, const int obf_password, const int append_password) + : obf_ratio_(obf_ratio), obf_password_(obf_password), append_password_(append_password) {} + + ~DynamicObfuscator() = default; + + FuncGraphPtr ObfuscateMindIR(const FuncGraphPtr &func_graph); + + private: + void op_wise_fake_branch(FuncGraphPtr func_graph); + std::string single_op_obfuscate_type(AnfNodePtr node); + CNodePtr get_control_node(FuncGraphPtr func_graph, AnfNodePtr prev_node); + CNodePtr password_mode_control(FuncGraphPtr func_graph); + CNodePtr custom_op_mode_control(FuncGraphPtr func_graph, AnfNodePtr prev_node); + void replace_matmul_node(CNodePtr node, FuncGraphPtr func_graph, CNodePtr flag_node); + + const float obf_ratio_ = 0.01; + const int obf_password_; + const int append_password_; + bool has_build_appended_input = false; + std::vector customized_func_results_; + int used_control_node_ = 0; + bool switch_branch_ = true; + const std::vector obf_target_op = {"MatMul-op", "Add-op", "Mat-op", "Sub-op", "Softmax-op", "Relu-op"}; +}; +} // namespace mindspore +#endif // MINDSPORE_DYNAMIC_OBFUSCATION_H diff --git a/mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.cc b/mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.cc new file mode 100644 index 00000000000..54109500841 --- /dev/null +++ b/mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h" +#include +#include "utils/info.h" + +namespace mindspore { +namespace kernel { +CustomizedOpaquePredicate &CustomizedOpaquePredicate::GetInstance() { + static CustomizedOpaquePredicate instance; + return instance; +} + +void CustomizedOpaquePredicate::set_func_names() { + // get the function of get_func_names() + py::gil_scoped_acquire gil_acquire; + static const std::string &module_name = "mindspore.ops.operations._opaque_predicate_registry"; + static const std::string &entrance = "get_func_names"; + py::module module = py::module::import(module_name.c_str()); + py::object get_pyfunc_obj = module.attr(entrance.c_str()); + if (get_pyfunc_obj.is_none()) { + MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name; + } + py::function get_pyfunc = get_pyfunc_obj.cast(); + py::tuple func_name_list = get_pyfunc(); + // clear old functions + func_names_.clear(); + for (size_t i = 0; i < func_name_list.size(); i++) { + func_names_.push_back(py::str(func_name_list[i])); + } + MS_LOG(DEBUG) << "Set function names finished, the number of functions is: " << func_names_.size(); +} + +const std::vector CustomizedOpaquePredicate::get_func_names() { + if (func_names_.size() == 0) { + MS_LOG(EXCEPTION) << "The number of customized function names is zero, get function names failed."; + } + return func_names_; +} + +py::function CustomizedOpaquePredicate::get_function() { + py::gil_scoped_acquire gil_acquire; + // get the function of get_opaque_predicate() + static const std::string &module_name = "mindspore.ops.operations._opaque_predicate_registry"; + static const std::string &entrance = "get_opaque_predicate"; + py::module module = py::module::import(module_name.c_str()); + py::object get_pyfunc_obj = module.attr(entrance.c_str()); + if (get_pyfunc_obj.is_none()) { + MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name; + } + py::function get_pyfunc = get_pyfunc_obj.cast(); + MS_LOG(DEBUG) << "The number of function is : " << func_names_.size(); + if (func_names_.size() == 0) { + MS_EXCEPTION(ValueError) << "The customized_func is not set, please set it in load()."; + } + std::string func_name = func_names_[0]; + MS_LOG(DEBUG) << "Get function name: " << func_name; + func_name_code_.clear(); + std::transform(func_name.begin(), func_name.end(), std::back_inserter(func_name_code_), + [](const char &item) { return static_cast(item); }); + if (func_name_code_.size() == 0) { + MS_EXCEPTION(ValueError) << "Set func_name_code_ failed."; + } + py::object py_func_obj = get_pyfunc(py::str(func_name)); + if (py_func_obj.is_none()) { + MS_EXCEPTION(ValueError) << "Cannot find python func with name: " << func_name; + } + return py_func_obj.cast(); +} + +bool CustomizedOpaquePredicate::run_function(float x, float y) { + if (Py_IsInitialized() != true) { + MS_LOG(ERROR) << "Py_IsInitialized failed."; + return false; + } + py::object customized_func = get_function(); + py::gil_scoped_acquire gil_acquire; + int inputs_num = 2; + py::tuple inputs(inputs_num); + inputs[0] = py::float_(x); + inputs[1] = py::float_(y); + py::object result = customized_func(*inputs); + if (result.is_none()) { + MS_EXCEPTION(ValueError) << "Computing result of customized_func is None, please check it."; + } + bool bool_result = py::cast(result); + int even_num = 2; + if (func_name_code_[calling_count_ % func_name_code_.size()] % even_num == 0) { + calling_count_ += 1; + return bool_result; + } + calling_count_ += 1; + return !bool_result; +} + +void CustomizedOpaquePredicate::init_calling_count() { + this->calling_count_ = 0; + MS_LOG(INFO) << "calling_count_ has been initialized to " << calling_count_; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h b/mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h new file mode 100644 index 00000000000..b81259c73bb --- /dev/null +++ b/mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h @@ -0,0 +1,57 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H +#define MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include +#include "pybind11/numpy.h" +#include "include/common/visible.h" + +namespace py = pybind11; +namespace mindspore { +namespace kernel { +class COMMON_EXPORT CustomizedOpaquePredicate { + public: + static CustomizedOpaquePredicate &GetInstance(); + void set_func_names(); + const std::vector get_func_names(); + bool run_function(float x, float y); + py::function get_function(); + void init_calling_count(); + + private: + CustomizedOpaquePredicate() : func_names_({}) {} + ~CustomizedOpaquePredicate() = default; + + std::vector func_names_; + int calling_count_ = 0; + std::vector func_name_code_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index 6320050b72d..e6a11b7c25f 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -56,6 +56,7 @@ #include "ops/grad/max_pool_grad_with_argmax.h" #include "ops/max_pool_with_argmax.h" #include "ops/mirror_pad.h" +#include "ops/opaquePredicate.h" namespace mindspore { namespace abstract { @@ -323,6 +324,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimIdentity, R{InferImplIdentity, nullptr, true}}, {prim::kPrimLoad, R{InferImplLoad, nullptr, true}}, {prim::kPrimMutable, R{InferImplMutable, nullptr, true}}, + {prim::kPrimOpaquePredicate, R{ops::OpaquePredicateInfer, nullptr, true}}, // Set impl to null as it will use PartialEvaluator; {prim::kPrimPartial, R{nullptr, nullptr, true}}, {prim::kPrimEnvironCreate, R{InferImplEnvironCreate, nullptr, true}}, diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 8f810a8fc31..70e35fe719c 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -1376,6 +1376,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicLossScale, std::make_shared("_Dyna GVAR_DEF(PrimitivePtr, kPrimScaleGrad, std::make_shared("ScaleGrad")); GVAR_DEF(PrimitivePtr, kPrimPopulationCount, std::make_shared("PopulationCount")); GVAR_DEF(PrimitivePtr, kPrimBlackmanWindow, std::make_shared("BlackmanWindow")); +GVAR_DEF(PrimitivePtr, kPrimOpaquePredicate, std::make_shared("OpaquePredicate")); // Structures GVAR_DEF(PrimitivePtr, kPrimMakeList, std::make_shared("make_list")); diff --git a/mindspore/core/ops/opaquePredicate.cc b/mindspore/core/ops/opaquePredicate.cc new file mode 100644 index 00000000000..1f58acca5f5 --- /dev/null +++ b/mindspore/core/ops/opaquePredicate.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/opaquePredicate.h" +#include +#include +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +abstract::ShapePtr OpaquePredicateInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + return BroadCastInferShape(op_name, input_args); +} + +TypePtr OpaquePredicateInferType(const PrimitivePtr &prim, const std::vector &input_args) { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto x = CheckAndConvertUtils::CheckArgs(prim->name(), input_args, 0); + auto y = CheckAndConvertUtils::CheckArgs(prim->name(), input_args, 1); + (void)abstract::CheckDtypeSame(prim->name(), x, y); + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat, kFloat16, kUInt16, + kFloat64, kUInt8, kBool, kComplex64, kComplex128, kUInt32}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("y", input_args[1]->BuildType(), valid_types, prim->name()); + return std::make_shared(kBool); +} + +MIND_API_OPERATOR_IMPL(OpaquePredicate, BaseOperator); +AbstractBasePtr OpaquePredicateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto shape = OpaquePredicateInferShape(primitive, input_args); + auto type = OpaquePredicateInferType(primitive, input_args); + return abstract::MakeAbstract(shape, type); +} + +REGISTER_PRIMITIVE_C(kNameOpaquePredicate, OpaquePredicate); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/opaquePredicate.h b/mindspore/core/ops/opaquePredicate.h new file mode 100644 index 00000000000..e82b2de7771 --- /dev/null +++ b/mindspore/core/ops/opaquePredicate.h @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_ +#define MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_ +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameOpaquePredicate = "OpaquePredicate"; +/// \brief The opaque predicate used for dynamic obfuscation +class MIND_API OpaquePredicate : public BaseOperator { + public: + MIND_API_BASE_MEMBER(OpaquePredicate); + OpaquePredicate() : BaseOperator(kNameOpaquePredicate) { InitIOName({"x", "y"}, {"output"}); } + void Init() const {} +}; + +abstract::AbstractBasePtr OpaquePredicateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_ diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index 303dc835443..dd1cba05a0e 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -26,6 +26,7 @@ import inspect import importlib from collections import OrderedDict from functools import wraps +import numpy as np import mindspore as ms from mindspore import context from mindspore import log as logger @@ -46,7 +47,6 @@ from mindspore._checkparam import Validator from mindspore.common._utils import is_shape_unknown from mindspore.common.mutable import mutable - # store ms_function class compiled pipeline cache ms_compile_cache = set() # store cell compiled pipeline cache, @@ -670,6 +670,7 @@ class _MsFunctionCompileContext: """ ms_function compile status manager """ + def __init__(self): pass @@ -1064,10 +1065,12 @@ class _CellGraphExecutor: Returns: Graph, return the result of pipeline running. """ + def __init__(self): # create needed graph by lazy mode self.is_init = False self.enable_tuple_broaden = False + self.obfuscate_config = None # used for model's dynamic obfuscation self._graph_executor = GraphExecutor_.get_instance() self._graph_executor.set_py_exe_path(sys.executable) self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep) @@ -1238,8 +1241,8 @@ class _CellGraphExecutor: return self._graph_executor.get_allreduce_fusion(real_phase) def __call__(self, obj, *args, phase='predict'): - if context.get_context("precompile_only") or\ - (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched(): + if context.get_context("precompile_only") or \ + (_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched(): return None return self.run(obj, *args, phase=phase) @@ -1290,6 +1293,21 @@ class _CellGraphExecutor: exec_id = exec_id + '.' + obj.arguments_key if self._graph_executor.has_compiled(exec_id) is False: return None + if self.obfuscate_config is not None: + if ('obf_ratio' not in self.obfuscate_config.keys()) or ( + 'obf_password' not in self.obfuscate_config.keys()): + raise ValueError("'obf_ratio' and 'obf_password' must be in obfuscate_config.") + obf_password = self.obfuscate_config.get('obf_password') + if obf_password == 0: + append_password = 0 + else: + seed_max = 2 ** 32 - 1 + int_max = 2 ** 31 - 1 + np.random.seed(obf_password % seed_max) + append_password = np.random.randint(int_max) + obf_password %= int_max + return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, self.obfuscate_config['obf_ratio'], + obf_password, append_password) return self._graph_executor.get_func_graph_proto(exec_id, ir_type) def get_optimize_graph_proto(self, obj): diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index cbca32a986f..edb4837f360 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -2219,18 +2219,18 @@ class Cell(Cell_): if isinstance(set_input, Tensor): if not isinstance(net_input, Tensor): raise TypeError( - f"The {index+1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.") + f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.") if set_input.dtype is not net_input.dtype: raise ValueError( - f"The {index+1}th input type of 'set_inputs' must be the same as network's input, " + f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, " f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.") if net_input.dim() != 0 and set_input.dim() != net_input.dim(): raise ValueError( - f"The {index+1}th input dims of 'set_inputs' must be the same as network's input, " + f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, " f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.") if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]): raise ValueError( - f"The {index+1}th input shape of 'set_inputs' must be the same as network's input, " + f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, " f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.") @@ -2247,6 +2247,11 @@ class GraphCell(Cell): The key is the parameter name whose type is str, and the value is a Tensor or Parameter. If the parameter exists in the graph according to the name, update it's value. If the parameter does not exist, ignore it. Default: None. + obf_password (int): The password used for dynamic obfuscation. "dynamic obfuscation" is used for model + protection, which can refer to `mindspore.train.serialization.obfuscate_model()`. If the input 'graph' is a + func_graph loaded from a mindir file obfuscated in password mode, then obf_password should be provided. + obf_password should be larger than zero and less or equal than int_64 (9223372036854775807). default: None. + Raises: TypeError: If the `graph` is not a FuncGraph. TypeError: If the `params_init` is not a dict. @@ -2273,13 +2278,19 @@ class GraphCell(Cell): [6. 9. 6.] [4. 6. 4.]]]] """ - def __init__(self, graph, params_init=None): + + def __init__(self, graph, params_init=None, obf_password=None): super(GraphCell, self).__init__(auto_prefix=True) if not isinstance(graph, FuncGraph): raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, " f"but got type {type(graph)}.") self.graph = graph - + self.obf_password = obf_password + int_64_max = 9223372036854775807 + if (obf_password is not None) and (obf_password <= 0 or obf_password > int_64_max): + raise ValueError( + "'obf_password' must be larger than 0, and less or equal than int64 ({})," + "but got {}.".format(int_64_max, obf_password)) params_init = {} if params_init is None else params_init if not isinstance(params_init, dict): raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.") @@ -2299,7 +2310,21 @@ class GraphCell(Cell): def __call__(self, *inputs): self.phase = "graph_load_from_mindir" self._add_attr("graph_load_from_mindir", self.graph) - return self.compile_and_run(*inputs) + if not self.obf_password: + return self.compile_and_run(*inputs) + append_input_1, append_input_2 = _obf_appended_inputs(self.obf_password) + return self.compile_and_run(*inputs, append_input_1, append_input_2) + + +def _obf_appended_inputs(obf_password): + seed_max = 2 ** 32 - 1 + int_max = 2 ** 31 - 1 + numpy.random.seed(obf_password % seed_max) + append_password = numpy.random.randint(int_max) + obf_password %= int_max + append_input_1 = Tensor((numpy.ones((1, 1)) * obf_password).astype(numpy.int32)) + append_input_2 = Tensor((numpy.ones((1, 1)) * append_password).astype(numpy.int32)) + return append_input_1, append_input_2 def _check_param_list_tuple(value): diff --git a/mindspore/python/mindspore/ops/_register_for_op.py b/mindspore/python/mindspore/ops/_register_for_op.py index 88630922b68..2e6ea5537eb 100644 --- a/mindspore/python/mindspore/ops/_register_for_op.py +++ b/mindspore/python/mindspore/ops/_register_for_op.py @@ -60,3 +60,13 @@ class PyFuncRegistry(UserDict): if key not in self: raise ValueError(f"Python function with key{key} not registered.") return self[key] + + +class OpaquePredicateRegistry(PyFuncRegistry): + def __init__(self): + super(OpaquePredicateRegistry, self).__init__() + self.func_names = [] + + def register(self, key, value): + self[key] = value + self.func_names.append(key) diff --git a/mindspore/python/mindspore/ops/operations/_opaque_predicate_registry.py b/mindspore/python/mindspore/ops/operations/_opaque_predicate_registry.py new file mode 100644 index 00000000000..b8fb3a6d8db --- /dev/null +++ b/mindspore/python/mindspore/ops/operations/_opaque_predicate_registry.py @@ -0,0 +1,37 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Register pyfunc for opaque_func_cpu_kernel""" + +from mindspore.ops._register_for_op import OpaquePredicateRegistry + + +registered_func_name = OpaquePredicateRegistry() + + +def add_opaque_predicate(fn_name, func): + registered_func_name.register(fn_name, func) + + +def get_opaque_predicate(fn_name): + return registered_func_name.get(fn_name) + + +def get_func_names(): + return registered_func_name.func_names + + +def clean_funcs(): + registered_func_name.func_names = [] diff --git a/mindspore/python/mindspore/train/__init__.py b/mindspore/python/mindspore/train/__init__.py index 68db214a9d6..8de28e5ea09 100644 --- a/mindspore/python/mindspore/train/__init__.py +++ b/mindspore/python/mindspore/train/__init__.py @@ -26,7 +26,7 @@ from mindspore.train.amp import build_train_network from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \ load, parse_print, build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, \ - async_ckpt_thread_status, restore_group_info_list, convert_model + async_ckpt_thread_status, restore_group_info_list, convert_model, obfuscate_model from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \ CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, \ History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore @@ -39,7 +39,7 @@ __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "bui "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter", "load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list", "convert_model", - "data_sink"] + "data_sink", "obfuscate_model"] __all__.extend(callback.__all__) __all__.extend(summary.__all__) __all__.extend(train_thor.__all__) diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index 1b0400d7e72..d25f5aae56f 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -55,11 +55,11 @@ from mindspore.parallel._cell_wrapper import get_allgather_cell from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices -from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy,\ +from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \ _restore_group_info_list from mindspore.train._utils import read_proto -from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file - +from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir +from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16, "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64, @@ -384,6 +384,17 @@ def _check_append_dict(append_dict): return append_dict +def _check_load_obfuscate(**kwargs): + if 'obf_func' in kwargs.keys(): + customized_func = kwargs.get('obf_func') + if not callable(customized_func): + raise ValueError("obf_func must be a callable function, but got a {}.".format(type(customized_func))) + clean_funcs() + add_opaque_predicate(customized_func.__name__, customized_func) + return True + return False + + def load(file_name, **kwargs): """ Load MindIR. @@ -436,6 +447,9 @@ def load(file_name, **kwargs): "please check whether the 'file_name' is correct.") file_name = os.path.realpath(file_name) + # set customized functions for dynamic obfuscation + obfuscated = _check_load_obfuscate(**kwargs) + logger.info("Execute the process of loading mindir.") if 'dec_key' in kwargs.keys(): dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes) @@ -448,9 +462,9 @@ def load(file_name, **kwargs): else: dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str) graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode, - decrypt=dec_func) + decrypt=dec_func, obfuscated=obfuscated) else: - graph = load_mindir(file_name) + graph = load_mindir(file_name, obfuscated=obfuscated) if graph is None: if _is_cipher_file(file_name): @@ -461,6 +475,152 @@ def load(file_name, **kwargs): return graph +def _check_param_type(param_config, key, target_type, requested): + """check type of parameters""" + if key in param_config: + if not isinstance(param_config[key], target_type): + raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key]))) + return param_config[key] + if requested: + raise ValueError("The parameter {} is requested, but not got.".format(key)) + if key == "obf_password": + return 0 + return None + + +def _check_obfuscate_params(obf_config): + """check obfuscation parameters, including obf_password, obf_ratio, customized_func""" + if 'obf_password' not in obf_config.keys() and 'customized_func' not in obf_config.keys(): + raise ValueError( + "At least one of 'obf_password' or 'customized_func' must be set in obf_config, but got None of them.") + obfuscate_type = _check_param_type(obf_config, "type", str, False) + if obfuscate_type not in (None, "dynamic"): + raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type)) + if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str): + if obf_config['obf_ratio'] not in ["small", "medium", "large"]: + raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format( + obf_config['obf_ratio'])) + ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6} + obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio']) + obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True) + if (obf_ratio <= 0) or (obf_ratio > 1): + raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio'])) + customized_funcs = [] + if 'customized_func' in obf_config.keys(): + if callable(obf_config['customized_func']): + customized_funcs.append(obf_config['customized_func']) + else: + raise TypeError( + "'customized_func' must be a function, but not got {}.".format(type(obf_config['customized_func']))) + obf_password = _check_param_type(obf_config, "obf_password", int, False) + int_64_max = 9223372036854775807 + if obf_password > int_64_max: + raise ValueError( + "'obf_password' must be less or equal than int64 ({}), but got {}.".format(int_64_max, obf_password)) + return obf_ratio, customized_funcs, obf_password + + +def obfuscate_model(obf_config, **kwargs): + """ + Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its + predict correctness. The obfuscated model can prevent attackers from stealing the model. + + Args: + obf_config (dict): obfuscation config. + + - type (str): The type of obfuscation, only 'dynamic' is supported until now. + - original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original + model is encrypted, then enc_key and enc_mode should be provided. + - save_model_path (str): The path to save the obfuscated model. + - model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which + is the same as using `export()`. + - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio` should + be in range of (0, 1] or in ["small", "medium", "large"]. + - customized_func (function): A python function used for customized function mode, which used for control + the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This + function needs to ensure that its result is constant for any input. Users can refer to opaque + predicates. If customized_func is set, then it should be passed to `load()` interface when loading + obfuscated model. + - obf_password (int): A password used for password mode, which should be larger than zero. If + obf_password is set, then it should be passed to `nn.GraphCell()` interface when loading obfuscated + model. It should be noted that at least one of 'customized_func' or 'obf_password' should be set, and + 'obf_password' mode would be applied if both of them are set. + + kwargs (dict): Configuration options dictionary. + + - enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32. + - enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set. + Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'. + + Raises: + TypeError: If obf_config is not a dict. + ValueError: If enc_key is passed and enc_mode is not in ["AES-GCM", "AES-CBC"]. + ValueError: If original_model_path is not provided in obf_config. + ValueError: If the model saved in original_model_path has been obfuscated. + ValueError: If save_model_path is not provided in obf_config. + ValueError: If obf_ratio is not provided in obf_config. + ValueError: If both customized_func and obf_password are not provided in obf_config. + ValueError: If both obf_password is not in (0, 9223372036854775807]. + + Examples: + >>> obf_config = {'original_model_path': "./net.mindir", + ... 'save_model_path': "./obf_net", + ... 'model_inputs': [input1, ], + ... 'obf_ratio': 0.1, 'obf_password': 173262358423} + >>> obfuscate_model(obf_config) + >>> obf_func = load("obf_net.mindir") + >>> obf_net = nn.GraphCell(obf_func, obf_password=173262358423) + >>> print(obf_net(input1).asnumpy()) + """ + if not isinstance(obf_config, dict): + raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config))) + file_path = _check_param_type(obf_config, "original_model_path", str, True) + saved_path = _check_param_type(obf_config, "save_model_path", str, True) + model_inputs = _check_param_type(obf_config, "model_inputs", list, True) + for item in model_inputs: + if not isinstance(item, Tensor): + raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item))) + obf_ratio, customized_funcs, obf_password = _check_obfuscate_params(obf_config) + if customized_funcs and obf_password > 0: + logger.warning("Although customized_func and obf_password are set, the 'obf_password' mode would be" + "applied, remember to set obf_password when loading obfuscated model.") + + if obf_password == 0: # apply customized_func mode + clean_funcs() + for func in customized_funcs: + add_opaque_predicate(func.__name__, func) + append_password = 0 + else: + seed_max = 2 ** 32 - 1 + int_max = 2 ** 31 - 1 + np.random.seed(obf_password % seed_max) + append_password = np.random.randint(int_max) + obf_password %= int_max + + if 'enc_key' in kwargs.keys(): + enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes) + enc_mode = "AES-GCM" + if 'enc_mode' in kwargs.keys(): + enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str) + if enc_mode not in ["AES-GCM", "AES-CBC"]: + raise ValueError( + "Only MindIR files that encrypted with 'AES-GCM' or 'AES-CBC' is supported for obfuscate_model()," + " but got {}.".format(enc_mode)) + obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio, obf_password=obf_password, + append_password=append_password, dec_key=enc_key, key_len=len(enc_key), + dec_mode=enc_mode) + else: + obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio, obf_password=obf_password, + append_password=append_password) + + obf_net = nn.GraphCell(obf_graph) + if obf_password != 0: + y_tensor = Tensor(np.ones((1, 1)).astype(np.int32)) + append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32)) + model_inputs += [y_tensor, append_y_tensor] + export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs) + + def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM", specify_prefix=None): """ @@ -524,7 +684,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N np_type = tensor_to_np_type.get(data_type) ms_type = tensor_to_ms_type[data_type] if data_type == 'str': - str_length = int(len(data)/4) + str_length = int(len(data) / 4) np_type = np_type + str(str_length) element_data = np.frombuffer(data, np_type) param_data_list.append(element_data) @@ -549,7 +709,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N except BaseException as e: logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name) raise ValueError(e.__str__() + "\nFor 'load_checkpoint', " - "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e + "failed to load the checkpoint file {}.".format(ckpt_file_name)) from e if not parameter_dict: raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether " @@ -617,10 +777,10 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode): except BaseException as e: if _is_cipher_file(ckpt_file_name): err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \ - "please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name) + "please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name) else: err_info = "Failed to read the checkpoint file {}. May not have permission to read it, please check" \ - " the correct of the file.".format(ckpt_file_name) + " the correct of the file.".format(ckpt_file_name) logger.error(err_info) raise ValueError(err_info) from e return checkpoint_list @@ -885,6 +1045,20 @@ def export(net, *inputs, file_name, file_format, **kwargs): - For details of using the customized encryption, please check the `tutorial `_. + - obf_config (dict): obfuscation config. + + - type (str): The type of obfuscation, only 'dynamic' is supported until now. + - obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio` + should be in range of (0, 1] or in ["small", "medium", "large"]. + - customized_func (function): A python function used for customized function mode, which used for control + the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This + function needs to ensure that its result is constant for any input. Users can refer to opaque + predicates. If customized_func is set, then it should be passed to `load()` interface when loading + obfuscated model. + - obf_password (int): A password used for password mode, which should be larger than zero. If + obf_password is set, then it should be passed to `nn.GraphCell()` interface when loading obfuscated + model. It should be noted that at least one of 'customized_func' or 'obf_password' should be set, and + 'obf_password' mode would be applied if both of them are set. Examples: >>> import mindspore as ms >>> import numpy as np @@ -920,11 +1094,8 @@ def export(net, *inputs, file_name, file_format, **kwargs): file_name = os.path.realpath(file_name) net = _quant_export(net, *inputs, file_format=file_format, **kwargs) if 'enc_key' in kwargs.keys(): - enc_key, enc_mode = _check_key_mode_type(file_format, **kwargs) - dataset = kwargs.get('dataset') - _export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset) - else: - _export(net, file_name, file_format, *inputs, **kwargs) + kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs) + _export(net, file_name, file_format, *inputs, **kwargs) def _export(net, file_name, file_format, *inputs, **kwargs): @@ -1150,13 +1321,40 @@ def _cell_info(net, *inputs): do_convert=False, auto_parallel_mode=net._auto_parallel_mode) # pylint: disable=protected-access mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') + # clean obfuscation config to prevent the next call + _executor.obfuscate_config = None net_dict = net.parameters_dict() return mindir_stream, net_dict +def _set_obfuscate_config(**kwargs): + """Set obfuscation config for executor.""" + logger.warning("Obfuscate model.") + if 'enc_mode' in kwargs.keys(): + enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str) + if enc_mode not in ["AES-GCM", "AES-CBC"]: + raise ValueError( + "Only MindIR files that encrypted with 'AES-GCM' or 'AES-CBC' is supported for obfuscation," + "but got {}.".format(enc_mode)) + obf_ratio, customized_funcs, obf_password = _check_obfuscate_params(kwargs.get('obf_config')) + if customized_funcs and obf_password > 0: + logger.warning("Although customized_func and obf_password are set, the 'obf_password' mode would be" + "applied, remember to set obf_password when loading obfuscated model.") + + if obf_password == 0: # apply customized_func mode + clean_funcs() + for func in customized_funcs: + add_opaque_predicate(func.__name__, func) + _executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_password': obf_password} + + def _save_mindir(net, file_name, *inputs, **kwargs): """Save MindIR format file.""" + # set obfuscate configs + if 'obf_config' in kwargs.keys(): + _set_obfuscate_config(**kwargs) + model = mindir_model() if not isinstance(net, nn.Cell): mindir_stream, net_dict = _msfunc_info(net, *inputs) @@ -1248,6 +1446,7 @@ def _save_dataset_to_mindir(model, dataset): def quant_mode_manage(func): """Inherit the quant_mode in old version.""" + @functools.wraps(func) def warpper(network, *inputs, file_format, **kwargs): if 'quant_mode' not in kwargs: @@ -1746,8 +1945,8 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}" " and group is {}".format(param.name, split_param.data.shape, opt_shard_group)) raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice" - f" in load distributed checkpoint for {param.name}. Data shape is " - f"{split_param.data.shape} and group is {opt_shard_group}.") from e + f" in load distributed checkpoint for {param.name}. Data shape is " + f"{split_param.data.shape} and group is {opt_shard_group}.") from e split_param = Parameter(Tensor(data_slice), param.name, split_param.requires_grad, split_param.layerwise_parallel) param_dict[param.name] = split_param