!49756 fix dyn tuple tuple export

Merge pull request !49756 from lianliguang/fix-dyn-tuple-export
This commit is contained in:
i-robot 2023-03-04 14:10:46 +00:00 committed by Gitee
commit 13123bf8e3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 176 additions and 9 deletions

View File

@ -140,6 +140,7 @@ class IrExportBuilder {
bool SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
bool SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
bool ExportTuple(const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto);
bool SetAbstractToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetAbstractToNodeProto(const abstract::AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
@ -906,18 +907,36 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
return type_name;
}
bool IrExportBuilder::ExportTuple(const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TUPLE);
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
auto tuple_info_proto = attr_proto->mutable_tuple_info();
tuple_info_proto->set_is_dyn_len(tuple_abs->dynamic_len());
auto elem_abs = tuple_abs->dynamic_len_element_abs();
if (elem_abs != nullptr) {
mind_ir::AttributeProto *tuple_elem_proto = tuple_info_proto->mutable_tuple_elem_item();
if (!SetAbstractToNodeProto(elem_abs, tuple_elem_proto)) {
return false;
}
}
const auto &elems = tuple_abs->elements();
for (const auto &item : elems) {
mind_ir::AttributeProto *attr_values = attr_proto->add_values();
if (!SetAbstractToNodeProto(item, attr_values)) {
return false;
}
}
return true;
}
bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto) {
auto type = abs->BuildType();
auto shape = abs->BuildShape();
if (type->isa<Tuple>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TUPLE);
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
for (size_t i = 0; i < tuple_abs->size(); i++) {
mind_ir::AttributeProto *attr_values = attr_proto->add_values();
if (!SetAbstractToNodeProto((*tuple_abs)[i], attr_values)) {
return false;
}
}
return ExportTuple(abs, attr_proto);
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();

View File

@ -450,7 +450,17 @@ abstract::AbstractBasePtr MSANFModelParser::GetNodeAbstractFromAttrProtoWithType
}
(void)vec.emplace_back(abs);
}
return std::make_shared<abstract::AbstractTuple>(vec);
auto tuple_abs = std::make_shared<abstract::AbstractTuple>(vec);
if (attr_proto.has_tuple_info()) {
auto tuple_info = attr_proto.tuple_info();
tuple_abs->set_dynamic_len(tuple_info.is_dyn_len());
if (tuple_info.has_tuple_elem_item()) {
auto elem_proto = tuple_info.tuple_elem_item();
auto elem_abs = GetNodeAbstractFromAttrProtoWithType(elem_proto);
tuple_abs->set_dynamic_len_element_abs(elem_abs);
}
}
return tuple_abs;
}
case mind_ir::AttributeProto_AttributeType_UMONAD: {
return kUMonad->ToAbstract();

View File

@ -47,6 +47,10 @@ message AttributeProto {
TYPE_NULL = 36;
MAP_TENSOR = 37;
}
message TupleInfoProto{
optional bool is_dyn_len = 18; // store if tuple is dynamic length
optional AttributeProto tuple_elem_item = 19; // store the element of tuple dynamic length
}
optional string name = 1;
optional float f = 2;
optional int64 i = 3;
@ -64,6 +68,7 @@ message AttributeProto {
optional string ref_attr_name = 15;
optional AttributeType type = 16;
repeated AttributeProto values = 17; // tuple, list,dict of value
optional TupleInfoProto tuple_info = 18; // tuple, structural info
}

View File

@ -0,0 +1,133 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export and load mindir in dynamic length of sequence and dynamic shape."""
import os
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore.common import mutable
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import export, load
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_dynamic_shape_tuple():
"""
Feature: export dynamic shape to MindIR file
Description: Test export API to export network into MindIR
Expectation: run successfully
"""
class TestCell(nn.Cell):
def construct(self, x):
return x.shape + (1,)
test_cell = TestCell()
file_name = "test"
export(test_cell, Tensor(shape=[None, 2, 3], dtype=ms.float32), file_name=file_name, file_format="MINDIR")
verify_name = file_name + ".mindir"
assert os.path.exists(verify_name)
x = Tensor(input_np_x)
file_name = "net"
export(test_cell, x, file_name=file_name, file_format='MINDIR')
verify_name = file_name + ".mindir"
assert os.path.exists(verify_name)
graph = load(verify_name)
net_mindir = nn.GraphCell(graph)
result_mindir = net_mindir(x)
out_net = test_cell(x)
assert out_net == result_mindir
os.remove(verify_name)
input_np_x = np.random.rand(2, 3, 3).astype(np.float32)
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
def construct(self, x):
x = x[0] + x[1]
return self.relu(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_mutable_tuple():
"""
Feature: export mutable tuple size to MindIR file
Description: Test export API to export network into MindIR
Expectation: run successfully
"""
x = [Tensor(input_np_x), Tensor(input_np_x)]
net = Net()
file_name = "net"
export(net, mutable(x), file_name=file_name, file_format='MINDIR')
verify_name = file_name + ".mindir"
assert os.path.exists(verify_name)
graph = load(verify_name)
net_mindir = nn.GraphCell(graph)
result_mindir = net_mindir(mutable(x))
out_net = net(x)
assert np.allclose(result_mindir.asnumpy(), out_net.asnumpy(), 0.0001, 0.0001)
os.remove(verify_name)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_mutable_dynamic_tuple():
"""
Feature: export dynamic tuple size to MindIR file
Description: Test export API to export network into MindIR
Expectation: run successfully
"""
x = [Tensor(input_np_x), Tensor(input_np_x)]
y = [Tensor(input_np_x), Tensor(input_np_x), Tensor(input_np_x), Tensor(input_np_x)]
net = Net()
file_name = "net"
export(net, mutable(x, dynamic_len=True), file_name=file_name, file_format='MINDIR')
verify_name = file_name + ".mindir"
assert os.path.exists(verify_name)
graph = load(verify_name)
net_mindir = nn.GraphCell(graph)
result_mindir = net_mindir(mutable(x))
out_net = net(x)
assert np.allclose(result_mindir.asnumpy(), out_net.asnumpy(), 0.0001, 0.0001)
out_net = net(y)
result_mindir = net_mindir(mutable(y))
assert np.allclose(result_mindir.asnumpy(), out_net.asnumpy(), 0.0001, 0.0001)
os.remove(verify_name)