!22830 args must be compatible with loaded graph

Merge pull request !22830 from lanzhineng/mindir_control_flow
This commit is contained in:
i-robot 2021-09-06 01:28:49 +00:00 committed by Gitee
commit 44775dca4b
3 changed files with 81 additions and 2 deletions

View File

@ -945,6 +945,19 @@ bool SetMindIRGraphAction(const ResourcePtr &res) {
return arg;
});
abstract::AbstractBasePtrList func_args;
const auto inputs = fg->get_inputs();
(void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(func_args),
[](const AnfNodePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->abstract()->Broaden();
});
if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
MS_LOG(EXCEPTION) << "The args is not compatible with the function graph."
<< " Please check the args is compatible with the follow: " << abstract::ArgsToString(func_args)
<< " The input args:" << abstract::ArgsToString(broaded_args);
}
// suppose that there is not KeywordArgument for the top graph
// get the hyper parameter
for (const auto &param : fg->parameters()) {

View File

@ -306,8 +306,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id()));
if (dims.size() == 0) {
MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1.";
tensor_proto->add_dims(1);
MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0.";
} else {
for (const auto &dim : dims) {
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.serialization import export, load
ZERO = Tensor([0], mstype.int32)
ONE = Tensor([1], mstype.int32)
class RecrusiveNet(nn.Cell):
def construct(self, x, z):
def f(x, z):
y = ZERO
if x < 0:
y = ONE
elif x < 3:
y = x * f(x - 1, z)
elif x < 5:
y = x * f(x - 2, z)
else:
y = f(x - 4, z)
z = y + 1 + z
return z
return f(x, z)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_recrusive():
context.set_context(mode=context.GRAPH_MODE)
network = RecrusiveNet()
x = Tensor(np.array([1]).astype(np.float32))
y = Tensor(np.array([2]).astype(np.float32))
origin_out = network(x, y)
file_name = "recrusive_net"
export(network, x, y, file_name=file_name, file_format='MINDIR')
mindir_name = file_name + ".mindir"
assert os.path.exists(mindir_name)
graph = load(mindir_name)
loaded_net = nn.GraphCell(graph)
outputs_after_load = loaded_net(x, y)
assert origin_out == outputs_after_load