From c7f11860200459a2e74d26107109de5d78049026 Mon Sep 17 00:00:00 2001 From: jin-xiulang Date: Fri, 10 Feb 2023 09:38:51 +0800 Subject: [PATCH] optimize dynamic obfuscation code --- mindspore/ccsrc/pipeline/jit/action.cc | 2 +- .../dynamic_obfuscation/dynamic_obfuscation.cc | 4 ++-- .../python/mindspore/train/serialization.py | 3 ++- .../python/mindir/test_dynamic_obfuscation.py | 17 ++++++++++++++++- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 315390dd6a3..6ed2d20e9a4 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -271,7 +271,7 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &resource, const Func auto is_load = primitive->GetAttr("is_load"); if (abstract::GetPrimEvaluator(primitive, engine) == nullptr && is_load != nullptr && GetValue(is_load)) { - MS_LOG(WARNING) << "The primitive is not defined in front end. Primitive: " << primitive->ToString(); + MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString(); continue; } } diff --git a/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc b/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc index 68ebeaeceae..0dbd3b8d5e8 100644 --- a/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc +++ b/mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.cc @@ -949,11 +949,11 @@ void DynamicObfuscator::SubGraphFakeBranch(FuncGraphPtr func_graph) { auto all_nodes = mgr->all_nodes(); int node_nums = all_nodes.size(); int obfuscate_target_num = std::ceil(node_nums * obf_ratio_ / keyExpandRate); - int op_num = 0; + int op_num = node_nums; // Initialize op_num to the maximum node number std::vector sorted_nodes; for (auto node : all_nodes) { + MS_LOG(INFO) << "The last node name is: " << node->fullname_with_scope(); sorted_nodes = TopoSort(node); // the node number in front of sorted nodes is the smallest - op_num = get_op_num(node); break; } std::reverse(sorted_nodes.begin(), sorted_nodes.end()); diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index aaa6715e6ad..3f693b8d149 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -1245,7 +1245,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs): It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format. """ logger.info("exporting model file:%s format:%s.", file_name, file_format) - + if "obf_config" in kwargs and file_format != "MINDIR": + raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.") if file_format == 'GEIR': logger.warning(f"For 'export', format 'GEIR' is deprecated, " f"it would be removed in future release, use 'AIR' instead.") diff --git a/tests/ut/python/mindir/test_dynamic_obfuscation.py b/tests/ut/python/mindir/test_dynamic_obfuscation.py index 39c544368a5..4f6c5ddd548 100644 --- a/tests/ut/python/mindir/test_dynamic_obfuscation.py +++ b/tests/ut/python/mindir/test_dynamic_obfuscation.py @@ -135,7 +135,6 @@ def test_obfuscate_model_customized_func_mode(): os.remove("obf_net_2.mindir") -@pytest.mark.skip(reason="random failures") def test_export_password_mode(): """ Feature: Obfuscate MindIR format model with dynamic obfuscation (password mode) in export(). @@ -189,3 +188,19 @@ def test_export_customized_func_mode(): if os.path.exists("obf_net_4.mindir"): os.remove("obf_net_4.mindir") + + +def test_wrong_file_format_input(): + """ + Feature: Obfuscate MindIR format model with dynamic obfuscation (customized_func mode) in export(). + Description: Test wrong file_formar input. + Expectation: Success. + """ + net_5 = ObfuscateNet() + input_tensor = Tensor(np.ones((1, 1, 32, 32)).astype(np.float32)) + + # obfuscate model + obf_config = {"obf_ratio": 0.8, "obf_password": 3423} + with pytest.raises(ValueError) as error_info: + export(net_5, input_tensor, file_name="obf_net_3", file_format="ONNX", obf_config=obf_config) + assert str(error_info.value) == "Dynamic obfuscation only support for MindIR format, but got ONNX format."