From 10de370dd8fd89bd767e1d72e0de8796b40b344e Mon Sep 17 00:00:00 2001 From: lianliguang Date: Fri, 13 Jan 2023 16:03:41 +0800 Subject: [PATCH] export node attr and primal attr --- .jenkins/check/config/filter_pylint.txt | 1 + .../convert_const_input_to_tensor_input.cc | 1 + .../transform/express_ir/mindir_exporter.cc | 30 ++- .../core/load_mindir/anf_model_parser.cc | 231 +++++++++--------- mindspore/core/load_mindir/anf_model_parser.h | 3 +- mindspore/core/proto/mind_ir.proto | 2 + tests/ut/cpp/CMakeLists.txt | 1 + tests/ut/cpp/mindir/test_node_attr_export.cc | 67 +++++ .../gtest_input/mindir/mindir_test.py | 41 ++++ tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc | 7 - 10 files changed, 263 insertions(+), 121 deletions(-) create mode 100644 tests/ut/cpp/mindir/test_node_attr_export.cc create mode 100644 tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index c484bbb9378..98052079d26 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -89,6 +89,7 @@ "mindspore/tests/vm_impl/array_ops_vm_impl.py" "unused-variable" "mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_compile.py" "unused-import" "mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py" "super-init-not-called" +"mindspore/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py" "unused-variable" "mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called" "mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable" "mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable" diff --git a/mindspore/ccsrc/backend/common/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/backend/common/pass/convert_const_input_to_tensor_input.cc index 04268d6cc07..12a1826db6c 100644 --- a/mindspore/ccsrc/backend/common/pass/convert_const_input_to_tensor_input.cc +++ b/mindspore/ccsrc/backend/common/pass/convert_const_input_to_tensor_input.cc @@ -97,6 +97,7 @@ AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGra auto new_cnode = NewCNode(new_inputs, func_graph); MS_EXCEPTION_IF_NULL(new_cnode); new_cnode->set_primal_attrs(cnode->primal_attrs()); + new_cnode->set_attrs(cnode->attrs()); if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) { new_cnode->set_abstract(new_inputs[1]->abstract()); } else { diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index c4850696bd0..196d1f09c1b 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -130,7 +130,7 @@ class IrExportBuilder { bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); bool BuildValueNode(const ValueNodePtr &node, const std::string &node_name, mind_ir::GraphProto *const graph_proto); std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto); - + bool BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto); bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter, @@ -1006,6 +1006,11 @@ bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons (void)std::for_each(input_names.begin(), input_names.end(), [&node_proto](const string &name) { node_proto->add_input(name); }); + + if (!BuildCNodeAttr(node, node_proto)) { + MS_LOG(ERROR) << "Set value to node attr to node proto failed."; + return false; + } return true; } @@ -1497,6 +1502,29 @@ bool IrExportBuilder::SetDictToAttributeProto(const ValueDictionaryPtr &value_di return true; } +bool IrExportBuilder::BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto) { + for (const auto &attr : node->attrs()) { + mind_ir::AttributeProto *attr_proto = node_proto->add_node_attr(); + attr_proto->set_name(attr.first); + if (!SetValueToAttributeProto(attr.second, attr_proto)) { + MS_LOG(ERROR) << "Set value to node attr to node proto failed."; + MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}"; + return false; + } + } + + for (const auto &attr : node->primal_attrs()) { + mind_ir::AttributeProto *attr_proto = node_proto->add_primal_attr(); + attr_proto->set_name(attr.first); + if (!SetValueToAttributeProto(attr.second, attr_proto)) { + MS_LOG(ERROR) << "Set value to node primal attr to node proto failed."; + MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}"; + return false; + } + } + return true; +} + std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) { auto builder = std::make_shared(incremental); if (builder == nullptr) { diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 997dc89d2cb..084ddeb7cea 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -267,6 +267,68 @@ AnfNodePtr NewValueNodeWithAbstract(const T &value) { return node; } } // namespace +ValuePtr MSANFModelParser::GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) { + auto attr_name = attr_proto.name(); + switch (attr_proto.type()) { + case mind_ir::AttributeProto_AttributeType_TENSORS: { + mind_ir::TensorProto tensor_proto = attr_proto.tensors(0); + if (tensor_proto.has_raw_data()) { + // For real tensor. + tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Failed to get the tensor for ValueNode."; + return nullptr; + } + return tensor_info; + } else if (tensor_proto.name() == kQuantParam) { + auto quantization_param_vector = GenerateQuantizationParam(tensor_proto); + if (!quantization_param_vector.empty()) { + return quantization_param_vector[0]; + } + } else { + // For data type. + const int attr_tensor_type = tensor_proto.data_type(); + auto iter = kDefaultValueSwitchMap.find(attr_tensor_type); + if (iter == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; + return nullptr; + } + return TypeIdToType(iter->second); + } + MS_LOG(ERROR) << "Failed to get the tensor for value."; + return nullptr; + } + case mind_ir::AttributeProto_AttributeType_NONE: { + return kNone; + } + case mind_ir::AttributeProto_AttributeType_TUPLE: + case mind_ir::AttributeProto_AttributeType_LIST: { + auto sequence_value = ObtainValueInSequenceForm(attr_proto); + if (sequence_value == nullptr) { + MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name; + return nullptr; + } + return sequence_value; + } + case mind_ir::AttributeProto_AttributeType_DICT: { + auto dict_value = ObtainValueInDictionaryForm(attr_proto); + if (dict_value == nullptr) { + MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name; + return nullptr; + } + return dict_value; + } + default: { + ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto); + if (value == nullptr) { + MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name; + return nullptr; + } + return value; + } + } + return nullptr; +} tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor) { ShapeVector shape; @@ -981,79 +1043,25 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, bool MSANFModelParser::SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) { MS_EXCEPTION_IF_NULL(prim); const std::string &attr_name = attr_proto.name(); - switch (attr_proto.type()) { - case mind_ir::AttributeProto_AttributeType_TENSORS: { - mind_ir::TensorProto tensor_proto = attr_proto.tensors(0); - if (tensor_proto.has_raw_data()) { - // For real tensor. - tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "Failed to get the tensor for ValueNode."; - return false; - } - (void)prim->AddAttr(attr_name, tensor_info); - } else if (tensor_proto.name() == kQuantParam) { - auto quantization_param_vector = GenerateQuantizationParam(tensor_proto); - if (!quantization_param_vector.empty()) { - (void)prim->AddAttr(kQuantParam, quantization_param_vector[0]); - } - } else { - // For data type. - const int attr_tensor_type = tensor_proto.data_type(); - auto iter = kDefaultValueSwitchMap.find(attr_tensor_type); - if (iter == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; - return false; - } - (void)prim->AddAttr(attr_name, TypeIdToType(iter->second)); - } - break; - } - case mind_ir::AttributeProto_AttributeType_NONE: { - (void)prim->AddAttr(attr_name, kNone); - break; - } - case mind_ir::AttributeProto_AttributeType_TUPLE: - case mind_ir::AttributeProto_AttributeType_LIST: { - auto sequence_value = ObtainValueInSequenceForm(attr_proto); - if (sequence_value == nullptr) { - MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name; - return false; - } - (void)prim->AddAttr(attr_name, sequence_value); - break; - } - case mind_ir::AttributeProto_AttributeType_DICT: { - auto dict_value = ObtainValueInDictionaryForm(attr_proto); - if (dict_value == nullptr) { - MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name; - return false; - } - (void)prim->AddAttr(attr_name, dict_value); - break; - } - default: { - ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto); - if (value == nullptr) { - MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name; - return false; - } - const std::string &op_type = prim->name(); - CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value); - if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa()) { - auto str_dtype = GetValue(value); - if (str_dtype == "int32") { - int64_t index = 3; - (void)prim->AddAttr(attr_name, MakeValue(index)); - break; - } - MS_EXCEPTION(NotSupportError) - << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got" - << value->ToString(); - } - (void)prim->AddAttr(attr_name, value); - } + auto value = GetValueFromAttributeProto(attr_proto); + if (value == nullptr) { + MS_LOG(ERROR) << "Failed to get value from proto.\n proto info:" << attr_proto.name(); + return false; } + const std::string &op_type = prim->name(); + CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value); + // Compatible with older versions. + if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa()) { + auto str_dtype = GetValue(value); + if (str_dtype == "int32") { + int64_t index = 3; + (void)prim->AddAttr(attr_name, MakeValue(index)); + } + MS_EXCEPTION(NotSupportError) + << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got" + << value->ToString(); + } + (void)prim->AddAttr(attr_name, value); return true; } @@ -1655,6 +1663,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc // Set Abstract and prim attr for CNode SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr); + BuildAttrForCNode(cnode_ptr, node_proto); return cnode_ptr; } @@ -1739,39 +1748,12 @@ bool MSANFModelParser::BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph const mind_ir::GraphProto &importProto) { for (auto i = 0; i < importProto.attribute_size(); ++i) { const mind_ir::AttributeProto &attr_proto = importProto.attribute(i); - const int attr_type = attr_proto.type(); - switch (attr_type) { - case mind_ir::AttributeProto_AttributeType_STRING: { - outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_string_string(attr_proto)); - break; - } - case mind_ir::AttributeProto_AttributeType_BOOL: { - outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int32_t_bool(attr_proto)); - break; - } - case mind_ir::AttributeProto_AttributeType_INT32: { - outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int32_t_int32_t(attr_proto)); - break; - } - case mind_ir::AttributeProto_AttributeType_INT64: { - outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int64_t_int64_t(attr_proto)); - break; - } - case mind_ir::AttributeProto_AttributeType_TUPLE: - case mind_ir::AttributeProto_AttributeType_LIST: { - auto sequence_value = ObtainValueInSequenceForm(attr_proto); - if (sequence_value == nullptr) { - MS_LOG(ERROR) << "Failed to get sequence value for " << attr_proto.name(); - return false; - } - outputFuncGraph->set_attr(attr_proto.name(), sequence_value); - break; - } - default: - MS_LOG(ERROR) << "Obtain attr for graph has not support input type: " << attr_type - << ", attr name: " << attr_proto.name(); - return false; + auto value = GetValueFromAttributeProto(attr_proto); + if (value == nullptr) { + MS_LOG(ERROR) << "Failed set func_graph attr to func_graph"; + return false; } + outputFuncGraph->set_attr(attr_proto.name(), value); } return true; } @@ -1876,24 +1858,27 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto, if (IsLite()) { abstract_valid_ = true; } - FuncGraphPtr dstGraph = std::make_shared(); + if (!MSANFParseModelConfigureInfo(model_proto)) { MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; } - for (int i = 0; i < model_proto.primitives_size(); ++i) { - if (!BuildPrimitiveNode(model_proto.primitives(i))) { - MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString(); - return nullptr; - } - } - if (model_proto.has_little_endian()) { if (model_proto.little_endian() != this->little_endian()) { MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!"; return nullptr; } } + + FuncGraphPtr dstGraph = std::make_shared(); + + for (int i = 0; i < model_proto.primitives_size(); ++i) { + if (!BuildPrimitiveNode(model_proto.primitives(i))) { + MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString(); + return nullptr; + } + } + const mind_ir::GraphProto &graphBuild = model_proto.graph(); // Forward declare FuncGraph name @@ -2126,4 +2111,26 @@ void MSANFModelParser::CorrectFuncGraph(const FuncGraphPtr &root) { } MS_LOG(DEBUG) << "End to correct the funcgraph."; } + +bool MSANFModelParser::BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto) { + for (auto i = 0; i < node_proto.node_attr_size(); ++i) { + const auto &attr_proto = node_proto.node_attr(i); + auto value = GetValueFromAttributeProto(attr_proto); + if (value == nullptr) { + MS_LOG(ERROR) << "Failed set func_graph attr to func_graph"; + return false; + } + cnode->AddAttr(attr_proto.name(), value); + } + for (auto i = 0; i < node_proto.primal_attr_size(); ++i) { + const auto &attr_proto = node_proto.primal_attr(i); + auto value = GetValueFromAttributeProto(attr_proto); + if (value == nullptr) { + MS_LOG(ERROR) << "Failed set func_graph attr to func_graph"; + return false; + } + cnode->AddPrimalAttr(attr_proto.name(), value); + } + return true; +} } // namespace mindspore diff --git a/mindspore/core/load_mindir/anf_model_parser.h b/mindspore/core/load_mindir/anf_model_parser.h index d3e0e6a4562..99352eb9975 100644 --- a/mindspore/core/load_mindir/anf_model_parser.h +++ b/mindspore/core/load_mindir/anf_model_parser.h @@ -90,6 +90,8 @@ class MSANFModelParser { void CorrectFuncGraph(const FuncGraphPtr &root); bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); + bool BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto); + ValuePtr GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto); bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); @@ -127,7 +129,6 @@ class MSANFModelParser { bool GetAttrValueForValueNodeWithType(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); bool ObtainValueNodeInNoneForm(const std::string &value_node_name); - bool ObtainValueNodeInTypeNullForm(const std::string &value_node_name); bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto); ValuePtr ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto); diff --git a/mindspore/core/proto/mind_ir.proto b/mindspore/core/proto/mind_ir.proto index 79bd30f2b71..282e73915b8 100644 --- a/mindspore/core/proto/mind_ir.proto +++ b/mindspore/core/proto/mind_ir.proto @@ -84,6 +84,8 @@ message NodeProto { repeated AttributeProto attribute = 5; optional string doc_string = 6; optional string domain = 7; + repeated AttributeProto node_attr = 8; + repeated AttributeProto primal_attr = 9; } diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 953ff56032a..3f9709f933f 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -80,6 +80,7 @@ if(ENABLE_MINDDATA) ./plugin/device/cpu/hal/*.cc ./place/*.cc ./ops/test_ops_fake_quant_param.cc + ./mindir/*.cc ) if(NOT ENABLE_SECURITY) file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} diff --git a/tests/ut/cpp/mindir/test_node_attr_export.cc b/tests/ut/cpp/mindir/test_node_attr_export.cc new file mode 100644 index 00000000000..b8b97304bcd --- /dev/null +++ b/tests/ut/cpp/mindir/test_node_attr_export.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common_test.h" +#include "common/py_func_graph_fetcher.h" + +#include "pipeline/jit/resource.h" +#include "pipeline/jit/action.h" +#include "include/common/debug/dump_proto.h" +#include "load_mindir/load_model.h" +#include "mindspore/core/ops/core_ops.h" +#include "ir/anf.h" +#include "ir/tensor.h" + +namespace mindspore { +class TestLoadExport : public UT::Common { + public: + TestLoadExport() : getPyFun("gtest_input.mindir.mindir_test") {} + ~TestLoadExport() override = default; + // Expectation: No Expectation + UT::PyFuncGraphFetcher getPyFun; +}; + +/// Feature: MindIR node attribute export and load. +/// Description: Node attribute export and load. +/// Expectation: success. +TEST_F(TestLoadExport, test_export_func) { + auto func_graph = getPyFun.CallAndParseRet("export_test", "add_node_attr_test"); + tensor::TensorPtr t = std::make_shared(kFloat32->type_id(), std::vector{1, 2, 3}); + + auto export_return_node = func_graph->output(); + auto export_relu = export_return_node->cast(); + export_relu->AddAttr("TestAttr", MakeValue(true)); + export_relu->AddPrimalAttr("TestPrimalAttr", MakeValue(true)); + if (func_graph->manager() == nullptr) { + std::vector graphs{func_graph}; + FuncGraphManagerPtr manager = std::make_shared(graphs); + manager->AddFuncGraph(func_graph); + } + // Renormalize func_graph to infer and set shape and type information. + pipeline::ResourcePtr resource_ = std::make_shared(); + auto graph = pipeline::Renormalize(resource_, func_graph, {t->ToAbstract()}); + auto str = GetBinaryProtoString(graph); + mind_ir::ModelProto model_; + model_.ParseFromString(str); + MSANFModelParser model_parser; + FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); + auto return_node = dstgraph_ptr->output(); + auto load_relu = return_node->cast(); + auto test_primal_attr = load_relu->GetPrimalAttr("TestPrimalAttr"); + auto test_attr = load_relu->GetAttr("TestAttr"); + ASSERT_TRUE(GetValue(test_attr)); + ASSERT_TRUE(GetValue(test_primal_attr)); +} +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py b/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py new file mode 100644 index 00000000000..28c995d224a --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import operations as P + +relu = P.ReLU() + + +class FnDict: + def __init__(self): + self.fn_dict = {} + + def __call__(self, fn): + self.fn_dict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fn_dict.get(name, "") + + + +def export_test(tag): + """ test_adam_apply_one_with_decay_rule """ + fns = FnDict() + + @fns + def add_node_attr_test(x): + return relu(x) + + return fns[tag] diff --git a/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc b/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc index 2231b48c6cd..5af8c5d6eb6 100644 --- a/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc +++ b/tests/ut/cpp/stub/anf_ir/dump_proto_stub.cc @@ -23,11 +23,4 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { re std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; } std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; } - -std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) { return ""; } - -bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path, - const FuncGraphPtr ¶m_layout_fg) { - return true; -} } // namespace mindspore