forked from mindspore-Ecosystem/mindspore
optimize dynamic obfuscation code
This commit is contained in:
parent
b71f1ea735
commit
c7f1186020
|
@ -271,7 +271,7 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const Func
|
|||
auto is_load = primitive->GetAttr("is_load");
|
||||
if (abstract::GetPrimEvaluator(primitive, engine) == nullptr && is_load != nullptr &&
|
||||
GetValue<bool>(is_load)) {
|
||||
MS_LOG(WARNING) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
|
||||
MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -949,11 +949,11 @@ void DynamicObfuscator::SubGraphFakeBranch(FuncGraphPtr func_graph) {
|
|||
auto all_nodes = mgr->all_nodes();
|
||||
int node_nums = all_nodes.size();
|
||||
int obfuscate_target_num = std::ceil(node_nums * obf_ratio_ / keyExpandRate);
|
||||
int op_num = 0;
|
||||
int op_num = node_nums; // Initialize op_num to the maximum node number
|
||||
std::vector<mindspore::AnfNodePtr> sorted_nodes;
|
||||
for (auto node : all_nodes) {
|
||||
MS_LOG(INFO) << "The last node name is: " << node->fullname_with_scope();
|
||||
sorted_nodes = TopoSort(node); // the node number in front of sorted nodes is the smallest
|
||||
op_num = get_op_num(node);
|
||||
break;
|
||||
}
|
||||
std::reverse(sorted_nodes.begin(), sorted_nodes.end());
|
||||
|
|
|
@ -1245,7 +1245,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|||
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
|
||||
"""
|
||||
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
||||
|
||||
if "obf_config" in kwargs and file_format != "MINDIR":
|
||||
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
||||
if file_format == 'GEIR':
|
||||
logger.warning(f"For 'export', format 'GEIR' is deprecated, "
|
||||
f"it would be removed in future release, use 'AIR' instead.")
|
||||
|
|
|
@ -135,7 +135,6 @@ def test_obfuscate_model_customized_func_mode():
|
|||
os.remove("obf_net_2.mindir")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="random failures")
|
||||
def test_export_password_mode():
|
||||
"""
|
||||
Feature: Obfuscate MindIR format model with dynamic obfuscation (password mode) in export().
|
||||
|
@ -189,3 +188,19 @@ def test_export_customized_func_mode():
|
|||
|
||||
if os.path.exists("obf_net_4.mindir"):
|
||||
os.remove("obf_net_4.mindir")
|
||||
|
||||
|
||||
def test_wrong_file_format_input():
|
||||
"""
|
||||
Feature: Obfuscate MindIR format model with dynamic obfuscation (customized_func mode) in export().
|
||||
Description: Test wrong file_formar input.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net_5 = ObfuscateNet()
|
||||
input_tensor = Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
|
||||
|
||||
# obfuscate model
|
||||
obf_config = {"obf_ratio": 0.8, "obf_password": 3423}
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
export(net_5, input_tensor, file_name="obf_net_3", file_format="ONNX", obf_config=obf_config)
|
||||
assert str(error_info.value) == "Dynamic obfuscation only support for MindIR format, but got ONNX format."
|
||||
|
|
Loading…
Reference in New Issue