export node attr and primal attr

This commit is contained in:
lianliguang 2023-01-13 16:03:41 +08:00
parent c39ea4032e
commit 10de370dd8
10 changed files with 263 additions and 121 deletions

View File

@ -89,6 +89,7 @@
"mindspore/tests/vm_impl/array_ops_vm_impl.py" "unused-variable" "mindspore/tests/vm_impl/array_ops_vm_impl.py" "unused-variable"
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_compile.py" "unused-import" "mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_compile.py" "unused-import"
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py" "super-init-not-called" "mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py" "super-init-not-called"
"mindspore/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py" "unused-variable"
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called" "mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called"
"mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable" "mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable"
"mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable" "mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable"

View File

@ -97,6 +97,7 @@ AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGra
auto new_cnode = NewCNode(new_inputs, func_graph); auto new_cnode = NewCNode(new_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_cnode); MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_primal_attrs(cnode->primal_attrs()); new_cnode->set_primal_attrs(cnode->primal_attrs());
new_cnode->set_attrs(cnode->attrs());
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) { if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
new_cnode->set_abstract(new_inputs[1]->abstract()); new_cnode->set_abstract(new_inputs[1]->abstract());
} else { } else {

View File

@ -130,7 +130,7 @@ class IrExportBuilder {
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
bool BuildValueNode(const ValueNodePtr &node, const std::string &node_name, mind_ir::GraphProto *const graph_proto); bool BuildValueNode(const ValueNodePtr &node, const std::string &node_name, mind_ir::GraphProto *const graph_proto);
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto); std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
bool BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
bool SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto); bool SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter, bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
@ -1006,6 +1006,11 @@ bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
(void)std::for_each(input_names.begin(), input_names.end(), (void)std::for_each(input_names.begin(), input_names.end(),
[&node_proto](const string &name) { node_proto->add_input(name); }); [&node_proto](const string &name) { node_proto->add_input(name); });
if (!BuildCNodeAttr(node, node_proto)) {
MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
return false;
}
return true; return true;
} }
@ -1497,6 +1502,29 @@ bool IrExportBuilder::SetDictToAttributeProto(const ValueDictionaryPtr &value_di
return true; return true;
} }
bool IrExportBuilder::BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
for (const auto &attr : node->attrs()) {
mind_ir::AttributeProto *attr_proto = node_proto->add_node_attr();
attr_proto->set_name(attr.first);
if (!SetValueToAttributeProto(attr.second, attr_proto)) {
MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
return false;
}
}
for (const auto &attr : node->primal_attrs()) {
mind_ir::AttributeProto *attr_proto = node_proto->add_primal_attr();
attr_proto->set_name(attr.first);
if (!SetValueToAttributeProto(attr.second, attr_proto)) {
MS_LOG(ERROR) << "Set value to node primal attr to node proto failed.";
MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
return false;
}
}
return true;
}
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) { std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) {
auto builder = std::make_shared<IrExportBuilder>(incremental); auto builder = std::make_shared<IrExportBuilder>(incremental);
if (builder == nullptr) { if (builder == nullptr) {

View File

@ -267,6 +267,68 @@ AnfNodePtr NewValueNodeWithAbstract(const T &value) {
return node; return node;
} }
} // namespace } // namespace
ValuePtr MSANFModelParser::GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
auto attr_name = attr_proto.name();
switch (attr_proto.type()) {
case mind_ir::AttributeProto_AttributeType_TENSORS: {
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
if (tensor_proto.has_raw_data()) {
// For real tensor.
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
return nullptr;
}
return tensor_info;
} else if (tensor_proto.name() == kQuantParam) {
auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
if (!quantization_param_vector.empty()) {
return quantization_param_vector[0];
}
} else {
// For data type.
const int attr_tensor_type = tensor_proto.data_type();
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
if (iter == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
return nullptr;
}
return TypeIdToType(iter->second);
}
MS_LOG(ERROR) << "Failed to get the tensor for value.";
return nullptr;
}
case mind_ir::AttributeProto_AttributeType_NONE: {
return kNone;
}
case mind_ir::AttributeProto_AttributeType_TUPLE:
case mind_ir::AttributeProto_AttributeType_LIST: {
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
if (sequence_value == nullptr) {
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
return nullptr;
}
return sequence_value;
}
case mind_ir::AttributeProto_AttributeType_DICT: {
auto dict_value = ObtainValueInDictionaryForm(attr_proto);
if (dict_value == nullptr) {
MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name;
return nullptr;
}
return dict_value;
}
default: {
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
if (value == nullptr) {
MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name;
return nullptr;
}
return value;
}
}
return nullptr;
}
tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor) { tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor) {
ShapeVector shape; ShapeVector shape;
@ -981,79 +1043,25 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
bool MSANFModelParser::SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) { bool MSANFModelParser::SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
const std::string &attr_name = attr_proto.name(); const std::string &attr_name = attr_proto.name();
switch (attr_proto.type()) { auto value = GetValueFromAttributeProto(attr_proto);
case mind_ir::AttributeProto_AttributeType_TENSORS: {
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
if (tensor_proto.has_raw_data()) {
// For real tensor.
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
return false;
}
(void)prim->AddAttr(attr_name, tensor_info);
} else if (tensor_proto.name() == kQuantParam) {
auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
if (!quantization_param_vector.empty()) {
(void)prim->AddAttr(kQuantParam, quantization_param_vector[0]);
}
} else {
// For data type.
const int attr_tensor_type = tensor_proto.data_type();
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
if (iter == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
return false;
}
(void)prim->AddAttr(attr_name, TypeIdToType(iter->second));
}
break;
}
case mind_ir::AttributeProto_AttributeType_NONE: {
(void)prim->AddAttr(attr_name, kNone);
break;
}
case mind_ir::AttributeProto_AttributeType_TUPLE:
case mind_ir::AttributeProto_AttributeType_LIST: {
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
if (sequence_value == nullptr) {
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
return false;
}
(void)prim->AddAttr(attr_name, sequence_value);
break;
}
case mind_ir::AttributeProto_AttributeType_DICT: {
auto dict_value = ObtainValueInDictionaryForm(attr_proto);
if (dict_value == nullptr) {
MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name;
return false;
}
(void)prim->AddAttr(attr_name, dict_value);
break;
}
default: {
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
if (value == nullptr) { if (value == nullptr) {
MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name; MS_LOG(ERROR) << "Failed to get value from proto.\n proto info:" << attr_proto.name();
return false; return false;
} }
const std::string &op_type = prim->name(); const std::string &op_type = prim->name();
CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value); CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value);
// Compatible with older versions.
if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa<StringImm>()) { if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa<StringImm>()) {
auto str_dtype = GetValue<std::string>(value); auto str_dtype = GetValue<std::string>(value);
if (str_dtype == "int32") { if (str_dtype == "int32") {
int64_t index = 3; int64_t index = 3;
(void)prim->AddAttr(attr_name, MakeValue<int64_t>(index)); (void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
break;
} }
MS_EXCEPTION(NotSupportError) MS_EXCEPTION(NotSupportError)
<< "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got" << "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
<< value->ToString(); << value->ToString();
} }
(void)prim->AddAttr(attr_name, value); (void)prim->AddAttr(attr_name, value);
}
}
return true; return true;
} }
@ -1655,6 +1663,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
// Set Abstract and prim attr for CNode // Set Abstract and prim attr for CNode
SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr); SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr);
BuildAttrForCNode(cnode_ptr, node_proto);
return cnode_ptr; return cnode_ptr;
} }
@ -1739,39 +1748,12 @@ bool MSANFModelParser::BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph
const mind_ir::GraphProto &importProto) { const mind_ir::GraphProto &importProto) {
for (auto i = 0; i < importProto.attribute_size(); ++i) { for (auto i = 0; i < importProto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = importProto.attribute(i); const mind_ir::AttributeProto &attr_proto = importProto.attribute(i);
const int attr_type = attr_proto.type(); auto value = GetValueFromAttributeProto(attr_proto);
switch (attr_type) { if (value == nullptr) {
case mind_ir::AttributeProto_AttributeType_STRING: { MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_string_string(attr_proto));
break;
}
case mind_ir::AttributeProto_AttributeType_BOOL: {
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int32_t_bool(attr_proto));
break;
}
case mind_ir::AttributeProto_AttributeType_INT32: {
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int32_t_int32_t(attr_proto));
break;
}
case mind_ir::AttributeProto_AttributeType_INT64: {
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int64_t_int64_t(attr_proto));
break;
}
case mind_ir::AttributeProto_AttributeType_TUPLE:
case mind_ir::AttributeProto_AttributeType_LIST: {
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
if (sequence_value == nullptr) {
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_proto.name();
return false;
}
outputFuncGraph->set_attr(attr_proto.name(), sequence_value);
break;
}
default:
MS_LOG(ERROR) << "Obtain attr for graph has not support input type: " << attr_type
<< ", attr name: " << attr_proto.name();
return false; return false;
} }
outputFuncGraph->set_attr(attr_proto.name(), value);
} }
return true; return true;
} }
@ -1876,24 +1858,27 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto,
if (IsLite()) { if (IsLite()) {
abstract_valid_ = true; abstract_valid_ = true;
} }
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
if (!MSANFParseModelConfigureInfo(model_proto)) { if (!MSANFParseModelConfigureInfo(model_proto)) {
MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
} }
for (int i = 0; i < model_proto.primitives_size(); ++i) {
if (!BuildPrimitiveNode(model_proto.primitives(i))) {
MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
return nullptr;
}
}
if (model_proto.has_little_endian()) { if (model_proto.has_little_endian()) {
if (model_proto.little_endian() != this->little_endian()) { if (model_proto.little_endian() != this->little_endian()) {
MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!"; MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!";
return nullptr; return nullptr;
} }
} }
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
for (int i = 0; i < model_proto.primitives_size(); ++i) {
if (!BuildPrimitiveNode(model_proto.primitives(i))) {
MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
return nullptr;
}
}
const mind_ir::GraphProto &graphBuild = model_proto.graph(); const mind_ir::GraphProto &graphBuild = model_proto.graph();
// Forward declare FuncGraph name // Forward declare FuncGraph name
@ -2126,4 +2111,26 @@ void MSANFModelParser::CorrectFuncGraph(const FuncGraphPtr &root) {
} }
MS_LOG(DEBUG) << "End to correct the funcgraph."; MS_LOG(DEBUG) << "End to correct the funcgraph.";
} }
bool MSANFModelParser::BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto) {
for (auto i = 0; i < node_proto.node_attr_size(); ++i) {
const auto &attr_proto = node_proto.node_attr(i);
auto value = GetValueFromAttributeProto(attr_proto);
if (value == nullptr) {
MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
return false;
}
cnode->AddAttr(attr_proto.name(), value);
}
for (auto i = 0; i < node_proto.primal_attr_size(); ++i) {
const auto &attr_proto = node_proto.primal_attr(i);
auto value = GetValueFromAttributeProto(attr_proto);
if (value == nullptr) {
MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
return false;
}
cnode->AddPrimalAttr(attr_proto.name(), value);
}
return true;
}
} // namespace mindspore } // namespace mindspore

