replace fixed string name with global references

This commit is contained in:
jin-xiulang 2023-02-22 14:20:49 +08:00
parent 3f3cee445d
commit 99b02895f3
3 changed files with 42 additions and 19 deletions

View File

@ -856,6 +856,8 @@ constexpr auto kCheckValidOpName = "CheckValid";
constexpr auto kSoftmaxGradFusionOpName = "SoftmaxGradFusion";
constexpr auto kSoftMarginLossOpName = "SoftMarginLoss";
constexpr auto kZerosLikeOpName = "ZerosLike";
constexpr auto kSoftplusOpName = "Softplus";
constexpr auto kSoftsignOpName = "Softsign";
// Sequence ops
constexpr auto kScalarToTensorOpName = "ScalarToTensor";

View File

@ -51,10 +51,6 @@ constexpr int keyExpandRate = 10; // total node need for a switch graph
constexpr int kWeightIndex = 2;
constexpr int kSwitchInputsNum = 2;
constexpr int kNodeWithWeightInputsNum = 3;
constexpr auto keyConv2DOpName = "Conv2D-op";
constexpr auto keyReluOpName = "ReLU-op";
constexpr auto keySigmoidOpName = "Sigmoid-op";
constexpr auto keyMatMulOpName = "MatMul-op";
ShapeVector get_node_shape(const AnfNodePtr &input_node) {
if (input_node == nullptr) {
@ -123,8 +119,15 @@ std::string get_node_name(const AnfNodePtr &node) {
MS_LOG(WARNING) << "Input node name is empty.";
return "";
}
std::string name = split_words[split_words.size() - 1];
return name;
std::string name = split_words[split_words.size() - 1]; // name is like x-opx
std::vector<string> split_name = name_split(name, "-");
size_t qualified_split_len = 2;
if (split_name.size() != qualified_split_len) {
MS_LOG(ERROR) << "The size of split op_name must be 2, but got: " << split_name.size()
<< ". Complete name is: " << name;
return "";
}
return split_name[0];
}
int get_op_num(const AnfNodePtr &node) {
@ -154,9 +157,10 @@ ParameterPtr get_node_param(const FuncGraphPtr func_graph, const CNodePtr &node)
std::string parameter_name = "";
for (auto input : node->inputs()) {
std::string op_name = get_node_name(input);
int op_name_len = LongToInt(op_name.size());
int load_len = 7;
if ((op_name_len >= load_len) && (op_name.substr(0, load_len) == "Load-op")) {
MS_LOG(INFO) << "op_name is: " << op_name;
int op_name_len = op_name.size();
int load_len = 4;
if ((op_name_len >= load_len) && (op_name.substr(0, load_len) == "Load")) {
for (auto param : input->cast<mindspore::CNodePtr>()->inputs()) {
if (param->fullname_with_scope().find("weight") != std::string::npos) {
parameter_name = param->fullname_with_scope();
@ -329,6 +333,10 @@ CNodePtr DynamicObfuscator::RandomSeedModeControl(const FuncGraphPtr func_graph)
func_graph->AddValueNode(equal_v_node);
ValueNodePtr equal_compa_node = make_int_node(func_graph, branch_control_input_);
CNodePtr equal_c_node = func_graph->NewCNode({equal_v_node, y_append, equal_compa_node});
if (equal_c_node == nullptr) {
MS_LOG(ERROR) << "equal_c_node is nullptr.";
return nullptr;
}
tensor::TensorPtr equal_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
equal_c_node->set_abstract(equal_tensor->ToAbstract());
func_graph->AddNode(equal_c_node);
@ -342,6 +350,10 @@ CNodePtr DynamicObfuscator::RandomSeedModeControl(const FuncGraphPtr func_graph)
func_graph->AddValueNode(greater_v_node);
ValueNodePtr greater_compa_node = make_int_node(func_graph, comparison_int);
CNodePtr greater_c_node = func_graph->NewCNode({greater_v_node, y_append, greater_compa_node});
if (greater_c_node == nullptr) {
MS_LOG(ERROR) << "greater_c_node is nullptr.";
return nullptr;
}
tensor::TensorPtr greater_tensor = std::make_shared<Tensor>(mindspore::kNumberTypeBool, y_shape);
greater_c_node->set_abstract(greater_tensor->ToAbstract());
func_graph->AddNode(greater_c_node);
@ -533,20 +545,19 @@ void DynamicObfuscator::CheckDuplicatedParent(const AnfNodePtr &node) {
}
bool DynamicObfuscator::IsTarget(std::string &cnode_name) {
std::vector<string> split_words = name_split(cnode_name, "/");
if (split_words.empty()) {
if (cnode_name.empty()) {
MS_LOG(WARNING) << "CNode name is empty.";
return false;
}
std::string op_name = split_words[split_words.size() - 1];
std::vector<std::string> target_op_list;
target_op_list.insert(target_op_list.end(), single_input_target_op_.begin(), single_input_target_op_.end());
target_op_list.insert(target_op_list.end(), single_input_with_weight_target_op_.begin(),
single_input_with_weight_target_op_.end());
for (std::string target_op_name : target_op_list) {
int op_name_len = SizeToInt(op_name.size());
int target_name_len = SizeToInt(target_op_name.size());
if ((op_name_len >= target_name_len) && (op_name.substr(0, target_name_len) == target_op_name)) {
int op_name_len = cnode_name.size();
int target_name_len = target_op_name.size();
if ((op_name_len >= target_name_len) && (cnode_name.substr(0, target_name_len) == target_op_name)) {
MS_LOG(WARNING) << "find target node.";
return true;
}
}
@ -671,6 +682,7 @@ FuncGraphPtr DynamicObfuscator::CloneSubGraph(const std::vector<mindspore::CNode
mindspore::AnfNodePtr last_node = input_x;
for (auto node : node_arr) {
std::string obf_type = ObfuscateOpType(node);
MS_LOG(WARNING) << "obf_type: " << obf_type;
mindspore::ObfCase obf_case = ObfuscateOpCase(obf_type);
switch (obf_case) {
case ObfCase::OneInputNoWeightNode: {
@ -836,7 +848,7 @@ mindspore::CNodePtr DynamicObfuscator::AddPartialBranch(const FuncGraphPtr fg, F
subgraph_inputs.push_back(nodes[0]->inputs()[1]);
for (unsigned i = 0; i < nodes.size(); i++) {
std::string obf_type = ObfuscateOpType(nodes[i]);
if ((obf_type == keyConv2DOpName || obf_type == keyMatMulOpName) &&
if ((obf_type == kConv2DOpName || obf_type == kMatMulOpName) &&
nodes[i]->inputs().size() >= kNodeWithWeightInputsNum) {
subgraph_inputs.push_back(nodes[i]->inputs()[kWeightIndex]);
}
@ -914,10 +926,18 @@ void DynamicObfuscator::AddSwitchNode(const FuncGraphPtr fg) {
} else {
switch_c_node = fg->NewCNode({switch_v_node, control_node, switch_partial_fake_c, switch_partial_clone_c});
}
if (switch_c_node == nullptr) {
MS_LOG(ERROR) << "switch_c_node is nullptr.";
return;
}
switch_c_node->set_abstract(fg_subgraph_clone->ToAbstract());
fg->AddNode(switch_c_node);
mindspore::CNodePtr call_cnode = fg->NewCNode({switch_c_node});
if (call_cnode == nullptr) {
MS_LOG(ERROR) << "call_cnode is nullptr.";
return;
}
fg->AddNode(call_cnode);
if (child_node != nullptr) {

View File

@ -24,6 +24,7 @@
#include <set>
#include "load_mindir/load_model.h"
#include "include/common/visible.h"
#include "include/common/utils/utils.h"
#include "ops/core_ops.h"
namespace mindspore {
@ -70,9 +71,9 @@ class COMMON_EXPORT DynamicObfuscator {
int subgraph_obf_num_ = 0;
bool switch_branch_ = true;
const std::vector<std::string> single_input_target_op_ = {
"ReLU-op", "Sigmoid-op", "ReLU6-op", "Softplus-op", "HSigmoid-op", "FastGeLU-op", "HSwish-op",
"Softsign-op", "SeLU-op", "Tanh-op", "Square-op", "AvgPool-op", "MaxPool-op"};
const std::vector<std::string> single_input_with_weight_target_op_ = {"Conv2D-op", "MatMul-op"};
kReLUOpName, kSigmoidOpName, kReLU6OpName, kSoftplusOpName, kHSigmoidOpName, kFastGeLUOpName, kHSwishOpName,
kSoftsignOpName, kSeLUOpName, kTanhOpName, kSquareOpName, kAvgPoolOpName, kMaxPoolOpName};
const std::vector<std::string> single_input_with_weight_target_op_ = {kConv2DOpName, kMatMulOpName};
const std::vector<PrimitivePtr> one_input_prim_ = {
mindspore::prim::kPrimReLU, mindspore::prim::kPrimSigmoid, mindspore::prim::kPrimReLU6,
mindspore::prim::kPrimSoftplus, mindspore::prim::kPrimHSigmoid, mindspore::prim::kPrimFastGeLU,