fix bug of mindir parameter counter and check if all parameter has been inited in mindir
This commit is contained in:
parent
fecb44a1bb
commit
b5a09933a2
|
@ -149,6 +149,7 @@
|
|||
"mindspore/tests/ut/python/optimizer/test_auto_grad.py" "broad-except"
|
||||
"mindspore/tests/st/fallback/control_flow/test_fallback_100_if_after_if.py" "unused-variable"
|
||||
"mindspore/tests/st/numpy_native/test_array_ops.py" "useless-super-delegation"
|
||||
"mindspore/tests/ut/python/mindir/test_mindir_export.py" "no-else-return"
|
||||
|
||||
#MindSpore Lite
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"
|
||||
|
|
|
@ -629,6 +629,7 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr
|
|||
return false;
|
||||
}
|
||||
}
|
||||
outputFuncGraph->set_fv_param_count(importProto.parameter_size());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from mindspore.nn import Cell, GraphCell
|
||||
|
||||
from mindspore import Tensor, export, load, Parameter, dtype
|
||||
|
||||
|
||||
def test_export_control_flow():
|
||||
"""
|
||||
Feature: Test MindIR Export model
|
||||
Description: test mindir export when parameter is not use
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.w = Parameter(Tensor([-2], dtype.float32), name="weight")
|
||||
self.b = Parameter(Tensor([-5], dtype.float32), name="bias")
|
||||
|
||||
def construct(self, x, y):
|
||||
if len(x.shape) == 1:
|
||||
return y
|
||||
while y >= x:
|
||||
if self.b <= x:
|
||||
return y
|
||||
elif self.w < x:
|
||||
return x
|
||||
x += y
|
||||
|
||||
return x + y
|
||||
|
||||
|
||||
x = np.array([3], np.float32)
|
||||
y = np.array([0], np.float32)
|
||||
net = Net()
|
||||
export(net, Tensor(x), Tensor(y), file_name="ctrl", file_format='MINDIR')
|
||||
graph = load('ctrl.mindir')
|
||||
g_net = GraphCell(graph)
|
||||
export_out = g_net(Tensor(x), Tensor(y))
|
||||
correct_out = net(Tensor(x), Tensor(y))
|
||||
assert np.allclose(export_out.asnumpy(), correct_out.asnumpy())
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -200,3 +200,29 @@ def test_ms_mindir_enc_2g_0001():
|
|||
export(net, *inputs, file_name=os.path.join(mindir_dir, "AddNet.mindir"), file_format="MINDIR", enc_key=key)
|
||||
graph = load(os.path.join(mindir_dir, "AddNet_graph.mindir"), dec_key=key)
|
||||
assert graph is not None
|
||||
|
||||
|
||||
def test_mindir_export_remove_parameter():
|
||||
"""
|
||||
Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0)
|
||||
Description: MindIR Export model is exceed TOTAL_SAVE should be split save as model file and data file
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms.train.serialization.TOTAL_SAVE = 0
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.addn = ops.AddN()
|
||||
self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w")
|
||||
self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z")
|
||||
|
||||
def construct(self, x):
|
||||
return self.addn((x, self.y, self.z))
|
||||
|
||||
x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32))
|
||||
add_net = Net()
|
||||
export(add_net, x, file_name="mindir_export_split", file_format="MINDIR")
|
||||
shutil.rmtree("./mindir_export_split_variables/")
|
||||
with pytest.raises(RuntimeError, match=" please check the correct of the file."):
|
||||
load("mindir_export_split_graph.mindir")
|
||||
|
|
Loading…
Reference in New Issue