View File

@ -90,6 +90,8 @@ class MSANFModelParser {
void CorrectFuncGraph(const FuncGraphPtr &root); void CorrectFuncGraph(const FuncGraphPtr &root);
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto);
ValuePtr GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
@ -127,7 +129,6 @@ class MSANFModelParser {
bool GetAttrValueForValueNodeWithType(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); bool GetAttrValueForValueNodeWithType(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
bool ObtainValueNodeInNoneForm(const std::string &value_node_name); bool ObtainValueNodeInNoneForm(const std::string &value_node_name);
bool ObtainValueNodeInTypeNullForm(const std::string &value_node_name);
bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto); ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto);
ValuePtr ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto); ValuePtr ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto);

View File

@ -84,6 +84,8 @@ message NodeProto {
repeated AttributeProto attribute = 5; repeated AttributeProto attribute = 5;
optional string doc_string = 6; optional string doc_string = 6;
optional string domain = 7; optional string domain = 7;
repeated AttributeProto node_attr = 8;
repeated AttributeProto primal_attr = 9;
} }

View File

@ -80,6 +80,7 @@ if(ENABLE_MINDDATA)
./plugin/device/cpu/hal/*.cc ./plugin/device/cpu/hal/*.cc
./place/*.cc ./place/*.cc
./ops/test_ops_fake_quant_param.cc ./ops/test_ops_fake_quant_param.cc
./mindir/*.cc
) )
if(NOT ENABLE_SECURITY) if(NOT ENABLE_SECURITY)
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}

View File

@ -0,0 +1,67 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/action.h"
#include "include/common/debug/dump_proto.h"
#include "load_mindir/load_model.h"
#include "mindspore/core/ops/core_ops.h"
#include "ir/anf.h"
#include "ir/tensor.h"
namespace mindspore {
class TestLoadExport : public UT::Common {
public:
TestLoadExport() : getPyFun("gtest_input.mindir.mindir_test") {}
~TestLoadExport() override = default;
// Expectation: No Expectation
UT::PyFuncGraphFetcher getPyFun;
};
/// Feature: MindIR node attribute export and load.
/// Description: Node attribute export and load.
/// Expectation: success.
TEST_F(TestLoadExport, test_export_func) {
auto func_graph = getPyFun.CallAndParseRet("export_test", "add_node_attr_test");
tensor::TensorPtr t = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{1, 2, 3});
auto export_return_node = func_graph->output();
auto export_relu = export_return_node->cast<CNodePtr>();
export_relu->AddAttr("TestAttr", MakeValue(true));
export_relu->AddPrimalAttr("TestPrimalAttr", MakeValue(true));
if (func_graph->manager() == nullptr) {
std::vector<FuncGraphPtr> graphs{func_graph};
FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(graphs);
manager->AddFuncGraph(func_graph);
}
// Renormalize func_graph to infer and set shape and type information.
pipeline::ResourcePtr resource_ = std::make_shared<pipeline::Resource>();
auto graph = pipeline::Renormalize(resource_, func_graph, {t->ToAbstract()});
auto str = GetBinaryProtoString(graph);
mind_ir::ModelProto model_;
model_.ParseFromString(str);
MSANFModelParser model_parser;
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
auto return_node = dstgraph_ptr->output();
auto load_relu = return_node->cast<CNodePtr>();
auto test_primal_attr = load_relu->GetPrimalAttr("TestPrimalAttr");
auto test_attr = load_relu->GetAttr("TestAttr");
ASSERT_TRUE(GetValue<bool>(test_attr));
ASSERT_TRUE(GetValue<bool>(test_primal_attr));
}
} // namespace mindspore

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
relu = P.ReLU()
class FnDict:
def __init__(self):
self.fn_dict = {}
def __call__(self, fn):
self.fn_dict[fn.__name__] = fn
def __getitem__(self, name):
return self.fn_dict.get(name, "")
def export_test(tag):
""" test_adam_apply_one_with_decay_rule """
fns = FnDict()
@fns
def add_node_attr_test(x):
return relu(x)
return fns[tag]

View File

@ -23,11 +23,4 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { re
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; } std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; } std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) { return ""; }
bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path,
const FuncGraphPtr &param_layout_fg) {
return true;
}
} // namespace mindspore } // namespace mindspore