From bd4add384f14d0bf75c72069b5a188b601eda57c Mon Sep 17 00:00:00 2001 From: lanzhineng Date: Fri, 3 Sep 2021 09:14:51 +0800 Subject: [PATCH] The args must be compatible with the loaded graph --- mindspore/ccsrc/pipeline/jit/action.cc | 13 ++++ .../transform/express_ir/mindir_exporter.cc | 3 +- tests/st/control/test_recrusive_mindir.py | 67 +++++++++++++++++++ 3 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tests/st/control/test_recrusive_mindir.py diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 10205e15663..73f859a55fc 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -945,6 +945,19 @@ bool SetMindIRGraphAction(const ResourcePtr &res) { return arg; }); + abstract::AbstractBasePtrList func_args; + const auto inputs = fg->get_inputs(); + (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(func_args), + [](const AnfNodePtr &arg) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->abstract()->Broaden(); + }); + if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) { + MS_LOG(EXCEPTION) << "The args is not compatible with the function graph." + << " Please check the args is compatible with the follow: " << abstract::ArgsToString(func_args) + << " The input args:" << abstract::ArgsToString(broaded_args); + } + // suppose that there is not KeywordArgument for the top graph // get the hyper parameter for (const auto ¶m : fg->parameters()) { diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 3819b02821d..ef8981cb153 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -306,8 +306,7 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn mind_ir::TensorProto *tensor_proto = value_proto->add_tensor(); tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id())); if (dims.size() == 0) { - MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; - tensor_proto->add_dims(1); + MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0."; } else { for (const auto &dim : dims) { MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; diff --git a/tests/st/control/test_recrusive_mindir.py b/tests/st/control/test_recrusive_mindir.py new file mode 100644 index 00000000000..fbc91464ad1 --- /dev/null +++ b/tests/st/control/test_recrusive_mindir.py @@ -0,0 +1,67 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import context +from mindspore.common.tensor import Tensor +from mindspore.common import dtype as mstype +from mindspore.train.serialization import export, load + +ZERO = Tensor([0], mstype.int32) +ONE = Tensor([1], mstype.int32) + + +class RecrusiveNet(nn.Cell): + def construct(self, x, z): + def f(x, z): + y = ZERO + if x < 0: + y = ONE + elif x < 3: + y = x * f(x - 1, z) + elif x < 5: + y = x * f(x - 2, z) + else: + y = f(x - 4, z) + z = y + 1 + z + return z + + return f(x, z) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_recrusive(): + context.set_context(mode=context.GRAPH_MODE) + network = RecrusiveNet() + + x = Tensor(np.array([1]).astype(np.float32)) + y = Tensor(np.array([2]).astype(np.float32)) + origin_out = network(x, y) + + file_name = "recrusive_net" + export(network, x, y, file_name=file_name, file_format='MINDIR') + mindir_name = file_name + ".mindir" + assert os.path.exists(mindir_name) + + graph = load(mindir_name) + loaded_net = nn.GraphCell(graph) + outputs_after_load = loaded_net(x, y) + assert origin_out == outputs_after_load