forked from mindspore-Ecosystem/mindspore
fix the problem of the probability of use_case failure when get or init GraphCell params
This commit is contained in:
parent
4d1266328b
commit
0b00f4cb55
|
@ -48,23 +48,21 @@ np_b = np.ones((2, 3), np.float32) + 3
|
||||||
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
|
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
|
||||||
input_a = Tensor(np_a)
|
input_a = Tensor(np_a)
|
||||||
input_b = Tensor(np_b)
|
input_b = Tensor(np_b)
|
||||||
mindir_path = "test_net.mindir"
|
|
||||||
ckpt_path = "test_net.ckpt"
|
|
||||||
|
|
||||||
|
|
||||||
def setup_module():
|
|
||||||
net = Net()
|
|
||||||
export(net, input_a, input_b, file_name=mindir_path[:-7], file_format='MINDIR')
|
|
||||||
|
|
||||||
|
|
||||||
def load_mindir_and_update_params():
|
def load_mindir_and_update_params():
|
||||||
load_net = nn.GraphCell(graph=load(mindir_path))
|
net = Net()
|
||||||
|
mindir_name = "net_0.mindir"
|
||||||
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||||
|
|
||||||
|
load_net = nn.GraphCell(graph=load(mindir_name))
|
||||||
ret = load_net(input_a, input_b)
|
ret = load_net(input_a, input_b)
|
||||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
|
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
|
||||||
|
|
||||||
save_checkpoint(load_net, ckpt_path)
|
ckpt_name = "net_0.ckpt"
|
||||||
params_init = load_checkpoint(ckpt_path)
|
save_checkpoint(load_net, ckpt_name)
|
||||||
load_net_with_new_params = nn.GraphCell(graph=load(mindir_path), params_init=params_init)
|
params_init = load_checkpoint(ckpt_name)
|
||||||
|
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
|
||||||
return load_net_with_new_params
|
return load_net_with_new_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,9 +88,13 @@ def test_init_graph_cell_parameters_with_wrong_type():
|
||||||
Description: load mind ir and update parameters with wrong type.
|
Description: load mind ir and update parameters with wrong type.
|
||||||
Expectation: raise a ValueError indicating the params type error.
|
Expectation: raise a ValueError indicating the params type error.
|
||||||
"""
|
"""
|
||||||
|
net = Net()
|
||||||
|
mindir_name = "net_1.mindir"
|
||||||
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||||
|
|
||||||
new_params = {"weight": np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)}
|
new_params = {"weight": np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)}
|
||||||
with pytest.raises(TypeError) as err:
|
with pytest.raises(TypeError) as err:
|
||||||
graph = load(mindir_path)
|
graph = load(mindir_name)
|
||||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||||
load_net(input_a, input_b)
|
load_net(input_a, input_b)
|
||||||
|
|
||||||
|
@ -107,9 +109,13 @@ def test_init_graph_cell_parameters_with_wrong_shape():
|
||||||
Description: load mind ir and update parameters with wrong tensor shape.
|
Description: load mind ir and update parameters with wrong tensor shape.
|
||||||
Expectation: raise a ValueError indicating the tensor shape error.
|
Expectation: raise a ValueError indicating the tensor shape error.
|
||||||
"""
|
"""
|
||||||
|
net = Net()
|
||||||
|
mindir_name = "net_2.mindir"
|
||||||
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||||
|
|
||||||
new_params = {"weight": Parameter(np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32))}
|
new_params = {"weight": Parameter(np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32))}
|
||||||
with pytest.raises(ValueError) as err:
|
with pytest.raises(ValueError) as err:
|
||||||
graph = load(mindir_path)
|
graph = load(mindir_name)
|
||||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||||
load_net(input_a, input_b)
|
load_net(input_a, input_b)
|
||||||
|
|
||||||
|
@ -124,9 +130,13 @@ def test_init_graph_cell_parameters_with_wrong_dtype():
|
||||||
Description: load mind ir and update parameters with wrong tensor dtype.
|
Description: load mind ir and update parameters with wrong tensor dtype.
|
||||||
Expectation: raise a ValueError indicating the tensor dtype error.
|
Expectation: raise a ValueError indicating the tensor dtype error.
|
||||||
"""
|
"""
|
||||||
|
net = Net()
|
||||||
|
mindir_name = "net_3.mindir"
|
||||||
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||||
|
|
||||||
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
|
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
|
||||||
with pytest.raises(ValueError) as err:
|
with pytest.raises(ValueError) as err:
|
||||||
graph = load(mindir_path)
|
graph = load(mindir_name)
|
||||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||||
load_net(input_a, input_b)
|
load_net(input_a, input_b)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue