fix the problem of the probability of use_case failure when get or init GraphCell params

This commit is contained in:
buxue 2021-11-18 11:41:16 +08:00
parent 4d1266328b
commit 0b00f4cb55
1 changed files with 24 additions and 14 deletions

View File

@ -48,23 +48,21 @@ np_b = np.ones((2, 3), np.float32) + 3
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
input_a = Tensor(np_a) input_a = Tensor(np_a)
input_b = Tensor(np_b) input_b = Tensor(np_b)
mindir_path = "test_net.mindir"
ckpt_path = "test_net.ckpt"
def setup_module():
net = Net()
export(net, input_a, input_b, file_name=mindir_path[:-7], file_format='MINDIR')
def load_mindir_and_update_params(): def load_mindir_and_update_params():
load_net = nn.GraphCell(graph=load(mindir_path)) net = Net()
mindir_name = "net_0.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
load_net = nn.GraphCell(graph=load(mindir_name))
ret = load_net(input_a, input_b) ret = load_net(input_a, input_b)
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param) assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
save_checkpoint(load_net, ckpt_path) ckpt_name = "net_0.ckpt"
params_init = load_checkpoint(ckpt_path) save_checkpoint(load_net, ckpt_name)
load_net_with_new_params = nn.GraphCell(graph=load(mindir_path), params_init=params_init) params_init = load_checkpoint(ckpt_name)
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
return load_net_with_new_params return load_net_with_new_params
@ -90,9 +88,13 @@ def test_init_graph_cell_parameters_with_wrong_type():
Description: load mind ir and update parameters with wrong type. Description: load mind ir and update parameters with wrong type.
Expectation: raise a ValueError indicating the params type error. Expectation: raise a ValueError indicating the params type error.
""" """
net = Net()
mindir_name = "net_1.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
new_params = {"weight": np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)} new_params = {"weight": np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)}
with pytest.raises(TypeError) as err: with pytest.raises(TypeError) as err:
graph = load(mindir_path) graph = load(mindir_name)
load_net = nn.GraphCell(graph, params_init=new_params) load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b) load_net(input_a, input_b)
@ -107,9 +109,13 @@ def test_init_graph_cell_parameters_with_wrong_shape():
Description: load mind ir and update parameters with wrong tensor shape. Description: load mind ir and update parameters with wrong tensor shape.
Expectation: raise a ValueError indicating the tensor shape error. Expectation: raise a ValueError indicating the tensor shape error.
""" """
net = Net()
mindir_name = "net_2.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
new_params = {"weight": Parameter(np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32))} new_params = {"weight": Parameter(np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32))}
with pytest.raises(ValueError) as err: with pytest.raises(ValueError) as err:
graph = load(mindir_path) graph = load(mindir_name)
load_net = nn.GraphCell(graph, params_init=new_params) load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b) load_net(input_a, input_b)
@ -124,9 +130,13 @@ def test_init_graph_cell_parameters_with_wrong_dtype():
Description: load mind ir and update parameters with wrong tensor dtype. Description: load mind ir and update parameters with wrong tensor dtype.
Expectation: raise a ValueError indicating the tensor dtype error. Expectation: raise a ValueError indicating the tensor dtype error.
""" """
net = Net()
mindir_name = "net_3.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))} new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
with pytest.raises(ValueError) as err: with pytest.raises(ValueError) as err:
graph = load(mindir_path) graph = load(mindir_name)
load_net = nn.GraphCell(graph, params_init=new_params) load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b) load_net(input_a, input_b)