forked from mindspore-Ecosystem/mindspore
export node attr and primal attr
This commit is contained in:
parent
c39ea4032e
commit
10de370dd8
|
@ -89,6 +89,7 @@
|
||||||
"mindspore/tests/vm_impl/array_ops_vm_impl.py" "unused-variable"
|
"mindspore/tests/vm_impl/array_ops_vm_impl.py" "unused-variable"
|
||||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_compile.py" "unused-import"
|
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_compile.py" "unused-import"
|
||||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py" "super-init-not-called"
|
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py" "super-init-not-called"
|
||||||
|
"mindspore/tests/ut/cpp/python_input/gtest_input/mindir/mindir_test.py" "unused-variable"
|
||||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called"
|
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called"
|
||||||
"mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable"
|
"mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable"
|
||||||
"mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable"
|
"mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable"
|
||||||
|
|
|
@ -97,6 +97,7 @@ AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGra
|
||||||
auto new_cnode = NewCNode(new_inputs, func_graph);
|
auto new_cnode = NewCNode(new_inputs, func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||||
new_cnode->set_primal_attrs(cnode->primal_attrs());
|
new_cnode->set_primal_attrs(cnode->primal_attrs());
|
||||||
|
new_cnode->set_attrs(cnode->attrs());
|
||||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
|
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
|
||||||
new_cnode->set_abstract(new_inputs[1]->abstract());
|
new_cnode->set_abstract(new_inputs[1]->abstract());
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -130,7 +130,7 @@ class IrExportBuilder {
|
||||||
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||||
bool BuildValueNode(const ValueNodePtr &node, const std::string &node_name, mind_ir::GraphProto *const graph_proto);
|
bool BuildValueNode(const ValueNodePtr &node, const std::string &node_name, mind_ir::GraphProto *const graph_proto);
|
||||||
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
|
||||||
|
bool BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||||
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
||||||
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
||||||
bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
|
bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
|
||||||
|
@ -1006,6 +1006,11 @@ bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
||||||
|
|
||||||
(void)std::for_each(input_names.begin(), input_names.end(),
|
(void)std::for_each(input_names.begin(), input_names.end(),
|
||||||
[&node_proto](const string &name) { node_proto->add_input(name); });
|
[&node_proto](const string &name) { node_proto->add_input(name); });
|
||||||
|
|
||||||
|
if (!BuildCNodeAttr(node, node_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1497,6 +1502,29 @@ bool IrExportBuilder::SetDictToAttributeProto(const ValueDictionaryPtr &value_di
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IrExportBuilder::BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
|
||||||
|
for (const auto &attr : node->attrs()) {
|
||||||
|
mind_ir::AttributeProto *attr_proto = node_proto->add_node_attr();
|
||||||
|
attr_proto->set_name(attr.first);
|
||||||
|
if (!SetValueToAttributeProto(attr.second, attr_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
|
||||||
|
MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &attr : node->primal_attrs()) {
|
||||||
|
mind_ir::AttributeProto *attr_proto = node_proto->add_primal_attr();
|
||||||
|
attr_proto->set_name(attr.first);
|
||||||
|
if (!SetValueToAttributeProto(attr.second, attr_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Set value to node primal attr to node proto failed.";
|
||||||
|
MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) {
|
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) {
|
||||||
auto builder = std::make_shared<IrExportBuilder>(incremental);
|
auto builder = std::make_shared<IrExportBuilder>(incremental);
|
||||||
if (builder == nullptr) {
|
if (builder == nullptr) {
|
||||||
|
|
|
@ -267,6 +267,68 @@ AnfNodePtr NewValueNodeWithAbstract(const T &value) {
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
ValuePtr MSANFModelParser::GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
|
||||||
|
auto attr_name = attr_proto.name();
|
||||||
|
switch (attr_proto.type()) {
|
||||||
|
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
||||||
|
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
|
||||||
|
if (tensor_proto.has_raw_data()) {
|
||||||
|
// For real tensor.
|
||||||
|
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
|
||||||
|
if (tensor_info == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensor_info;
|
||||||
|
} else if (tensor_proto.name() == kQuantParam) {
|
||||||
|
auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
|
||||||
|
if (!quantization_param_vector.empty()) {
|
||||||
|
return quantization_param_vector[0];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For data type.
|
||||||
|
const int attr_tensor_type = tensor_proto.data_type();
|
||||||
|
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
|
||||||
|
if (iter == kDefaultValueSwitchMap.end()) {
|
||||||
|
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return TypeIdToType(iter->second);
|
||||||
|
}
|
||||||
|
MS_LOG(ERROR) << "Failed to get the tensor for value.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
case mind_ir::AttributeProto_AttributeType_NONE: {
|
||||||
|
return kNone;
|
||||||
|
}
|
||||||
|
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
||||||
|
case mind_ir::AttributeProto_AttributeType_LIST: {
|
||||||
|
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
||||||
|
if (sequence_value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return sequence_value;
|
||||||
|
}
|
||||||
|
case mind_ir::AttributeProto_AttributeType_DICT: {
|
||||||
|
auto dict_value = ObtainValueInDictionaryForm(attr_proto);
|
||||||
|
if (dict_value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return dict_value;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||||
|
if (value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor) {
|
tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor) {
|
||||||
ShapeVector shape;
|
ShapeVector shape;
|
||||||
|
@ -981,79 +1043,25 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
|
||||||
bool MSANFModelParser::SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
|
bool MSANFModelParser::SetPrimitiveAttrWithType(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
const std::string &attr_name = attr_proto.name();
|
const std::string &attr_name = attr_proto.name();
|
||||||
switch (attr_proto.type()) {
|
auto value = GetValueFromAttributeProto(attr_proto);
|
||||||
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
|
||||||
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
|
|
||||||
if (tensor_proto.has_raw_data()) {
|
|
||||||
// For real tensor.
|
|
||||||
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
|
|
||||||
if (tensor_info == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
(void)prim->AddAttr(attr_name, tensor_info);
|
|
||||||
} else if (tensor_proto.name() == kQuantParam) {
|
|
||||||
auto quantization_param_vector = GenerateQuantizationParam(tensor_proto);
|
|
||||||
if (!quantization_param_vector.empty()) {
|
|
||||||
(void)prim->AddAttr(kQuantParam, quantization_param_vector[0]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// For data type.
|
|
||||||
const int attr_tensor_type = tensor_proto.data_type();
|
|
||||||
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
|
|
||||||
if (iter == kDefaultValueSwitchMap.end()) {
|
|
||||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
(void)prim->AddAttr(attr_name, TypeIdToType(iter->second));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mind_ir::AttributeProto_AttributeType_NONE: {
|
|
||||||
(void)prim->AddAttr(attr_name, kNone);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
|
||||||
case mind_ir::AttributeProto_AttributeType_LIST: {
|
|
||||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
|
||||||
if (sequence_value == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
(void)prim->AddAttr(attr_name, sequence_value);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mind_ir::AttributeProto_AttributeType_DICT: {
|
|
||||||
auto dict_value = ObtainValueInDictionaryForm(attr_proto);
|
|
||||||
if (dict_value == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to get dictionary value for " << attr_name;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
(void)prim->AddAttr(attr_name, dict_value);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
|
||||||
if (value == nullptr) {
|
if (value == nullptr) {
|
||||||
MS_LOG(ERROR) << "Can not get the value for attr: " << attr_name;
|
MS_LOG(ERROR) << "Failed to get value from proto.\n proto info:" << attr_proto.name();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const std::string &op_type = prim->name();
|
const std::string &op_type = prim->name();
|
||||||
CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value);
|
CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &value);
|
||||||
|
// Compatible with older versions.
|
||||||
if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa<StringImm>()) {
|
if (op_type == "HistogramFixedWidth" && attr_name == "dtype" && value->isa<StringImm>()) {
|
||||||
auto str_dtype = GetValue<std::string>(value);
|
auto str_dtype = GetValue<std::string>(value);
|
||||||
if (str_dtype == "int32") {
|
if (str_dtype == "int32") {
|
||||||
int64_t index = 3;
|
int64_t index = 3;
|
||||||
(void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
|
(void)prim->AddAttr(attr_name, MakeValue<int64_t>(index));
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
MS_EXCEPTION(NotSupportError)
|
MS_EXCEPTION(NotSupportError)
|
||||||
<< "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
|
<< "The primtive[HistogramFixedWidth] not supported only support attribute[dtype] is 'int32',but got"
|
||||||
<< value->ToString();
|
<< value->ToString();
|
||||||
}
|
}
|
||||||
(void)prim->AddAttr(attr_name, value);
|
(void)prim->AddAttr(attr_name, value);
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1655,6 +1663,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
||||||
|
|
||||||
// Set Abstract and prim attr for CNode
|
// Set Abstract and prim attr for CNode
|
||||||
SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr);
|
SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr);
|
||||||
|
BuildAttrForCNode(cnode_ptr, node_proto);
|
||||||
return cnode_ptr;
|
return cnode_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1739,39 +1748,12 @@ bool MSANFModelParser::BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph
|
||||||
const mind_ir::GraphProto &importProto) {
|
const mind_ir::GraphProto &importProto) {
|
||||||
for (auto i = 0; i < importProto.attribute_size(); ++i) {
|
for (auto i = 0; i < importProto.attribute_size(); ++i) {
|
||||||
const mind_ir::AttributeProto &attr_proto = importProto.attribute(i);
|
const mind_ir::AttributeProto &attr_proto = importProto.attribute(i);
|
||||||
const int attr_type = attr_proto.type();
|
auto value = GetValueFromAttributeProto(attr_proto);
|
||||||
switch (attr_type) {
|
if (value == nullptr) {
|
||||||
case mind_ir::AttributeProto_AttributeType_STRING: {
|
MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
|
||||||
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_string_string(attr_proto));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mind_ir::AttributeProto_AttributeType_BOOL: {
|
|
||||||
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int32_t_bool(attr_proto));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mind_ir::AttributeProto_AttributeType_INT32: {
|
|
||||||
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int32_t_int32_t(attr_proto));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mind_ir::AttributeProto_AttributeType_INT64: {
|
|
||||||
outputFuncGraph->set_attr(attr_proto.name(), ParseAttrInSingleScalar_int64_t_int64_t(attr_proto));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
|
||||||
case mind_ir::AttributeProto_AttributeType_LIST: {
|
|
||||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
|
||||||
if (sequence_value == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_proto.name();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
outputFuncGraph->set_attr(attr_proto.name(), sequence_value);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Obtain attr for graph has not support input type: " << attr_type
|
|
||||||
<< ", attr name: " << attr_proto.name();
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
outputFuncGraph->set_attr(attr_proto.name(), value);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1876,24 +1858,27 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto,
|
||||||
if (IsLite()) {
|
if (IsLite()) {
|
||||||
abstract_valid_ = true;
|
abstract_valid_ = true;
|
||||||
}
|
}
|
||||||
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
|
|
||||||
if (!MSANFParseModelConfigureInfo(model_proto)) {
|
if (!MSANFParseModelConfigureInfo(model_proto)) {
|
||||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < model_proto.primitives_size(); ++i) {
|
|
||||||
if (!BuildPrimitiveNode(model_proto.primitives(i))) {
|
|
||||||
MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model_proto.has_little_endian()) {
|
if (model_proto.has_little_endian()) {
|
||||||
if (model_proto.little_endian() != this->little_endian()) {
|
if (model_proto.little_endian() != this->little_endian()) {
|
||||||
MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!";
|
MS_LOG(ERROR) << "The byte order of export MindIr device and load MindIr device is not same!";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr dstGraph = std::make_shared<FuncGraph>();
|
||||||
|
|
||||||
|
for (int i = 0; i < model_proto.primitives_size(); ++i) {
|
||||||
|
if (!BuildPrimitiveNode(model_proto.primitives(i))) {
|
||||||
|
MS_LOG(ERROR) << "Parse primitives info for pb file failed! " << model_proto.primitives(i).DebugString();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const mind_ir::GraphProto &graphBuild = model_proto.graph();
|
const mind_ir::GraphProto &graphBuild = model_proto.graph();
|
||||||
|
|
||||||
// Forward declare FuncGraph name
|
// Forward declare FuncGraph name
|
||||||
|
@ -2126,4 +2111,26 @@ void MSANFModelParser::CorrectFuncGraph(const FuncGraphPtr &root) {
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "End to correct the funcgraph.";
|
MS_LOG(DEBUG) << "End to correct the funcgraph.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto) {
|
||||||
|
for (auto i = 0; i < node_proto.node_attr_size(); ++i) {
|
||||||
|
const auto &attr_proto = node_proto.node_attr(i);
|
||||||
|
auto value = GetValueFromAttributeProto(attr_proto);
|
||||||
|
if (value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cnode->AddAttr(attr_proto.name(), value);
|
||||||
|
}
|
||||||
|
for (auto i = 0; i < node_proto.primal_attr_size(); ++i) {
|
||||||
|
const auto &attr_proto = node_proto.primal_attr(i);
|
||||||
|
auto value = GetValueFromAttributeProto(attr_proto);
|
||||||
|
if (value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed set func_graph attr to func_graph";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cnode->AddPrimalAttr(attr_proto.name(), value);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -90,6 +90,8 @@ class MSANFModelParser {
|
||||||
void CorrectFuncGraph(const FuncGraphPtr &root);
|
void CorrectFuncGraph(const FuncGraphPtr &root);
|
||||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
|
bool BuildAttrForCNode(const CNodePtr &cnode, const mind_ir::NodeProto &node_proto);
|
||||||
|
ValuePtr GetValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
|
||||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
|
@ -127,7 +129,6 @@ class MSANFModelParser {
|
||||||
bool GetAttrValueForValueNodeWithType(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
bool GetAttrValueForValueNodeWithType(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||||
bool ObtainValueNodeInNoneForm(const std::string &value_node_name);
|
bool ObtainValueNodeInNoneForm(const std::string &value_node_name);
|
||||||
bool ObtainValueNodeInTypeNullForm(const std::string &value_node_name);
|
|
||||||
bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||||
ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto);
|
ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto);
|
||||||
ValuePtr ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto);
|
ValuePtr ObtainValueInDictionaryForm(const mind_ir::AttributeProto &attr_proto);
|
||||||
|
|
|
@ -84,6 +84,8 @@ message NodeProto {
|
||||||
repeated AttributeProto attribute = 5;
|
repeated AttributeProto attribute = 5;
|
||||||
optional string doc_string = 6;
|
optional string doc_string = 6;
|
||||||
optional string domain = 7;
|
optional string domain = 7;
|
||||||
|
repeated AttributeProto node_attr = 8;
|
||||||
|
repeated AttributeProto primal_attr = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,7 @@ if(ENABLE_MINDDATA)
|
||||||
./plugin/device/cpu/hal/*.cc
|
./plugin/device/cpu/hal/*.cc
|
||||||
./place/*.cc
|
./place/*.cc
|
||||||
./ops/test_ops_fake_quant_param.cc
|
./ops/test_ops_fake_quant_param.cc
|
||||||
|
./mindir/*.cc
|
||||||
)
|
)
|
||||||
if(NOT ENABLE_SECURITY)
|
if(NOT ENABLE_SECURITY)
|
||||||
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "common/common_test.h"
|
||||||
|
#include "common/py_func_graph_fetcher.h"
|
||||||
|
|
||||||
|
#include "pipeline/jit/resource.h"
|
||||||
|
#include "pipeline/jit/action.h"
|
||||||
|
#include "include/common/debug/dump_proto.h"
|
||||||
|
#include "load_mindir/load_model.h"
|
||||||
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class TestLoadExport : public UT::Common {
|
||||||
|
public:
|
||||||
|
TestLoadExport() : getPyFun("gtest_input.mindir.mindir_test") {}
|
||||||
|
~TestLoadExport() override = default;
|
||||||
|
// Expectation: No Expectation
|
||||||
|
UT::PyFuncGraphFetcher getPyFun;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Feature: MindIR node attribute export and load.
|
||||||
|
/// Description: Node attribute export and load.
|
||||||
|
/// Expectation: success.
|
||||||
|
TEST_F(TestLoadExport, test_export_func) {
|
||||||
|
auto func_graph = getPyFun.CallAndParseRet("export_test", "add_node_attr_test");
|
||||||
|
tensor::TensorPtr t = std::make_shared<tensor::Tensor>(kFloat32->type_id(), std::vector<int64_t>{1, 2, 3});
|
||||||
|
|
||||||
|
auto export_return_node = func_graph->output();
|
||||||
|
auto export_relu = export_return_node->cast<CNodePtr>();
|
||||||
|
export_relu->AddAttr("TestAttr", MakeValue(true));
|
||||||
|
export_relu->AddPrimalAttr("TestPrimalAttr", MakeValue(true));
|
||||||
|
if (func_graph->manager() == nullptr) {
|
||||||
|
std::vector<FuncGraphPtr> graphs{func_graph};
|
||||||
|
FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(graphs);
|
||||||
|
manager->AddFuncGraph(func_graph);
|
||||||
|
}
|
||||||
|
// Renormalize func_graph to infer and set shape and type information.
|
||||||
|
pipeline::ResourcePtr resource_ = std::make_shared<pipeline::Resource>();
|
||||||
|
auto graph = pipeline::Renormalize(resource_, func_graph, {t->ToAbstract()});
|
||||||
|
auto str = GetBinaryProtoString(graph);
|
||||||
|
mind_ir::ModelProto model_;
|
||||||
|
model_.ParseFromString(str);
|
||||||
|
MSANFModelParser model_parser;
|
||||||
|
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||||
|
auto return_node = dstgraph_ptr->output();
|
||||||
|
auto load_relu = return_node->cast<CNodePtr>();
|
||||||
|
auto test_primal_attr = load_relu->GetPrimalAttr("TestPrimalAttr");
|
||||||
|
auto test_attr = load_relu->GetAttr("TestAttr");
|
||||||
|
ASSERT_TRUE(GetValue<bool>(test_attr));
|
||||||
|
ASSERT_TRUE(GetValue<bool>(test_primal_attr));
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
relu = P.ReLU()
|
||||||
|
|
||||||
|
|
||||||
|
class FnDict:
|
||||||
|
def __init__(self):
|
||||||
|
self.fn_dict = {}
|
||||||
|
|
||||||
|
def __call__(self, fn):
|
||||||
|
self.fn_dict[fn.__name__] = fn
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
return self.fn_dict.get(name, "")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def export_test(tag):
|
||||||
|
""" test_adam_apply_one_with_decay_rule """
|
||||||
|
fns = FnDict()
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def add_node_attr_test(x):
|
||||||
|
return relu(x)
|
||||||
|
|
||||||
|
return fns[tag]
|
|
@ -23,11 +23,4 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { re
|
||||||
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||||
|
|
||||||
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||||
|
|
||||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) { return ""; }
|
|
||||||
|
|
||||||
bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path,
|
|
||||||
const FuncGraphPtr ¶m_layout_fg) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue