forked from mindspore-Ecosystem/mindspore
Add dynamic obfucation tool
This commit is contained in:
parent
3420611c13
commit
969e368cc5
|
@ -23,6 +23,8 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
|
||||
"mindspore/mindspore/ccsrc/backend/common/optimizer/op_adaptation_info_factory.h" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/concatv2_impl.cu" "runtime/int"
|
||||
"mindspore/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc" "runtime/threadsafe_fn"
|
||||
"mindspore/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc" "runtime/references"
|
||||
|
||||
# Modelzoo
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
|
||||
|
|
|
@ -32,3 +32,10 @@ mindspore.export
|
|||
- 对于'MINDIR'格式的模型,支持的加密选项有:'AES-GCM','AES-CBC'和用户自定义加密算法。默认值:"AES-GCM"。
|
||||
- 关于使用自定义加密导出的详情,请查看 `教程 <https://www.mindspore.cn/mindarmour/docs/zh-CN/master/model_encrypt_protection.html>`_。
|
||||
- **dataset** (Dataset) - 指定数据集的预处理方法,用于将数据集的预处理导入MindIR。
|
||||
|
||||
- **obf_config** (dict) - 模型混淆配置选项字典。
|
||||
|
||||
- **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。
|
||||
- **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。
|
||||
- **customized_func** (function) - 在自定义函数模式下需要设置的Python函数,用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型,且是恒定的,用户可以参考不透明谓词进行设置。如果设置了`customized_func`,那么在使用`load`接口导入模型的时候,需要把这个函数也传入。
|
||||
- **obf_password** (int) - 秘密口令,用于password模式,是一个大于0的整数。如果用户设置了`obf_password`,那么在部署混淆模型的时候,需要在`nn.GraphCell()`接口中传入`obf_password`。需要注意的是,如果用户同时设置了`customized_func`和`obf_password`,那么password模式将会被采用。
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
mindspore.obfuscate_model
|
||||
=========================
|
||||
|
||||
.. py:function:: mindspore.obfuscate_model(obf_config, **kwargs)
|
||||
|
||||
对MindIR格式的模型进行混淆,混淆主要是修改模型的网络结构但不影响它的推理精度,混淆后的模型可以防止被盗用。
|
||||
|
||||
参数:
|
||||
- **obf_config** (dict) - 模型混淆配置选项字典。
|
||||
|
||||
- **type** (str) - 混淆类型,目前支持动态混淆,即'dynamic'。
|
||||
- **original_model_path** (str) - 待混淆的MindIR模型地址。如果该模型是加密文件的,则需要在`kwargs`中传入`enc_key`和`enc_mode`。
|
||||
- **save_model_path** (str) - 混淆模型的保存地址。
|
||||
- **model_inputs** (list[Tensor]) - 模型的推理输入,Tensor的值可以是随机的,和使用`export()`接口类似。
|
||||
- **obf_ratio** (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串"small"、"medium"、"large"。
|
||||
- **customized_func** (function) - 在自定义函数模式下需要设置的Python函数,用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型,且是恒定的,用户可以参考不透明谓词进行设置。如果设置了`customized_func`,那么在使用`load`接口导入模型的时候,需要把这个函数也传入。
|
||||
- **obf_password** (int) - 秘密口令,用于password模式,是一个大于0的整数。如果用户设置了`obf_password`,那么在部署混淆模型的时候,需要在`nn.GraphCell()`接口中传入`obf_password`。需要注意的是,如果用户同时设置了`customized_func`和`obf_password`,那么password模式将会被采用。
|
||||
|
||||
- **kwargs** (dict) - 配置选项字典。
|
||||
|
||||
- **enc_key** (str) - 用于加密的字节类型密钥,有效长度为16、24或者32。
|
||||
- **enc_mode** (Union[str, function]) - 指定加密模式,当设置 `enc_key` 时启用。支持的加密选项有:'AES-GCM','AES-CBC'。默认值:"AES-GCM"。
|
|
@ -234,6 +234,7 @@ Serialization
|
|||
mindspore.save_checkpoint
|
||||
mindspore.transform_checkpoint_by_rank
|
||||
mindspore.transform_checkpoints
|
||||
mindspore.obfuscate_model
|
||||
|
||||
JIT
|
||||
---
|
||||
|
|
|
@ -147,6 +147,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_func_graph", &GraphExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.")
|
||||
.def("get_func_graph_proto", &GraphExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
|
||||
py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
|
||||
.def("get_obfuscate_func_graph_proto", &GraphExecutorPy::GetObfuscateFuncGraphProto, py::arg("phase") = py::str(""),
|
||||
py::arg("obf_ratio") = py::float_(1.0), py::arg("obf_pasword") = py::int_(0),
|
||||
py::arg("append_password") = py::int_(0), "Get graph proto of dynamic-obfuscated model.")
|
||||
.def("get_params", &GraphExecutorPy::GetParams, py::arg("phase") = py::str(""), "Get Parameters from graph")
|
||||
.def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
|
||||
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
|
||||
|
@ -200,11 +203,13 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset.");
|
||||
(void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode.");
|
||||
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
|
||||
|
||||
(void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
|
||||
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"),
|
||||
py::arg("decrypt") = py::none(), "Load model as Graph.");
|
||||
|
||||
py::arg("decrypt") = py::none(), py::arg("obfuscated") = py::bool_(false), "Load model as Graph.");
|
||||
(void)m.def("dynamic_obfuscate_mindir", &mindspore::pipeline::DynamicObfuscateMindIR, py::arg("file_name"),
|
||||
py::arg("obf_ratio"), py::arg("obf_password") = py::int_(0), py::arg("append_password") = py::int_(0),
|
||||
py::arg("dec_key") = nullptr, py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"),
|
||||
"Obfuscate a mindir model by dynamic obfuscation.");
|
||||
(void)m.def("init_cluster", &mindspore::distributed::Initialize, "Init Cluster");
|
||||
(void)m.def("set_cluster_exit_with_exception", &mindspore::distributed::set_cluster_exit_with_exception,
|
||||
"Set this process exits with exception.");
|
||||
|
|
|
@ -66,6 +66,8 @@
|
|||
#include "runtime/pynative/op_executor.h"
|
||||
#include "runtime/device/stream_synchronizer.h"
|
||||
#include "distributed/collective/collective_manager.h"
|
||||
#include "mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h"
|
||||
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
|
||||
|
||||
#if defined(__linux__) && defined(WITH_BACKEND)
|
||||
#include "ps/constants.h"
|
||||
|
@ -442,6 +444,7 @@ py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std
|
|||
}
|
||||
|
||||
if (ir_type == IR_TYPE_MINDIR) {
|
||||
// obfuscate model
|
||||
std::string proto_str = GetBinaryProtoString(fg_ptr);
|
||||
if (proto_str.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Export MINDIR format model failed.";
|
||||
|
@ -452,6 +455,24 @@ py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std
|
|||
MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type;
|
||||
}
|
||||
|
||||
py::bytes GraphExecutorPy::GetObfuscateFuncGraphProto(const std::string &phase, const float obf_ratio,
|
||||
const int obf_password, const int append_password) {
|
||||
FuncGraphPtr fg_ptr = GetFuncGraph(phase);
|
||||
// obfuscate model
|
||||
if (obf_password == 0) {
|
||||
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
|
||||
MS_LOG(DEBUG) << "[GetObfuscateFuncGraphProto] set customized function names finished";
|
||||
}
|
||||
mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, obf_password, append_password);
|
||||
mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(fg_ptr);
|
||||
|
||||
std::string proto_str = GetBinaryProtoString(obfuscated_graph);
|
||||
if (proto_str.empty()) {
|
||||
MS_LOG(EXCEPTION) << "GetBinaryProtoString failed.";
|
||||
}
|
||||
return proto_str;
|
||||
}
|
||||
|
||||
py::bytes GraphExecutorPy::GetOptimizeGraphProto(const std::string &phase) {
|
||||
if (info_.count(phase) == 0) {
|
||||
MS_LOG(EXCEPTION) << "No phase in executor: " << phase;
|
||||
|
@ -1255,6 +1276,8 @@ void GraphExecutorPy::TerminateDebugger() {
|
|||
#endif
|
||||
|
||||
py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) {
|
||||
// init for dynamic-obfuscated model infer
|
||||
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
|
||||
// Mindspore debugger notify main thread to exit after one step, and will not run next step
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
TerminateDebugger();
|
||||
|
@ -1655,7 +1678,12 @@ void GraphExecutorPy::ExportGraph(const std::string &file_name, const std::strin
|
|||
}
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode, const py::object decrypt) {
|
||||
const std::string &dec_mode, const py::object decrypt, const bool obfuscated) {
|
||||
if (obfuscated) {
|
||||
MS_LOG(DEBUG) << "[LoadMindIR] Set customized function.";
|
||||
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
|
||||
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
|
||||
}
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
if (dec_mode == "Customized") {
|
||||
py::bytes key_bytes(dec_key);
|
||||
|
@ -1679,6 +1707,28 @@ FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int obf_password,
|
||||
int append_password, char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
if (obf_password == 0) {
|
||||
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
|
||||
MS_LOG(DEBUG) << "[DynamicObfuscateMindIR] set function names finished.";
|
||||
}
|
||||
mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, obf_password, append_password);
|
||||
MindIRLoader mindir_loader(false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode, false);
|
||||
FuncGraphPtr func_graph = mindir_loader.LoadMindIR(file_name);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "[DynamicObfuscateMindIR] load mindir failed, please check the mindir file.";
|
||||
return nullptr;
|
||||
}
|
||||
mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(func_graph);
|
||||
if (obfuscated_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "[DynamicObfuscateMindIR] obfuscate model failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return obfuscated_graph;
|
||||
}
|
||||
|
||||
void CloseTsd(bool force) {
|
||||
#ifdef WITH_BACKEND
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
|
|
@ -88,6 +88,8 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
FuncGraphPtr GetGradGraph(const std::string &phase);
|
||||
void SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase);
|
||||
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &ir_type);
|
||||
py::bytes GetObfuscateFuncGraphProto(const std::string &phase, const float obf_ratio, const int obf_password,
|
||||
const int append_password);
|
||||
#ifndef ENABLE_SECURITY
|
||||
py::bytes GetOptimizeGraphProto(const std::string &phase);
|
||||
#endif
|
||||
|
@ -185,7 +187,8 @@ void CloseTsd(bool force = false);
|
|||
void MemoryRecycle();
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode, const py::object decrypt = py::none());
|
||||
const std::string &dec_mode, const py::object decrypt = py::none(),
|
||||
const bool obfuscated = false);
|
||||
|
||||
// init and exec dataset sub graph
|
||||
bool ME_EXPORT InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
|
||||
|
@ -203,6 +206,9 @@ py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_le
|
|||
py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode);
|
||||
bool PyIsCipherFile(const std::string &file_path);
|
||||
void FinalizeCluster();
|
||||
FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int obf_password,
|
||||
int append_password, char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/opaque_predicate_kernel.h"
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kInputsNum = 2;
|
||||
constexpr size_t kOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
bool OpaquePredicateKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto input1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
|
||||
output[0] =
|
||||
CustomizedOpaquePredicate::GetInstance().run_function(static_cast<float>(*input1), static_cast<float>(*input2));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool OpaquePredicateKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, OpaquePredicateKernelMod::OpaquePredicateFunc>> OpaquePredicateKernelMod::func_list_ =
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
&OpaquePredicateKernelMod::LaunchKernel<float>}};
|
||||
|
||||
std::vector<KernelAttr> OpaquePredicateKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, OpaquePredicateFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, OpaquePredicate, OpaquePredicateKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_
|
||||
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class OpaquePredicateKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
OpaquePredicateKernelMod() = default;
|
||||
~OpaquePredicateKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using OpaquePredicateFunc =
|
||||
std::function<bool(OpaquePredicateKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
OpaquePredicateFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, OpaquePredicateFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_OPAQUE_PREDICATE_KERNEL_H_
|
|
@ -0,0 +1,494 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h"
|
||||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "utils/info.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
using Tensor = mindspore::tensor::Tensor;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTensorPtr;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractTuplePtr;
|
||||
|
||||
constexpr int expand_rate = 10; // total node need for a switch graph
|
||||
|
||||
ShapeVector get_node_shape(AnfNodePtr input_node) {
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Input_node is nullptr, get shape failed!";
|
||||
return {};
|
||||
}
|
||||
AbstractBasePtr input_abstract = input_node->abstract();
|
||||
if (input_abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "The abstract of input_node is nullptr, get shape failed!";
|
||||
return {};
|
||||
}
|
||||
AbstractTensorPtr input_abstract_tensor = input_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
|
||||
mindspore::abstract::ShapePtr shape_ptr = input_abstract_tensor->shape();
|
||||
return shape_ptr->shape();
|
||||
}
|
||||
|
||||
TypeId get_node_dtype(AnfNodePtr input_node) {
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Input_node is nullptr, get dtype failed!";
|
||||
return {};
|
||||
}
|
||||
AbstractBasePtr input_abstract = input_node->abstract();
|
||||
if (input_abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "The abstract of input_node is nullptr, get dtype failed!";
|
||||
return {};
|
||||
}
|
||||
AbstractTensorPtr input_abstract_tensor = input_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
|
||||
AbstractBasePtr node_element = input_abstract_tensor->element();
|
||||
mindspore::abstract::AbstractScalarPtr node_element_abs =
|
||||
node_element->cast<mindspore::abstract::AbstractScalarPtr>();
|
||||
|
||||
TypeId data_type = node_element_abs->BuildType()->type_id();
|
||||
return data_type;
|
||||
}
|
||||
|
||||
std::vector<std::string> name_split(std::string &node_name, const std::string &split_sign) {
|
||||
node_name += split_sign;
|
||||
unsigned int name_len = node_name.size();
|
||||
std::string::size_type split_pos;
|
||||
std::vector<std::string> res;
|
||||
for (unsigned int i = 0; i < name_len; i++) {
|
||||
split_pos = node_name.find(split_sign, i);
|
||||
if (split_pos < name_len) {
|
||||
std::string sub_str = node_name.substr(i, split_pos - i);
|
||||
res.push_back(sub_str);
|
||||
i = split_pos + split_sign.size() - 1;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
ValueNodePtr build_tuple_value_node(std::vector<int64_t> values) {
|
||||
mindspore::ValueNodePtr v_node = std::make_shared<mindspore::ValueNode>(MakeValue(values));
|
||||
AbstractBasePtrList abs_list;
|
||||
std::transform(values.begin(), values.end(), std::back_inserter(abs_list), [](const int64 &item) {
|
||||
return std::make_shared<mindspore::abstract::AbstractScalar>(int64_t(item));
|
||||
});
|
||||
auto abs_tuple = std::make_shared<mindspore::abstract::AbstractTuple>(abs_list);
|
||||
v_node->set_abstract(abs_tuple);
|
||||
return v_node;
|
||||
}
|
||||
|
||||
ValueNodePtr make_int_node(FuncGraphPtr func_graph, int int_value) {
|
||||
ShapeVector int_shape{1, 1};
|
||||
tensor::TensorPtr int_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeInt32, int_shape);
|
||||
int *tensor_data = reinterpret_cast<int *>(int_tensor->data_c());
|
||||
for (int i = 0; i < int_tensor->data().size(); i++) {
|
||||
tensor_data[i] = int_value;
|
||||
}
|
||||
mindspore::ValueNodePtr int_tensor_node = std::make_shared<mindspore::ValueNode>(int_tensor);
|
||||
int_tensor_node->set_abstract(int_tensor->ToAbstract());
|
||||
(void)func_graph->AddValueNode(int_tensor_node);
|
||||
return int_tensor_node;
|
||||
}
|
||||
|
||||
tensor::TensorPtr make_weight_tensor(TypeId type_id, ShapeVector shape) {
|
||||
tensor::TensorPtr weight_tensor = std::make_shared<Tensor>(type_id, shape);
|
||||
std::default_random_engine generator;
|
||||
int max_count = 10000;
|
||||
int tensor_size = weight_tensor->data().size();
|
||||
if (type_id == kNumberTypeFloat64) {
|
||||
const double mean_64 = 0;
|
||||
const double stddev_64 = 1;
|
||||
std::normal_distribution<double> dist_64(mean_64, stddev_64);
|
||||
double *float_64_data = reinterpret_cast<double *>(weight_tensor->data_c());
|
||||
for (int i = 0; i < std::min(tensor_size, max_count); i++) {
|
||||
double random_float_64 = dist_64(generator);
|
||||
if (random_float_64 > 0) {
|
||||
float_64_data[i] = random_float_64;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Type id is: " << type_id << ", weights will be float_32 format.";
|
||||
const float mean = 0;
|
||||
const float stddev = 1;
|
||||
std::normal_distribution<float> dist_32(mean, stddev);
|
||||
float *float_32_data = reinterpret_cast<float *>(weight_tensor->data_c());
|
||||
for (int i = 0; i < std::min(tensor_size, max_count); i++) {
|
||||
float random_float_32 = dist_32(generator);
|
||||
if (random_float_32 > 0) {
|
||||
float_32_data[i] = random_float_32;
|
||||
}
|
||||
}
|
||||
}
|
||||
return weight_tensor;
|
||||
}
|
||||
|
||||
bool CheckIfObfuscated(const FuncGraphPtr &func_graph) {
|
||||
auto mgr = Manage(func_graph);
|
||||
auto all_nodes = mgr->all_nodes();
|
||||
for (AnfNodePtr node : all_nodes) {
|
||||
std::string node_name = node->fullname_with_scope();
|
||||
if (node_name.find("Switch") != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
FuncGraphPtr DynamicObfuscator::ObfuscateMindIR(const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(INFO) << "Start obfuscation.";
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (CheckIfObfuscated(func_graph)) {
|
||||
MS_EXCEPTION(ValueError) << "The input model has been onfuscated, do not obfuscate it again.";
|
||||
}
|
||||
auto mgr = Manage(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
auto all_nodes = mgr->all_nodes();
|
||||
int node_nums = all_nodes.size();
|
||||
MS_LOG(INFO) << "Total node num: " << node_nums;
|
||||
// init the number control node that has been build
|
||||
used_control_node_ = 0;
|
||||
if (obf_password_ == 0) {
|
||||
int obfuscate_target_num = std::ceil(all_nodes.size() * obf_ratio_ / expand_rate);
|
||||
int obfuscate_node_num = 0;
|
||||
// record customized_func computing results
|
||||
for (AnfNodePtr node : all_nodes) {
|
||||
std::string obf_type = single_op_obfuscate_type(node);
|
||||
MS_LOG(INFO) << "obf_type: " << obf_type;
|
||||
if (obf_type == "MatMul-op") {
|
||||
obfuscate_node_num += 1;
|
||||
MS_LOG(INFO) << "Find a MatMul Node: " << node->fullname_with_scope();
|
||||
bool customized_func_result = mindspore::kernel::CustomizedOpaquePredicate::GetInstance().run_function(
|
||||
static_cast<float>(1), static_cast<float>(1));
|
||||
customized_func_results_.push_back(customized_func_result);
|
||||
}
|
||||
if (obfuscate_node_num >= obfuscate_target_num) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
|
||||
}
|
||||
// do op-wise fake-branch obfuscation
|
||||
(void)op_wise_fake_branch(func_graph);
|
||||
if (used_control_node_ == 0) {
|
||||
MS_LOG(WARNING)
|
||||
<< "The model has not been obfuscated, which means obf_password or customized_func is not need to set.";
|
||||
}
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
void DynamicObfuscator::op_wise_fake_branch(FuncGraphPtr func_graph) {
|
||||
auto mgr = Manage(func_graph);
|
||||
auto all_nodes = mgr->all_nodes();
|
||||
int obfuscate_target_num = std::ceil(all_nodes.size() * obf_ratio_ / expand_rate);
|
||||
int obfuscate_node_num = 0;
|
||||
for (AnfNodePtr node : all_nodes) {
|
||||
std::string obf_type = single_op_obfuscate_type(node);
|
||||
MS_LOG(INFO) << "The obf_type is: " << obf_type;
|
||||
if (obf_type == "MatMul-op") {
|
||||
obfuscate_node_num += 1;
|
||||
MS_LOG(INFO) << "Find a MatMul Node: " << node->fullname_with_scope();
|
||||
std::vector<AnfNodePtr> node_inputs = node->cast<mindspore::CNodePtr>()->inputs();
|
||||
mindspore::AnfNodePtr input_1 = node_inputs[1];
|
||||
CNodePtr control_c_node = get_control_node(func_graph, input_1);
|
||||
(void)replace_matmul_node(node->cast<CNodePtr>(), func_graph, control_c_node);
|
||||
MS_LOG(INFO) << "Finished replacement for: " << node->fullname_with_scope();
|
||||
}
|
||||
if (obfuscate_node_num >= obfuscate_target_num) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string DynamicObfuscator::single_op_obfuscate_type(AnfNodePtr node) {
|
||||
if (node->isa<CNode>()) {
|
||||
std::string node_name = node->fullname_with_scope();
|
||||
MS_LOG(INFO) << "The node_name is: " << node_name;
|
||||
std::vector<std::string> split_words = name_split(node_name, "/");
|
||||
std::string op_name = split_words[split_words.size() - 1];
|
||||
for (std::string target_op_name : obf_target_op) {
|
||||
int op_name_len = op_name.size();
|
||||
int target_name_len = target_op_name.size();
|
||||
if ((op_name_len >= target_name_len) && (op_name.substr(0, target_name_len) == target_op_name)) {
|
||||
return target_op_name;
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
CNodePtr DynamicObfuscator::password_mode_control(FuncGraphPtr func_graph) {
|
||||
ShapeVector y_shape{1, 1};
|
||||
tensor::TensorPtr y_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeInt32, y_shape);
|
||||
if (!has_build_appended_input) {
|
||||
MS_LOG(INFO) << "Build parameter y and y_append.";
|
||||
auto y = func_graph->add_parameter();
|
||||
y->set_name("y");
|
||||
y->set_abstract(y_tensor->ToAbstract());
|
||||
auto y_append = func_graph->add_parameter();
|
||||
y_append->set_name("y_append");
|
||||
y_append->set_abstract(y_tensor->ToAbstract());
|
||||
has_build_appended_input = true;
|
||||
}
|
||||
auto y = func_graph->GetParameterByName("y");
|
||||
auto y_append = func_graph->GetParameterByName("y_append");
|
||||
|
||||
if (used_control_node_ == 0) {
|
||||
// make add function node
|
||||
mindspore::PrimitivePtr add_prim = mindspore::prim::kPrimAdd;
|
||||
add_prim->set_attr("is_load", MakeValue(true));
|
||||
mindspore::ValueNodePtr add_v_node = std::make_shared<mindspore::ValueNode>(add_prim);
|
||||
(void)func_graph->AddValueNode(add_v_node);
|
||||
CNodePtr add_c_node = func_graph->NewCNode({add_v_node, y, y_append});
|
||||
add_c_node->set_abstract(y_tensor->ToAbstract());
|
||||
// make equal function node
|
||||
ValueNodePtr equal_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimEqual);
|
||||
(void)func_graph->AddValueNode(equal_v_node);
|
||||
ValueNodePtr equal_compa_node = make_int_node(func_graph, obf_password_ + append_password_);
|
||||
CNodePtr equal_c_node = func_graph->NewCNode({equal_v_node, add_c_node, equal_compa_node});
|
||||
tensor::TensorPtr equal_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
|
||||
equal_c_node->set_abstract(equal_tensor->ToAbstract());
|
||||
(void)func_graph->AddNode(equal_c_node);
|
||||
used_control_node_ += 1;
|
||||
switch_branch_ = true;
|
||||
return equal_c_node;
|
||||
}
|
||||
// make greater function node
|
||||
int comparison_int = rand();
|
||||
ValueNodePtr greater_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimGreater);
|
||||
(void)func_graph->AddValueNode(greater_v_node);
|
||||
ValueNodePtr greater_compa_node = make_int_node(func_graph, comparison_int);
|
||||
CNodePtr greater_c_node = func_graph->NewCNode({greater_v_node, y, greater_compa_node});
|
||||
tensor::TensorPtr greater_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
|
||||
greater_c_node->set_abstract(greater_tensor->ToAbstract());
|
||||
(void)func_graph->AddNode(greater_c_node);
|
||||
used_control_node_ += 1;
|
||||
switch_branch_ = obf_password_ > comparison_int;
|
||||
return greater_c_node;
|
||||
}
|
||||
|
||||
mindspore::CNodePtr AddStrideSliceNode(FuncGraphPtr func_graph, ShapeVector begin_vector, ShapeVector stride_vector,
|
||||
ShapeVector end_vector, int end_mask, int begin_mask,
|
||||
mindspore::CNodePtr prev_node) {
|
||||
mindspore::ValueNodePtr begin_v_node = build_tuple_value_node(begin_vector);
|
||||
mindspore::ValueNodePtr stride_v_node = build_tuple_value_node(stride_vector);
|
||||
mindspore::ValueNodePtr end_v_node = build_tuple_value_node(end_vector);
|
||||
(void)func_graph->AddValueNode(begin_v_node);
|
||||
(void)func_graph->AddValueNode(stride_v_node);
|
||||
(void)func_graph->AddValueNode(end_v_node);
|
||||
mindspore::PrimitivePtr slice_prim = mindspore::prim::kPrimStridedSlice;
|
||||
slice_prim->set_attr("is_load", MakeValue(true));
|
||||
slice_prim->set_attr("new_axis_mask", MakeValue(int64_t(0)));
|
||||
slice_prim->set_attr("shrink_axis_mask", MakeValue(int64_t(1)));
|
||||
slice_prim->set_attr("end_mask", MakeValue(int64_t(end_mask)));
|
||||
slice_prim->set_attr("begin_mask", MakeValue(int64_t(begin_mask)));
|
||||
slice_prim->set_attr("ellipsis_mask", MakeValue(int64_t(0)));
|
||||
mindspore::ValueNodePtr slice_v_node = std::make_shared<mindspore::ValueNode>(slice_prim);
|
||||
(void)func_graph->AddValueNode(slice_v_node);
|
||||
mindspore::CNodePtr slice_c_node =
|
||||
func_graph->NewCNode({slice_v_node, prev_node, begin_v_node, end_v_node, stride_v_node});
|
||||
return slice_c_node;
|
||||
}
|
||||
|
||||
CNodePtr DynamicObfuscator::custom_op_mode_control(FuncGraphPtr func_graph, AnfNodePtr prev_node) {
|
||||
mindspore::PrimitivePtr reshape_prim = mindspore::prim::kPrimReshape;
|
||||
reshape_prim->set_attr("is_load", MakeValue(true));
|
||||
mindspore::ValueNodePtr reshape_v_node = std::make_shared<mindspore::ValueNode>(reshape_prim);
|
||||
(void)func_graph->AddValueNode(reshape_v_node);
|
||||
ShapeVector prev_node_shape = get_node_shape(prev_node);
|
||||
int shape_multiply = std::accumulate(prev_node_shape.begin(), prev_node_shape.end(), 1, std::multiplies<int>());
|
||||
MS_LOG(INFO) << "The shape_multiply is: " << shape_multiply;
|
||||
|
||||
ShapeVector flat_shape{1, shape_multiply};
|
||||
mindspore::ValueNodePtr shape_v_node = std::make_shared<mindspore::ValueNode>(MakeValue(flat_shape));
|
||||
(void)func_graph->AddValueNode(shape_v_node);
|
||||
mindspore::CNodePtr reshape_c_node = func_graph->NewCNode({reshape_v_node, prev_node, shape_v_node});
|
||||
TypeId data_type = get_node_dtype(prev_node);
|
||||
auto reshape_abstract = std::make_shared<Tensor>(data_type, flat_shape)->ToAbstract();
|
||||
reshape_c_node->set_abstract(reshape_abstract);
|
||||
(void)func_graph->AddNode(reshape_c_node);
|
||||
|
||||
// the first stride_slice x[0]
|
||||
ShapeVector begin_1{0, 0};
|
||||
ShapeVector stride_1{1, 1};
|
||||
mindspore::CNodePtr slice_c_node_1 =
|
||||
AddStrideSliceNode(func_graph, begin_1, stride_1, flat_shape, 2, 2, reshape_c_node);
|
||||
ShapeVector slice_1_shape{shape_multiply};
|
||||
slice_c_node_1->set_abstract(std::make_shared<Tensor>(data_type, slice_1_shape)->ToAbstract());
|
||||
(void)func_graph->AddNode(slice_c_node_1);
|
||||
|
||||
// the first stride_slice x[0][0]
|
||||
ShapeVector begin_2{0};
|
||||
ShapeVector end_2{1};
|
||||
ShapeVector stride_2{1};
|
||||
mindspore::CNodePtr slice_c_node_2 =
|
||||
AddStrideSliceNode(func_graph, begin_2, stride_2, stride_2, 0, 0, slice_c_node_1);
|
||||
ShapeVector slice_2_shape{1};
|
||||
slice_c_node_2->set_abstract(std::make_shared<Tensor>(data_type, slice_2_shape)->ToAbstract());
|
||||
(void)func_graph->AddNode(slice_c_node_2);
|
||||
|
||||
// the second stride_slice x[0][1]
|
||||
ShapeVector begin_3{1};
|
||||
ShapeVector end_3{1};
|
||||
ShapeVector stride_3{2};
|
||||
mindspore::CNodePtr slice_c_node_3 =
|
||||
AddStrideSliceNode(func_graph, begin_3, stride_3, stride_3, 0, 0, slice_c_node_1);
|
||||
ShapeVector slice_3_shape{1};
|
||||
slice_c_node_3->set_abstract(std::make_shared<Tensor>(data_type, slice_3_shape)->ToAbstract());
|
||||
(void)func_graph->AddNode(slice_c_node_3);
|
||||
|
||||
// add opaque predicate
|
||||
PrimitivePtr custom_prim = mindspore::prim::kPrimOpaquePredicate;
|
||||
custom_prim->set_attr("is_load", MakeValue(true));
|
||||
std::vector<ValuePtr> input_names_value;
|
||||
input_names_value.push_back(std::make_shared<StringImm>("x"));
|
||||
input_names_value.push_back(std::make_shared<StringImm>("y"));
|
||||
custom_prim->set_attr("input_names", std::make_shared<ValueList>(input_names_value));
|
||||
std::vector<ValuePtr> output_names_value;
|
||||
output_names_value.push_back(std::make_shared<StringImm>("output"));
|
||||
custom_prim->set_attr("output_names", std::make_shared<ValueList>(output_names_value));
|
||||
auto opaque_v_node = std::make_shared<mindspore::ValueNode>(custom_prim);
|
||||
(void)func_graph->AddValueNode(opaque_v_node);
|
||||
auto opaque_c_node = func_graph->NewCNode({opaque_v_node, slice_c_node_2, slice_c_node_3});
|
||||
ShapeVector y_shape{1, 1};
|
||||
auto bool_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
|
||||
opaque_c_node->set_abstract(bool_tensor->ToAbstract());
|
||||
(void)func_graph->AddNode(opaque_c_node);
|
||||
return opaque_c_node;
|
||||
}
|
||||
|
||||
CNodePtr DynamicObfuscator::get_control_node(FuncGraphPtr func_graph, AnfNodePtr prev_node) {
|
||||
if (obf_password_ != 0) {
|
||||
MS_LOG(INFO) << "Run password mode.";
|
||||
return password_mode_control(func_graph);
|
||||
}
|
||||
MS_LOG(INFO) << "Run customized function mode.";
|
||||
return custom_op_mode_control(func_graph, prev_node);
|
||||
}
|
||||
|
||||
void DynamicObfuscator::replace_matmul_node(CNodePtr node, FuncGraphPtr func_graph, CNodePtr control_node) {
|
||||
std::vector<AnfNodePtr> node_inputs = node->cast<mindspore::CNodePtr>()->inputs();
|
||||
mindspore::ValueNodePtr matmul_v_node = node_inputs[0]->cast<mindspore::ValueNodePtr>();
|
||||
mindspore::AnfNodePtr input_1 = node_inputs[1];
|
||||
mindspore::AnfNodePtr input_2 = node_inputs[2];
|
||||
|
||||
// construct branch 1
|
||||
mindspore::FuncGraphPtr fg_1 = std::make_shared<FuncGraph>();
|
||||
|
||||
// input_x
|
||||
ParameterPtr branch_1_input_x = fg_1->add_parameter();
|
||||
branch_1_input_x->set_abstract(input_1->abstract());
|
||||
branch_1_input_x->set_name("branch_1_input_x");
|
||||
|
||||
// input_y
|
||||
ParameterPtr branch_1_input_y = fg_1->add_parameter();
|
||||
branch_1_input_y->set_abstract(input_2->abstract());
|
||||
branch_1_input_y->set_name("branch_1_input_y");
|
||||
|
||||
mindspore::CNodePtr matmul_c_node_1 = fg_1->NewCNode({matmul_v_node, branch_1_input_x, branch_1_input_y});
|
||||
matmul_c_node_1->set_abstract(node->cast<mindspore::CNodePtr>()->abstract());
|
||||
(void)fg_1->AddNode(matmul_c_node_1);
|
||||
|
||||
// add return node
|
||||
mindspore::ValueNodePtr return_v_node_1 = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimReturn);
|
||||
(void)fg_1->AddValueNode(return_v_node_1);
|
||||
mindspore::CNodePtr branch_1_return = fg_1->NewCNode({return_v_node_1, matmul_c_node_1});
|
||||
(void)fg_1->AddNode(branch_1_return);
|
||||
fg_1->set_return(branch_1_return);
|
||||
fg_1->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
|
||||
mindspore::ValueNodePtr partial_v_node_1 = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimPartial);
|
||||
(void)func_graph->AddValueNode(partial_v_node_1);
|
||||
mindspore::ValueNodePtr fg_1_node = std::make_shared<mindspore::ValueNode>(fg_1);
|
||||
fg_1_node->set_abstract(fg_1->ToAbstract());
|
||||
(void)func_graph->AddValueNode(fg_1_node);
|
||||
mindspore::CNodePtr partial_c_node_1 = func_graph->NewCNode({partial_v_node_1, fg_1_node, input_1, input_2});
|
||||
(void)func_graph->AddNode(partial_c_node_1);
|
||||
|
||||
// construct branch 2
|
||||
mindspore::FuncGraphPtr fg_2 = std::make_shared<FuncGraph>();
|
||||
// add input_x
|
||||
ParameterPtr branch_2_input_x = fg_2->add_parameter();
|
||||
branch_2_input_x->set_abstract(input_1->abstract());
|
||||
branch_2_input_x->set_name("branch_2_input_x");
|
||||
// add input_y
|
||||
ParameterPtr branch_2_input_y = fg_2->add_parameter();
|
||||
branch_2_input_y->set_abstract(input_2->abstract());
|
||||
branch_2_input_y->set_name("branch_2_input_y");
|
||||
|
||||
// add matmul CNode
|
||||
mindspore::CNodePtr matmul_c_node_2 = fg_2->NewCNode({matmul_v_node, branch_2_input_x, branch_2_input_y});
|
||||
matmul_c_node_2->set_abstract(node->cast<mindspore::CNodePtr>()->abstract());
|
||||
(void)fg_2->AddNode(matmul_c_node_2);
|
||||
|
||||
// add return node
|
||||
mindspore::ValueNodePtr return_v_node_2 = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimReturn);
|
||||
(void)fg_2->AddValueNode(return_v_node_2);
|
||||
mindspore::CNodePtr branch_2_return = fg_2->NewCNode({return_v_node_2, matmul_c_node_2});
|
||||
(void)fg_2->AddNode(branch_2_return);
|
||||
fg_2->set_return(branch_2_return);
|
||||
fg_2->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
|
||||
// add partial for branch_2
|
||||
ShapeVector matmul_2_shape = get_node_shape(input_2);
|
||||
TypeId type_id = get_node_dtype(input_2);
|
||||
tensor::TensorPtr matmul_2_weight = make_weight_tensor(type_id, matmul_2_shape);
|
||||
mindspore::ValueNodePtr matmul_weight_v_node = std::make_shared<mindspore::ValueNode>(matmul_2_weight);
|
||||
matmul_weight_v_node->set_abstract(matmul_2_weight->ToAbstract());
|
||||
(void)func_graph->AddValueNode(matmul_weight_v_node);
|
||||
|
||||
mindspore::ValueNodePtr partial_v_node_2 = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimPartial);
|
||||
(void)func_graph->AddValueNode(partial_v_node_2);
|
||||
mindspore::ValueNodePtr fg_2_node = std::make_shared<mindspore::ValueNode>(fg_2);
|
||||
fg_2_node->set_abstract(fg_2->ToAbstract());
|
||||
(void)func_graph->AddValueNode(fg_2_node);
|
||||
mindspore::CNodePtr partial_c_node_2 =
|
||||
func_graph->NewCNode({partial_v_node_2, fg_2_node, input_1, matmul_weight_v_node});
|
||||
(void)func_graph->AddNode(partial_c_node_2);
|
||||
|
||||
// add switch node
|
||||
mindspore::ValueNodePtr switch_v_node = std::make_shared<mindspore::ValueNode>(mindspore::prim::kPrimSwitch);
|
||||
(void)func_graph->AddValueNode(switch_v_node);
|
||||
mindspore::CNodePtr switch_c_node;
|
||||
if (obf_password_ == 0) {
|
||||
int results_len = customized_func_results_.size();
|
||||
switch_branch_ = customized_func_results_[results_len - 1 - used_control_node_];
|
||||
used_control_node_ += 1;
|
||||
}
|
||||
if (switch_branch_) {
|
||||
switch_c_node = func_graph->NewCNode({switch_v_node, control_node, partial_c_node_1, partial_c_node_2});
|
||||
} else {
|
||||
switch_c_node = func_graph->NewCNode({switch_v_node, control_node, partial_c_node_2, partial_c_node_1});
|
||||
}
|
||||
func_graph->AddNode(switch_c_node);
|
||||
|
||||
// add call node
|
||||
mindspore::CNodePtr call_cnode = func_graph->NewCNode({switch_c_node});
|
||||
func_graph->AddNode(call_cnode);
|
||||
// add fg_1 and fg_2 to func_graph
|
||||
auto mgr = mindspore::Manage(func_graph);
|
||||
mgr->AddFuncGraph(fg_1);
|
||||
mgr->AddFuncGraph(fg_2);
|
||||
mgr->Replace(node, call_cnode);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_DYNAMIC_OBFUSCATION_H
|
||||
#define MINDSPORE_DYNAMIC_OBFUSCATION_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "include/common/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
class COMMON_EXPORT DynamicObfuscator {
|
||||
public:
|
||||
DynamicObfuscator(const float obf_ratio, const int obf_password, const int append_password)
|
||||
: obf_ratio_(obf_ratio), obf_password_(obf_password), append_password_(append_password) {}
|
||||
|
||||
~DynamicObfuscator() = default;
|
||||
|
||||
FuncGraphPtr ObfuscateMindIR(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
void op_wise_fake_branch(FuncGraphPtr func_graph);
|
||||
std::string single_op_obfuscate_type(AnfNodePtr node);
|
||||
CNodePtr get_control_node(FuncGraphPtr func_graph, AnfNodePtr prev_node);
|
||||
CNodePtr password_mode_control(FuncGraphPtr func_graph);
|
||||
CNodePtr custom_op_mode_control(FuncGraphPtr func_graph, AnfNodePtr prev_node);
|
||||
void replace_matmul_node(CNodePtr node, FuncGraphPtr func_graph, CNodePtr flag_node);
|
||||
|
||||
const float obf_ratio_ = 0.01;
|
||||
const int obf_password_;
|
||||
const int append_password_;
|
||||
bool has_build_appended_input = false;
|
||||
std::vector<bool> customized_func_results_;
|
||||
int used_control_node_ = 0;
|
||||
bool switch_branch_ = true;
|
||||
const std::vector<std::string> obf_target_op = {"MatMul-op", "Add-op", "Mat-op", "Sub-op", "Softmax-op", "Relu-op"};
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_DYNAMIC_OBFUSCATION_H
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
|
||||
#include <algorithm>
|
||||
#include "utils/info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
CustomizedOpaquePredicate &CustomizedOpaquePredicate::GetInstance() {
|
||||
static CustomizedOpaquePredicate instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void CustomizedOpaquePredicate::set_func_names() {
|
||||
// get the function of get_func_names()
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
static const std::string &module_name = "mindspore.ops.operations._opaque_predicate_registry";
|
||||
static const std::string &entrance = "get_func_names";
|
||||
py::module module = py::module::import(module_name.c_str());
|
||||
py::object get_pyfunc_obj = module.attr(entrance.c_str());
|
||||
if (get_pyfunc_obj.is_none()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name;
|
||||
}
|
||||
py::function get_pyfunc = get_pyfunc_obj.cast<py::function>();
|
||||
py::tuple func_name_list = get_pyfunc();
|
||||
// clear old functions
|
||||
func_names_.clear();
|
||||
for (size_t i = 0; i < func_name_list.size(); i++) {
|
||||
func_names_.push_back(py::str(func_name_list[i]));
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set function names finished, the number of functions is: " << func_names_.size();
|
||||
}
|
||||
|
||||
const std::vector<std::string> CustomizedOpaquePredicate::get_func_names() {
|
||||
if (func_names_.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The number of customized function names is zero, get function names failed.";
|
||||
}
|
||||
return func_names_;
|
||||
}
|
||||
|
||||
py::function CustomizedOpaquePredicate::get_function() {
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
// get the function of get_opaque_predicate()
|
||||
static const std::string &module_name = "mindspore.ops.operations._opaque_predicate_registry";
|
||||
static const std::string &entrance = "get_opaque_predicate";
|
||||
py::module module = py::module::import(module_name.c_str());
|
||||
py::object get_pyfunc_obj = module.attr(entrance.c_str());
|
||||
if (get_pyfunc_obj.is_none()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find a python function named " << entrance << "in module" << module_name;
|
||||
}
|
||||
py::function get_pyfunc = get_pyfunc_obj.cast<py::function>();
|
||||
MS_LOG(DEBUG) << "The number of function is : " << func_names_.size();
|
||||
if (func_names_.size() == 0) {
|
||||
MS_EXCEPTION(ValueError) << "The customized_func is not set, please set it in load().";
|
||||
}
|
||||
std::string func_name = func_names_[0];
|
||||
MS_LOG(DEBUG) << "Get function name: " << func_name;
|
||||
func_name_code_.clear();
|
||||
std::transform(func_name.begin(), func_name.end(), std::back_inserter(func_name_code_),
|
||||
[](const char &item) { return static_cast<int>(item); });
|
||||
if (func_name_code_.size() == 0) {
|
||||
MS_EXCEPTION(ValueError) << "Set func_name_code_ failed.";
|
||||
}
|
||||
py::object py_func_obj = get_pyfunc(py::str(func_name));
|
||||
if (py_func_obj.is_none()) {
|
||||
MS_EXCEPTION(ValueError) << "Cannot find python func with name: " << func_name;
|
||||
}
|
||||
return py_func_obj.cast<py::function>();
|
||||
}
|
||||
|
||||
bool CustomizedOpaquePredicate::run_function(float x, float y) {
|
||||
if (Py_IsInitialized() != true) {
|
||||
MS_LOG(ERROR) << "Py_IsInitialized failed.";
|
||||
return false;
|
||||
}
|
||||
py::object customized_func = get_function();
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
int inputs_num = 2;
|
||||
py::tuple inputs(inputs_num);
|
||||
inputs[0] = py::float_(x);
|
||||
inputs[1] = py::float_(y);
|
||||
py::object result = customized_func(*inputs);
|
||||
if (result.is_none()) {
|
||||
MS_EXCEPTION(ValueError) << "Computing result of customized_func is None, please check it.";
|
||||
}
|
||||
bool bool_result = py::cast<bool>(result);
|
||||
int even_num = 2;
|
||||
if (func_name_code_[calling_count_ % func_name_code_.size()] % even_num == 0) {
|
||||
calling_count_ += 1;
|
||||
return bool_result;
|
||||
}
|
||||
calling_count_ += 1;
|
||||
return !bool_result;
|
||||
}
|
||||
|
||||
void CustomizedOpaquePredicate::init_calling_count() {
|
||||
this->calling_count_ = 0;
|
||||
MS_LOG(INFO) << "calling_count_ has been initialized to " << calling_count_;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H
|
||||
#define MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include <list>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include <Python.h>
|
||||
#include "pybind11/numpy.h"
|
||||
#include "include/common/visible.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class COMMON_EXPORT CustomizedOpaquePredicate {
|
||||
public:
|
||||
static CustomizedOpaquePredicate &GetInstance();
|
||||
void set_func_names();
|
||||
const std::vector<std::string> get_func_names();
|
||||
bool run_function(float x, float y);
|
||||
py::function get_function();
|
||||
void init_calling_count();
|
||||
|
||||
private:
|
||||
CustomizedOpaquePredicate() : func_names_({}) {}
|
||||
~CustomizedOpaquePredicate() = default;
|
||||
|
||||
std::vector<std::string> func_names_;
|
||||
int calling_count_ = 0;
|
||||
std::vector<int> func_name_code_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_REGISTRY_OPAQUE_PREDICATE_H
|
|
@ -56,6 +56,7 @@
|
|||
#include "ops/grad/max_pool_grad_with_argmax.h"
|
||||
#include "ops/max_pool_with_argmax.h"
|
||||
#include "ops/mirror_pad.h"
|
||||
#include "ops/opaquePredicate.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -323,6 +324,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimIdentity, R{InferImplIdentity, nullptr, true}},
|
||||
{prim::kPrimLoad, R{InferImplLoad, nullptr, true}},
|
||||
{prim::kPrimMutable, R{InferImplMutable, nullptr, true}},
|
||||
{prim::kPrimOpaquePredicate, R{ops::OpaquePredicateInfer, nullptr, true}},
|
||||
// Set impl to null as it will use PartialEvaluator;
|
||||
{prim::kPrimPartial, R{nullptr, nullptr, true}},
|
||||
{prim::kPrimEnvironCreate, R{InferImplEnvironCreate, nullptr, true}},
|
||||
|
|
|
@ -1376,6 +1376,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicLossScale, std::make_shared<Primitive>("_Dyna
|
|||
GVAR_DEF(PrimitivePtr, kPrimScaleGrad, std::make_shared<Primitive>("ScaleGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPopulationCount, std::make_shared<Primitive>("PopulationCount"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimBlackmanWindow, std::make_shared<Primitive>("BlackmanWindow"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimOpaquePredicate, std::make_shared<Primitive>("OpaquePredicate"));
|
||||
|
||||
// Structures
|
||||
GVAR_DEF(PrimitivePtr, kPrimMakeList, std::make_shared<Primitive>("make_list"));
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/opaquePredicate.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <complex>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::ShapePtr OpaquePredicateInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr OpaquePredicateInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim->name(), input_args, 0);
|
||||
auto y = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim->name(), input_args, 1);
|
||||
(void)abstract::CheckDtypeSame(prim->name(), x, y);
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat, kFloat16, kUInt16,
|
||||
kFloat64, kUInt8, kBool, kComplex64, kComplex128, kUInt32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("y", input_args[1]->BuildType(), valid_types, prim->name());
|
||||
return std::make_shared<TensorType>(kBool);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(OpaquePredicate, BaseOperator);
|
||||
AbstractBasePtr OpaquePredicateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto shape = OpaquePredicateInferShape(primitive, input_args);
|
||||
auto type = OpaquePredicateInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameOpaquePredicate, OpaquePredicate);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_
|
||||
#define MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameOpaquePredicate = "OpaquePredicate";
|
||||
/// \brief The opaque predicate used for dynamic obfuscation
|
||||
class MIND_API OpaquePredicate : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(OpaquePredicate);
|
||||
OpaquePredicate() : BaseOperator(kNameOpaquePredicate) { InitIOName({"x", "y"}, {"output"}); }
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr OpaquePredicateInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_OPAQUE_PREDICATE_H_
|
|
@ -26,6 +26,7 @@ import inspect
|
|||
import importlib
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
|
@ -46,7 +47,6 @@ from mindspore._checkparam import Validator
|
|||
from mindspore.common._utils import is_shape_unknown
|
||||
from mindspore.common.mutable import mutable
|
||||
|
||||
|
||||
# store ms_function class compiled pipeline cache
|
||||
ms_compile_cache = set()
|
||||
# store cell compiled pipeline cache,
|
||||
|
@ -670,6 +670,7 @@ class _MsFunctionCompileContext:
|
|||
"""
|
||||
ms_function compile status manager
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
@ -1064,10 +1065,12 @@ class _CellGraphExecutor:
|
|||
Returns:
|
||||
Graph, return the result of pipeline running.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# create needed graph by lazy mode
|
||||
self.is_init = False
|
||||
self.enable_tuple_broaden = False
|
||||
self.obfuscate_config = None # used for model's dynamic obfuscation
|
||||
self._graph_executor = GraphExecutor_.get_instance()
|
||||
self._graph_executor.set_py_exe_path(sys.executable)
|
||||
self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
|
||||
|
@ -1238,8 +1241,8 @@ class _CellGraphExecutor:
|
|||
return self._graph_executor.get_allreduce_fusion(real_phase)
|
||||
|
||||
def __call__(self, obj, *args, phase='predict'):
|
||||
if context.get_context("precompile_only") or\
|
||||
(_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
|
||||
if context.get_context("precompile_only") or \
|
||||
(_is_role_pserver() and not _enable_distributed_mindrt()) or _is_role_sched():
|
||||
return None
|
||||
return self.run(obj, *args, phase=phase)
|
||||
|
||||
|
@ -1290,6 +1293,21 @@ class _CellGraphExecutor:
|
|||
exec_id = exec_id + '.' + obj.arguments_key
|
||||
if self._graph_executor.has_compiled(exec_id) is False:
|
||||
return None
|
||||
if self.obfuscate_config is not None:
|
||||
if ('obf_ratio' not in self.obfuscate_config.keys()) or (
|
||||
'obf_password' not in self.obfuscate_config.keys()):
|
||||
raise ValueError("'obf_ratio' and 'obf_password' must be in obfuscate_config.")
|
||||
obf_password = self.obfuscate_config.get('obf_password')
|
||||
if obf_password == 0:
|
||||
append_password = 0
|
||||
else:
|
||||
seed_max = 2 ** 32 - 1
|
||||
int_max = 2 ** 31 - 1
|
||||
np.random.seed(obf_password % seed_max)
|
||||
append_password = np.random.randint(int_max)
|
||||
obf_password %= int_max
|
||||
return self._graph_executor.get_obfuscate_func_graph_proto(exec_id, self.obfuscate_config['obf_ratio'],
|
||||
obf_password, append_password)
|
||||
return self._graph_executor.get_func_graph_proto(exec_id, ir_type)
|
||||
|
||||
def get_optimize_graph_proto(self, obj):
|
||||
|
|
|
@ -2219,18 +2219,18 @@ class Cell(Cell_):
|
|||
if isinstance(set_input, Tensor):
|
||||
if not isinstance(net_input, Tensor):
|
||||
raise TypeError(
|
||||
f"The {index+1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
|
||||
f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
|
||||
if set_input.dtype is not net_input.dtype:
|
||||
raise ValueError(
|
||||
f"The {index+1}th input type of 'set_inputs' must be the same as network's input, "
|
||||
f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
|
||||
f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
|
||||
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
||||
raise ValueError(
|
||||
f"The {index+1}th input dims of 'set_inputs' must be the same as network's input, "
|
||||
f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
|
||||
f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
|
||||
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
||||
raise ValueError(
|
||||
f"The {index+1}th input shape of 'set_inputs' must be the same as network's input, "
|
||||
f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, "
|
||||
f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.")
|
||||
|
||||
|
||||
|
@ -2247,6 +2247,11 @@ class GraphCell(Cell):
|
|||
The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
|
||||
If the parameter exists in the graph according to the name, update it's value.
|
||||
If the parameter does not exist, ignore it. Default: None.
|
||||
obf_password (int): The password used for dynamic obfuscation. "dynamic obfuscation" is used for model
|
||||
protection, which can refer to `mindspore.train.serialization.obfuscate_model()`. If the input 'graph' is a
|
||||
func_graph loaded from a mindir file obfuscated in password mode, then obf_password should be provided.
|
||||
obf_password should be larger than zero and less or equal than int_64 (9223372036854775807). default: None.
|
||||
|
||||
Raises:
|
||||
TypeError: If the `graph` is not a FuncGraph.
|
||||
TypeError: If the `params_init` is not a dict.
|
||||
|
@ -2273,13 +2278,19 @@ class GraphCell(Cell):
|
|||
[6. 9. 6.]
|
||||
[4. 6. 4.]]]]
|
||||
"""
|
||||
def __init__(self, graph, params_init=None):
|
||||
|
||||
def __init__(self, graph, params_init=None, obf_password=None):
|
||||
super(GraphCell, self).__init__(auto_prefix=True)
|
||||
if not isinstance(graph, FuncGraph):
|
||||
raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
|
||||
f"but got type {type(graph)}.")
|
||||
self.graph = graph
|
||||
|
||||
self.obf_password = obf_password
|
||||
int_64_max = 9223372036854775807
|
||||
if (obf_password is not None) and (obf_password <= 0 or obf_password > int_64_max):
|
||||
raise ValueError(
|
||||
"'obf_password' must be larger than 0, and less or equal than int64 ({}),"
|
||||
"but got {}.".format(int_64_max, obf_password))
|
||||
params_init = {} if params_init is None else params_init
|
||||
if not isinstance(params_init, dict):
|
||||
raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
|
||||
|
@ -2299,7 +2310,21 @@ class GraphCell(Cell):
|
|||
def __call__(self, *inputs):
|
||||
self.phase = "graph_load_from_mindir"
|
||||
self._add_attr("graph_load_from_mindir", self.graph)
|
||||
return self.compile_and_run(*inputs)
|
||||
if not self.obf_password:
|
||||
return self.compile_and_run(*inputs)
|
||||
append_input_1, append_input_2 = _obf_appended_inputs(self.obf_password)
|
||||
return self.compile_and_run(*inputs, append_input_1, append_input_2)
|
||||
|
||||
|
||||
def _obf_appended_inputs(obf_password):
|
||||
seed_max = 2 ** 32 - 1
|
||||
int_max = 2 ** 31 - 1
|
||||
numpy.random.seed(obf_password % seed_max)
|
||||
append_password = numpy.random.randint(int_max)
|
||||
obf_password %= int_max
|
||||
append_input_1 = Tensor((numpy.ones((1, 1)) * obf_password).astype(numpy.int32))
|
||||
append_input_2 = Tensor((numpy.ones((1, 1)) * append_password).astype(numpy.int32))
|
||||
return append_input_1, append_input_2
|
||||
|
||||
|
||||
def _check_param_list_tuple(value):
|
||||
|
|
|
@ -60,3 +60,13 @@ class PyFuncRegistry(UserDict):
|
|||
if key not in self:
|
||||
raise ValueError(f"Python function with key{key} not registered.")
|
||||
return self[key]
|
||||
|
||||
|
||||
class OpaquePredicateRegistry(PyFuncRegistry):
|
||||
def __init__(self):
|
||||
super(OpaquePredicateRegistry, self).__init__()
|
||||
self.func_names = []
|
||||
|
||||
def register(self, key, value):
|
||||
self[key] = value
|
||||
self.func_names.append(key)
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Register pyfunc for opaque_func_cpu_kernel"""
|
||||
|
||||
from mindspore.ops._register_for_op import OpaquePredicateRegistry
|
||||
|
||||
|
||||
registered_func_name = OpaquePredicateRegistry()
|
||||
|
||||
|
||||
def add_opaque_predicate(fn_name, func):
|
||||
registered_func_name.register(fn_name, func)
|
||||
|
||||
|
||||
def get_opaque_predicate(fn_name):
|
||||
return registered_func_name.get(fn_name)
|
||||
|
||||
|
||||
def get_func_names():
|
||||
return registered_func_name.func_names
|
||||
|
||||
|
||||
def clean_funcs():
|
||||
registered_func_name.func_names = []
|
|
@ -26,7 +26,7 @@ from mindspore.train.amp import build_train_network
|
|||
from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \
|
||||
load, parse_print, build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, \
|
||||
async_ckpt_thread_status, restore_group_info_list, convert_model
|
||||
async_ckpt_thread_status, restore_group_info_list, convert_model, obfuscate_model
|
||||
from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \
|
||||
CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, \
|
||||
History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore
|
||||
|
@ -39,7 +39,7 @@ __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "bui
|
|||
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
||||
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
|
||||
"load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list", "convert_model",
|
||||
"data_sink"]
|
||||
"data_sink", "obfuscate_model"]
|
||||
__all__.extend(callback.__all__)
|
||||
__all__.extend(summary.__all__)
|
||||
__all__.extend(train_thor.__all__)
|
||||
|
|
|
@ -55,11 +55,11 @@ from mindspore.parallel._cell_wrapper import get_allgather_cell
|
|||
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
||||
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
||||
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
||||
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy,\
|
||||
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
||||
_restore_group_info_list
|
||||
from mindspore.train._utils import read_proto
|
||||
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
|
||||
|
||||
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir
|
||||
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
||||
|
||||
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
||||
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
||||
|
@ -384,6 +384,17 @@ def _check_append_dict(append_dict):
|
|||
return append_dict
|
||||
|
||||
|
||||
def _check_load_obfuscate(**kwargs):
|
||||
if 'obf_func' in kwargs.keys():
|
||||
customized_func = kwargs.get('obf_func')
|
||||
if not callable(customized_func):
|
||||
raise ValueError("obf_func must be a callable function, but got a {}.".format(type(customized_func)))
|
||||
clean_funcs()
|
||||
add_opaque_predicate(customized_func.__name__, customized_func)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(file_name, **kwargs):
|
||||
"""
|
||||
Load MindIR.
|
||||
|
@ -436,6 +447,9 @@ def load(file_name, **kwargs):
|
|||
"please check whether the 'file_name' is correct.")
|
||||
file_name = os.path.realpath(file_name)
|
||||
|
||||
# set customized functions for dynamic obfuscation
|
||||
obfuscated = _check_load_obfuscate(**kwargs)
|
||||
|
||||
logger.info("Execute the process of loading mindir.")
|
||||
if 'dec_key' in kwargs.keys():
|
||||
dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
|
||||
|
@ -448,9 +462,9 @@ def load(file_name, **kwargs):
|
|||
else:
|
||||
dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
|
||||
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
|
||||
decrypt=dec_func)
|
||||
decrypt=dec_func, obfuscated=obfuscated)
|
||||
else:
|
||||
graph = load_mindir(file_name)
|
||||
graph = load_mindir(file_name, obfuscated=obfuscated)
|
||||
|
||||
if graph is None:
|
||||
if _is_cipher_file(file_name):
|
||||
|
@ -461,6 +475,152 @@ def load(file_name, **kwargs):
|
|||
return graph
|
||||
|
||||
|
||||
def _check_param_type(param_config, key, target_type, requested):
|
||||
"""check type of parameters"""
|
||||
if key in param_config:
|
||||
if not isinstance(param_config[key], target_type):
|
||||
raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
|
||||
return param_config[key]
|
||||
if requested:
|
||||
raise ValueError("The parameter {} is requested, but not got.".format(key))
|
||||
if key == "obf_password":
|
||||
return 0
|
||||
return None
|
||||
|
||||
|
||||
def _check_obfuscate_params(obf_config):
|
||||
"""check obfuscation parameters, including obf_password, obf_ratio, customized_func"""
|
||||
if 'obf_password' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
|
||||
raise ValueError(
|
||||
"At least one of 'obf_password' or 'customized_func' must be set in obf_config, but got None of them.")
|
||||
obfuscate_type = _check_param_type(obf_config, "type", str, False)
|
||||
if obfuscate_type not in (None, "dynamic"):
|
||||
raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
|
||||
if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
|
||||
if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
|
||||
raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
|
||||
obf_config['obf_ratio']))
|
||||
ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
|
||||
obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
|
||||
obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
|
||||
if (obf_ratio <= 0) or (obf_ratio > 1):
|
||||
raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
|
||||
customized_funcs = []
|
||||
if 'customized_func' in obf_config.keys():
|
||||
if callable(obf_config['customized_func']):
|
||||
customized_funcs.append(obf_config['customized_func'])
|
||||
else:
|
||||
raise TypeError(
|
||||
"'customized_func' must be a function, but not got {}.".format(type(obf_config['customized_func'])))
|
||||
obf_password = _check_param_type(obf_config, "obf_password", int, False)
|
||||
int_64_max = 9223372036854775807
|
||||
if obf_password > int_64_max:
|
||||
raise ValueError(
|
||||
"'obf_password' must be less or equal than int64 ({}), but got {}.".format(int_64_max, obf_password))
|
||||
return obf_ratio, customized_funcs, obf_password
|
||||
|
||||
|
||||
def obfuscate_model(obf_config, **kwargs):
|
||||
"""
|
||||
Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
|
||||
predict correctness. The obfuscated model can prevent attackers from stealing the model.
|
||||
|
||||
Args:
|
||||
obf_config (dict): obfuscation config.
|
||||
|
||||
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
||||
- original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
|
||||
model is encrypted, then enc_key and enc_mode should be provided.
|
||||
- save_model_path (str): The path to save the obfuscated model.
|
||||
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
|
||||
is the same as using `export()`.
|
||||
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio` should
|
||||
be in range of (0, 1] or in ["small", "medium", "large"].
|
||||
- customized_func (function): A python function used for customized function mode, which used for control
|
||||
the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
|
||||
function needs to ensure that its result is constant for any input. Users can refer to opaque
|
||||
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
||||
obfuscated model.
|
||||
- obf_password (int): A password used for password mode, which should be larger than zero. If
|
||||
obf_password is set, then it should be passed to `nn.GraphCell()` interface when loading obfuscated
|
||||
model. It should be noted that at least one of 'customized_func' or 'obf_password' should be set, and
|
||||
'obf_password' mode would be applied if both of them are set.
|
||||
|
||||
kwargs (dict): Configuration options dictionary.
|
||||
|
||||
- enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
||||
- enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
|
||||
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
|
||||
|
||||
Raises:
|
||||
TypeError: If obf_config is not a dict.
|
||||
ValueError: If enc_key is passed and enc_mode is not in ["AES-GCM", "AES-CBC"].
|
||||
ValueError: If original_model_path is not provided in obf_config.
|
||||
ValueError: If the model saved in original_model_path has been obfuscated.
|
||||
ValueError: If save_model_path is not provided in obf_config.
|
||||
ValueError: If obf_ratio is not provided in obf_config.
|
||||
ValueError: If both customized_func and obf_password are not provided in obf_config.
|
||||
ValueError: If both obf_password is not in (0, 9223372036854775807].
|
||||
|
||||
Examples:
|
||||
>>> obf_config = {'original_model_path': "./net.mindir",
|
||||
... 'save_model_path': "./obf_net",
|
||||
... 'model_inputs': [input1, ],
|
||||
... 'obf_ratio': 0.1, 'obf_password': 173262358423}
|
||||
>>> obfuscate_model(obf_config)
|
||||
>>> obf_func = load("obf_net.mindir")
|
||||
>>> obf_net = nn.GraphCell(obf_func, obf_password=173262358423)
|
||||
>>> print(obf_net(input1).asnumpy())
|
||||
"""
|
||||
if not isinstance(obf_config, dict):
|
||||
raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
|
||||
file_path = _check_param_type(obf_config, "original_model_path", str, True)
|
||||
saved_path = _check_param_type(obf_config, "save_model_path", str, True)
|
||||
model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
|
||||
for item in model_inputs:
|
||||
if not isinstance(item, Tensor):
|
||||
raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
|
||||
obf_ratio, customized_funcs, obf_password = _check_obfuscate_params(obf_config)
|
||||
if customized_funcs and obf_password > 0:
|
||||
logger.warning("Although customized_func and obf_password are set, the 'obf_password' mode would be"
|
||||
"applied, remember to set obf_password when loading obfuscated model.")
|
||||
|
||||
if obf_password == 0: # apply customized_func mode
|
||||
clean_funcs()
|
||||
for func in customized_funcs:
|
||||
add_opaque_predicate(func.__name__, func)
|
||||
append_password = 0
|
||||
else:
|
||||
seed_max = 2 ** 32 - 1
|
||||
int_max = 2 ** 31 - 1
|
||||
np.random.seed(obf_password % seed_max)
|
||||
append_password = np.random.randint(int_max)
|
||||
obf_password %= int_max
|
||||
|
||||
if 'enc_key' in kwargs.keys():
|
||||
enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
|
||||
enc_mode = "AES-GCM"
|
||||
if 'enc_mode' in kwargs.keys():
|
||||
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
||||
if enc_mode not in ["AES-GCM", "AES-CBC"]:
|
||||
raise ValueError(
|
||||
"Only MindIR files that encrypted with 'AES-GCM' or 'AES-CBC' is supported for obfuscate_model(),"
|
||||
" but got {}.".format(enc_mode))
|
||||
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio, obf_password=obf_password,
|
||||
append_password=append_password, dec_key=enc_key, key_len=len(enc_key),
|
||||
dec_mode=enc_mode)
|
||||
else:
|
||||
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio, obf_password=obf_password,
|
||||
append_password=append_password)
|
||||
|
||||
obf_net = nn.GraphCell(obf_graph)
|
||||
if obf_password != 0:
|
||||
y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
||||
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
||||
model_inputs += [y_tensor, append_y_tensor]
|
||||
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
||||
dec_key=None, dec_mode="AES-GCM", specify_prefix=None):
|
||||
"""
|
||||
|
@ -524,7 +684,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
np_type = tensor_to_np_type.get(data_type)
|
||||
ms_type = tensor_to_ms_type[data_type]
|
||||
if data_type == 'str':
|
||||
str_length = int(len(data)/4)
|
||||
str_length = int(len(data) / 4)
|
||||
np_type = np_type + str(str_length)
|
||||
element_data = np.frombuffer(data, np_type)
|
||||
param_data_list.append(element_data)
|
||||
|
@ -549,7 +709,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
except BaseException as e:
|
||||
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
|
||||
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
|
||||
"failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
|
||||
"failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
|
||||
|
||||
if not parameter_dict:
|
||||
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
||||
|
@ -617,10 +777,10 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
|
|||
except BaseException as e:
|
||||
if _is_cipher_file(ckpt_file_name):
|
||||
err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \
|
||||
"please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name)
|
||||
"please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name)
|
||||
else:
|
||||
err_info = "Failed to read the checkpoint file {}. May not have permission to read it, please check" \
|
||||
" the correct of the file.".format(ckpt_file_name)
|
||||
" the correct of the file.".format(ckpt_file_name)
|
||||
logger.error(err_info)
|
||||
raise ValueError(err_info) from e
|
||||
return checkpoint_list
|
||||
|
@ -885,6 +1045,20 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|||
- For details of using the customized encryption, please check the `tutorial
|
||||
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
||||
|
||||
- obf_config (dict): obfuscation config.
|
||||
|
||||
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
||||
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
||||
should be in range of (0, 1] or in ["small", "medium", "large"].
|
||||
- customized_func (function): A python function used for customized function mode, which used for control
|
||||
the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
|
||||
function needs to ensure that its result is constant for any input. Users can refer to opaque
|
||||
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
||||
obfuscated model.
|
||||
- obf_password (int): A password used for password mode, which should be larger than zero. If
|
||||
obf_password is set, then it should be passed to `nn.GraphCell()` interface when loading obfuscated
|
||||
model. It should be noted that at least one of 'customized_func' or 'obf_password' should be set, and
|
||||
'obf_password' mode would be applied if both of them are set.
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import numpy as np
|
||||
|
@ -920,11 +1094,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|||
file_name = os.path.realpath(file_name)
|
||||
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
|
||||
if 'enc_key' in kwargs.keys():
|
||||
enc_key, enc_mode = _check_key_mode_type(file_format, **kwargs)
|
||||
dataset = kwargs.get('dataset')
|
||||
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset)
|
||||
else:
|
||||
_export(net, file_name, file_format, *inputs, **kwargs)
|
||||
kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
|
||||
_export(net, file_name, file_format, *inputs, **kwargs)
|
||||
|
||||
|
||||
def _export(net, file_name, file_format, *inputs, **kwargs):
|
||||
|
@ -1150,13 +1321,40 @@ def _cell_info(net, *inputs):
|
|||
do_convert=False, auto_parallel_mode=net._auto_parallel_mode)
|
||||
# pylint: disable=protected-access
|
||||
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
||||
# clean obfuscation config to prevent the next call
|
||||
_executor.obfuscate_config = None
|
||||
|
||||
net_dict = net.parameters_dict()
|
||||
return mindir_stream, net_dict
|
||||
|
||||
|
||||
def _set_obfuscate_config(**kwargs):
|
||||
"""Set obfuscation config for executor."""
|
||||
logger.warning("Obfuscate model.")
|
||||
if 'enc_mode' in kwargs.keys():
|
||||
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
||||
if enc_mode not in ["AES-GCM", "AES-CBC"]:
|
||||
raise ValueError(
|
||||
"Only MindIR files that encrypted with 'AES-GCM' or 'AES-CBC' is supported for obfuscation,"
|
||||
"but got {}.".format(enc_mode))
|
||||
obf_ratio, customized_funcs, obf_password = _check_obfuscate_params(kwargs.get('obf_config'))
|
||||
if customized_funcs and obf_password > 0:
|
||||
logger.warning("Although customized_func and obf_password are set, the 'obf_password' mode would be"
|
||||
"applied, remember to set obf_password when loading obfuscated model.")
|
||||
|
||||
if obf_password == 0: # apply customized_func mode
|
||||
clean_funcs()
|
||||
for func in customized_funcs:
|
||||
add_opaque_predicate(func.__name__, func)
|
||||
_executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_password': obf_password}
|
||||
|
||||
|
||||
def _save_mindir(net, file_name, *inputs, **kwargs):
|
||||
"""Save MindIR format file."""
|
||||
# set obfuscate configs
|
||||
if 'obf_config' in kwargs.keys():
|
||||
_set_obfuscate_config(**kwargs)
|
||||
|
||||
model = mindir_model()
|
||||
if not isinstance(net, nn.Cell):
|
||||
mindir_stream, net_dict = _msfunc_info(net, *inputs)
|
||||
|
@ -1248,6 +1446,7 @@ def _save_dataset_to_mindir(model, dataset):
|
|||
|
||||
def quant_mode_manage(func):
|
||||
"""Inherit the quant_mode in old version."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def warpper(network, *inputs, file_format, **kwargs):
|
||||
if 'quant_mode' not in kwargs:
|
||||
|
@ -1746,8 +1945,8 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
|
||||
" and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
|
||||
raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
|
||||
f" in load distributed checkpoint for {param.name}. Data shape is "
|
||||
f"{split_param.data.shape} and group is {opt_shard_group}.") from e
|
||||
f" in load distributed checkpoint for {param.name}. Data shape is "
|
||||
f"{split_param.data.shape} and group is {opt_shard_group}.") from e
|
||||
split_param = Parameter(Tensor(data_slice), param.name,
|
||||
split_param.requires_grad, split_param.layerwise_parallel)
|
||||
param_dict[param.name] = split_param
|
||||
|
|
Loading…
Reference in New Issue