forked from mindspore-Ecosystem/mindspore
fix the problem of the probability of use_case failure when get or init GraphCell params
This commit is contained in:
parent
4d1266328b
commit
0b00f4cb55
|
@ -48,23 +48,21 @@ np_b = np.ones((2, 3), np.float32) + 3
|
|||
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
|
||||
input_a = Tensor(np_a)
|
||||
input_b = Tensor(np_b)
|
||||
mindir_path = "test_net.mindir"
|
||||
ckpt_path = "test_net.ckpt"
|
||||
|
||||
|
||||
def setup_module():
|
||||
net = Net()
|
||||
export(net, input_a, input_b, file_name=mindir_path[:-7], file_format='MINDIR')
|
||||
|
||||
|
||||
def load_mindir_and_update_params():
|
||||
load_net = nn.GraphCell(graph=load(mindir_path))
|
||||
net = Net()
|
||||
mindir_name = "net_0.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
||||
load_net = nn.GraphCell(graph=load(mindir_name))
|
||||
ret = load_net(input_a, input_b)
|
||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
|
||||
|
||||
save_checkpoint(load_net, ckpt_path)
|
||||
params_init = load_checkpoint(ckpt_path)
|
||||
load_net_with_new_params = nn.GraphCell(graph=load(mindir_path), params_init=params_init)
|
||||
ckpt_name = "net_0.ckpt"
|
||||
save_checkpoint(load_net, ckpt_name)
|
||||
params_init = load_checkpoint(ckpt_name)
|
||||
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
|
||||
return load_net_with_new_params
|
||||
|
||||
|
||||
|
@ -90,9 +88,13 @@ def test_init_graph_cell_parameters_with_wrong_type():
|
|||
Description: load mind ir and update parameters with wrong type.
|
||||
Expectation: raise a ValueError indicating the params type error.
|
||||
"""
|
||||
net = Net()
|
||||
mindir_name = "net_1.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
||||
new_params = {"weight": np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)}
|
||||
with pytest.raises(TypeError) as err:
|
||||
graph = load(mindir_path)
|
||||
graph = load(mindir_name)
|
||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
|
@ -107,9 +109,13 @@ def test_init_graph_cell_parameters_with_wrong_shape():
|
|||
Description: load mind ir and update parameters with wrong tensor shape.
|
||||
Expectation: raise a ValueError indicating the tensor shape error.
|
||||
"""
|
||||
net = Net()
|
||||
mindir_name = "net_2.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
||||
new_params = {"weight": Parameter(np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32))}
|
||||
with pytest.raises(ValueError) as err:
|
||||
graph = load(mindir_path)
|
||||
graph = load(mindir_name)
|
||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
|
@ -124,9 +130,13 @@ def test_init_graph_cell_parameters_with_wrong_dtype():
|
|||
Description: load mind ir and update parameters with wrong tensor dtype.
|
||||
Expectation: raise a ValueError indicating the tensor dtype error.
|
||||
"""
|
||||
net = Net()
|
||||
mindir_name = "net_3.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
||||
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
|
||||
with pytest.raises(ValueError) as err:
|
||||
graph = load(mindir_path)
|
||||
graph = load(mindir_name)
|
||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
|
|
Loading…
Reference in New Issue