!38106 fix buf of mindir parameter counter

Merge pull request !38106 from lianliguang/master
This commit is contained in:
i-robot 2022-07-19 01:42:09 +00:00 committed by Gitee
commit 4c686602aa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 83 additions and 1 deletions

View File

@ -149,6 +149,7 @@
"mindspore/tests/ut/python/optimizer/test_auto_grad.py" "broad-except"
"mindspore/tests/st/fallback/control_flow/test_fallback_100_if_after_if.py" "unused-variable"
"mindspore/tests/st/numpy_native/test_array_ops.py" "useless-super-delegation"
"mindspore/tests/ut/python/mindir/test_mindir_export.py" "no-else-return"
#MindSpore Lite
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"

View File

@ -629,6 +629,7 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr
return false;
}
}
outputFuncGraph->set_fv_param_count(importProto.parameter_size());
return true;
}

View File

@ -0,0 +1,54 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from mindspore.nn import Cell, GraphCell
from mindspore import Tensor, export, load, Parameter, dtype
def test_export_control_flow():
"""
Feature: Test MindIR Export model
Description: test mindir export when parameter is not use
Expectation: No exception.
"""
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.w = Parameter(Tensor([-2], dtype.float32), name="weight")
self.b = Parameter(Tensor([-5], dtype.float32), name="bias")
def construct(self, x, y):
if len(x.shape) == 1:
return y
while y >= x:
if self.b <= x:
return y
elif self.w < x:
return x
x += y
return x + y
x = np.array([3], np.float32)
y = np.array([0], np.float32)
net = Net()
export(net, Tensor(x), Tensor(y), file_name="ctrl", file_format='MINDIR')
graph = load('ctrl.mindir')
g_net = GraphCell(graph)
export_out = g_net(Tensor(x), Tensor(y))
correct_out = net(Tensor(x), Tensor(y))
assert np.allclose(export_out.asnumpy(), correct_out.asnumpy())

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -200,3 +200,29 @@ def test_ms_mindir_enc_2g_0001():
export(net, *inputs, file_name=os.path.join(mindir_dir, "AddNet.mindir"), file_format="MINDIR", enc_key=key)
graph = load(os.path.join(mindir_dir, "AddNet_graph.mindir"), dec_key=key)
assert graph is not None
def test_mindir_export_remove_parameter():
"""
Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0)
Description: MindIR Export model is exceed TOTAL_SAVE should be split save as model file and data file
Expectation: No exception.
"""
ms.train.serialization.TOTAL_SAVE = 0
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.addn = ops.AddN()
self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w")
self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z")
def construct(self, x):
return self.addn((x, self.y, self.z))
x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32))
add_net = Net()
export(add_net, x, file_name="mindir_export_split", file_format="MINDIR")
shutil.rmtree("./mindir_export_split_variables/")
with pytest.raises(RuntimeError, match=" please check the correct of the file."):
load("mindir_export_split_graph.mindir")