forked from mindspore-Ecosystem/mindspore
export node attr and primal attr
This commit is contained in:
parent
c39ea4032e
commit
10de370dd8
|
@ -89,6 +89,7 @@
|
|||
"mindspore/tests/vm_impl/array_ops_vm_impl.py" "unused-variable"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_compile.py" "unused-import"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py" "super-init-not-called"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py" "unused-variable"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable"
|
||||
"mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable"
|
||||
|
|
|
@ -97,6 +97,7 @@ AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGra
|
|||
auto new_cnode = NewCNode(new_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_primal_attrs(cnode->primal_attrs());
|
||||
new_cnode->set_attrs(cnode->attrs());
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
|
||||
new_cnode->set_abstract(new_inputs[1]->abstract());
|
||||
} else {
|
||||
|
|
|
@ -130,7 +130,7 @@ class IrExportBuilder {
|
|||
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
bool BuildValueNode(const ValueNodePtr &node, const std::string &node_name, mind_ir::GraphProto *const graph_proto);
|
||||
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||
|
||||
bool BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
||||
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
||||
bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
|
||||
|
@ -1006,6 +1006,11 @@ bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
|||
|
||||
(void)std::for_each(input_names.begin(), input_names.end(),
|
||||
[&node_proto](const string &name) { node_proto->add_input(name); });
|
||||
|
||||
if (!BuildCNodeAttr(node, node_proto)) {
|
||||
MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1497,6 +1502,29 @@ bool IrExportBuilder::SetDictToAttributeProto(const ValueDictionaryPtr &value_di
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
|
||||
for (const auto &attr : node->attrs()) {
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_node_attr();
|
||||
attr_proto->set_name(attr.first);
|
||||
if (!SetValueToAttributeProto(attr.second, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
|
||||
MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &attr : node->primal_attrs()) {
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_primal_attr();
|
||||
attr_proto->set_name(attr.first);
|
||||
if (!SetValueToAttributeProto(attr.second, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Set value to node primal attr to node proto failed.";
|
||||
MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) {
|
||||
auto builder = std::make_shared<IrExportBuilder>(incremental);
|
||||
if (builder == nullptr) {
|
||||
|
|
|
@ -267,6 +267,68 @@ AnfNodePtr NewValueNodeWithAbstract(const T &value) {
|
|||
return node;
|
||||
}
|
||||
} // namespace
|
||||
ValuePtr MSANFModelParser::GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
|
||||
auto attr_name = attr_proto.name();
|
||||
switch (attr_proto.type()) {
|
||||
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
||||
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
|
||||
if (tensor_proto.has_raw_data()) {
|
||||
// For real tensor.
|
||||
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
|
||||
return nullptr;
|
||||
}
|
||||
return tensor_info;
|
||||
} else if (tensor_proto.name() == kQuantParam) {
|
||||
auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
|
||||
if (!quantization_param_vector.empty()) {
|
||||
return quantization_param_vector[0];
|
||||
}
|
||||
} else {
|
||||
// For data type.
|
||||
const int attr_tensor_type = tensor_proto.data_type();
|
||||
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
|
||||
if (iter == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
return nullptr;
|
||||
}
|
||||
return TypeIdToType(iter->second);
|
||||
}
|
||||
MS_LOG(ERROR) << "Failed to get the tensor for value.";
|
||||
return nullptr;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_NONE: {
|
||||
return kNone;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
||||
case mind_ir::AttributeProto_AttributeType_LIST: {
|
||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
||||
if (sequence_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
|
||||
return nullptr;
|
||||
}
|
||||
return sequence_value;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_DICT: {
|
||||
auto dict_value = ObtainValueInDictionaryForm(attr_proto);
|
||||
if (dict_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name;
|
||||
return nullptr;
|
||||
}
|
||||
return dict_value;
|
||||
}
|
||||
default: {
|
||||
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name;
|
||||
return nullptr;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor) {
|
||||
ShapeVector shape;
|
||||
|
@ -981,79 +1043,25 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
|
|||
bool MSANFModelParser::SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const std::string &attr_name = attr_proto.name();
|
||||
switch (attr_proto.type()) {
|
||||
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
||||
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
|
||||
if (tensor_proto.has_raw_data()) {
|
||||
// For real tensor.
|
||||
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
|
||||
return false;
|
||||
}
|
||||
(void)prim->AddAttr(attr_name, tensor_info);
|
||||
} else if (tensor_proto.name() == kQuantParam) {
|
||||
auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
|
||||
if (!quantization_param_vector.empty()) {
|
||||
(void)prim->AddAttr(kQuantParam, quantization_param_vector[0]);
|
||||
}
|
||||
} else {
|
||||
// For data type.
|
||||
const int attr_tensor_type = tensor_proto.data_type();
|
||||
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
|
||||
if (iter == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
(void)prim->AddAttr(attr_name, TypeIdToType(iter->second));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_NONE: {
|
||||
(void)prim->AddAttr(attr_name, kNone);
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
||||
case mind_ir::AttributeProto_AttributeType_LIST: {
|
||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
||||
if (sequence_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
|
||||
return false;
|
||||
}
|
||||
(void)prim->AddAttr(attr_name, sequence_value);
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_DICT: {
|
||||
auto dict_value = ObtainValueInDictionaryForm(attr_proto);
|
||||
if (dict_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name;
|
||||
return false;
|
||||
}
|
||||
(void)prim->AddAttr(attr_name, dict_value);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name;
|
||||
return false;
|
||||
}
|
||||
const std::string &op_type = prim->name();
|
||||
CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value);
|
||||
if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa<StringImm>()) {
|
||||
auto str_dtype = GetValue<std::string>(value);
|
||||
if (str_dtype == "int32") {
|
||||
int64_t index = 3;
|
||||
(void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
|
||||
break;
|
||||
}
|
||||
MS_EXCEPTION(NotSupportError)
|
||||
<< "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
|
||||
<< value->ToString();
|
||||
}
|
||||
(void)prim->AddAttr(attr_name, value);
|
||||
}
|
||||
auto value = GetValueFromAttributeProto(attr_proto);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get value from proto.\n proto info:" << attr_proto.name();
|
||||
return false;
|
||||
}
|
||||
const std::string &op_type = prim->name();
|
||||
CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value);
|
||||
// Compatible with older versions.
|
||||
if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa<StringImm>()) {
|
||||
auto str_dtype = GetValue<std::string>(value);
|
||||
if (str_dtype == "int32") {
|
||||
int64_t index = 3;
|
||||
(void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
|
||||
}
|
||||
MS_EXCEPTION(NotSupportError)
|
||||
<< "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
|
||||
<< value->ToString();
|
||||
}
|
||||
(void)prim->AddAttr(attr_name, value);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1655,6 +1663,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
|
||||
// Set Abstract and prim attr for CNode
|
||||
SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr);
|
||||
BuildAttrForCNode(cnode_ptr, node_proto);
|
||||
return cnode_ptr;
|
||||
}
|
||||
|
||||
|
@ -1739,39 +1748,12 @@ bool MSANFModelParser::BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph
|
|||
const mind_ir::GraphProto &importProto) {
|
||||
for (auto i = 0; i < importProto.attribute_size(); ++i) {
|
||||
const mind_ir::AttributeProto &attr_proto = importProto.attribute(i);
|
||||
const int attr_type = attr_proto.type();
|
||||
switch (attr_type) {
|
||||
case mind_ir::AttributeProto_AttributeType_STRING: {
|
||||
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_string_string(attr_proto));
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_BOOL: {
|
||||
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int32_t_bool(attr_proto));
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT32: {
|
||||
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int32_t_int32_t(attr_proto));
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_INT64: {
|
||||
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int64_t_int64_t(attr_proto));
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
||||
case mind_ir::AttributeProto_AttributeType_LIST: {
|
||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
||||
if (sequence_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_proto.name();
|
||||
return false;
|
||||
}
|
||||
outputFuncGraph->set_attr(attr_proto.name(), sequence_value);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr for graph has not support input type: " << attr_type
|
||||
<< ", attr name: " << attr_proto.name();
|
||||
return false;
|
||||
auto value = GetValueFromAttributeProto(attr_proto);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
|
||||
return false;
|
||||
}
|
||||
outputFuncGraph->set_attr(attr_proto.name(), value);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -1876,24 +1858,27 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto,
|
|||
if (IsLite()) {
|
||||
abstract_valid_ = true;
|
||||
}
|
||||
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
|
||||
|
||||
if (!MSANFParseModelConfigureInfo(model_proto)) {
|
||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
}
|
||||
|
||||
for (int i = 0; i < model_proto.primitives_size(); ++i) {
|
||||
if (!BuildPrimitiveNode(model_proto.primitives(i))) {
|
||||
MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (model_proto.has_little_endian()) {
|
||||
if (model_proto.little_endian() != this->little_endian()) {
|
||||
MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
|
||||
|
||||
for (int i = 0; i < model_proto.primitives_size(); ++i) {
|
||||
if (!BuildPrimitiveNode(model_proto.primitives(i))) {
|
||||
MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const mind_ir::GraphProto &graphBuild = model_proto.graph();
|
||||
|
||||
// Forward declare FuncGraph name
|
||||
|
@ -2126,4 +2111,26 @@ void MSANFModelParser::CorrectFuncGraph(const FuncGraphPtr &root) {
|
|||
}
|
||||
MS_LOG(DEBUG) << "End to correct the funcgraph.";
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto) {
|
||||
for (auto i = 0; i < node_proto.node_attr_size(); ++i) {
|
||||
const auto &attr_proto = node_proto.node_attr(i);
|
||||
auto value = GetValueFromAttributeProto(attr_proto);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
|
||||
return false;
|
||||
}
|
||||
cnode->AddAttr(attr_proto.name(), value);
|
||||
}
|
||||
for (auto i = 0; i < node_proto.primal_attr_size(); ++i) {
|
||||
const auto &attr_proto = node_proto.primal_attr(i);
|
||||
auto value = GetValueFromAttributeProto(attr_proto);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
|
||||
return false;
|
||||
}
|
||||
cnode->AddPrimalAttr(attr_proto.name(), value);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -90,6 +90,8 @@ class MSANFModelParser {
|
|||
void CorrectFuncGraph(const FuncGraphPtr &root);
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto);
|
||||
ValuePtr GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
|
@ -127,7 +129,6 @@ class MSANFModelParser {
|
|||
bool GetAttrValueForValueNodeWithType(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInNoneForm(const std::string &value_node_name);
|
||||
bool ObtainValueNodeInTypeNullForm(const std::string &value_node_name);
|
||||
bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto);
|
||||
ValuePtr ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto);
|
||||
|
|
|
@ -84,6 +84,8 @@ message NodeProto {
|
|||
repeated AttributeProto attribute = 5;
|
||||
optional string doc_string = 6;
|
||||
optional string domain = 7;
|
||||
repeated AttributeProto node_attr = 8;
|
||||
repeated AttributeProto primal_attr = 9;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@ if(ENABLE_MINDDATA)
|
|||
./plugin/device/cpu/hal/*.cc
|
||||
./place/*.cc
|
||||
./ops/test_ops_fake_quant_param.cc
|
||||
./mindir/*.cc
|
||||
)
|
||||
if(NOT ENABLE_SECURITY)
|
||||
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
#include "include/common/debug/dump_proto.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestLoadExport : public UT::Common {
|
||||
public:
|
||||
TestLoadExport() : getPyFun("gtest_input.mindir.mindir_test") {}
|
||||
~TestLoadExport() override = default;
|
||||
// Expectation: No Expectation
|
||||
UT::PyFuncGraphFetcher getPyFun;
|
||||
};
|
||||
|
||||
/// Feature: MindIR node attribute export and load.
|
||||
/// Description: Node attribute export and load.
|
||||
/// Expectation: success.
|
||||
TEST_F(TestLoadExport, test_export_func) {
|
||||
auto func_graph = getPyFun.CallAndParseRet("export_test", "add_node_attr_test");
|
||||
tensor::TensorPtr t = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{1, 2, 3});
|
||||
|
||||
auto export_return_node = func_graph->output();
|
||||
auto export_relu = export_return_node->cast<CNodePtr>();
|
||||
export_relu->AddAttr("TestAttr", MakeValue(true));
|
||||
export_relu->AddPrimalAttr("TestPrimalAttr", MakeValue(true));
|
||||
if (func_graph->manager() == nullptr) {
|
||||
std::vector<FuncGraphPtr> graphs{func_graph};
|
||||
FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(graphs);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
}
|
||||
// Renormalize func_graph to infer and set shape and type information.
|
||||
pipeline::ResourcePtr resource_ = std::make_shared<pipeline::Resource>();
|
||||
auto graph = pipeline::Renormalize(resource_, func_graph, {t->ToAbstract()});
|
||||
auto str = GetBinaryProtoString(graph);
|
||||
mind_ir::ModelProto model_;
|
||||
model_.ParseFromString(str);
|
||||
MSANFModelParser model_parser;
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||
auto return_node = dstgraph_ptr->output();
|
||||
auto load_relu = return_node->cast<CNodePtr>();
|
||||
auto test_primal_attr = load_relu->GetPrimalAttr("TestPrimalAttr");
|
||||
auto test_attr = load_relu->GetAttr("TestAttr");
|
||||
ASSERT_TRUE(GetValue<bool>(test_attr));
|
||||
ASSERT_TRUE(GetValue<bool>(test_primal_attr));
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
relu = P.ReLU()
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fn_dict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fn_dict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fn_dict.get(name, "")
|
||||
|
||||
|
||||
|
||||
def export_test(tag):
|
||||
""" test_adam_apply_one_with_decay_rule """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def add_node_attr_test(x):
|
||||
return relu(x)
|
||||
|
||||
return fns[tag]
|
|
@ -23,11 +23,4 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { re
|
|||
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||
|
||||
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) { return ""; }
|
||||
|
||||
bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path,
|
||||
const FuncGraphPtr ¶m_layout_fg) {
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue