fix bug in checkpoint when save scaler

This commit is contained in:
wangnan39@huawei.com 2020-04-22 12:06:05 +08:00
parent 3fa31b34da
commit f38d18c665
3 changed files with 62 additions and 4 deletions

View File

@ -344,5 +344,5 @@ class ParameterUpdate(Cell):
self._param = param
def construct(self, x):
self._param = x
F.assign(self._param, x)
return x

View File

@ -408,10 +408,11 @@ def _fill_param_into_net(net, parameter_list):
for each_param in parameter_list:
param_name = each_param["name"]
np_val = each_param["data"].asnumpy()
if np_val.shape == (1,): # to scalar
parameter_dict[param_name] = Parameter(np_val[0], name=param_name)
if np_val.shape == (1,):
parameter_dict[param_name] = Parameter(np_val, name=param_name)
elif np_val.shape == ():
parameter_dict[param_name] = Parameter(np_val.tolist(), name=param_name)
parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)),
name=param_name)
else:
parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)

View File

@ -52,12 +52,69 @@ def test_parameter_tuple_illegal():
def test_parameter_init_illegal():
import numpy as np
dat = np.array([[1, 2, 3], [2, 3, 4]])
tensor = Tensor(dat)
data_none = None
data_bool = True
data_str = "nicai"
data_int = 3
data_list = [1, "2", True]
data_tuple = (1, 2, 3)
# test data
Parameter(tensor, name=data_str)
Parameter(data_int, name=data_str)
Parameter(dat, name=data_str)
with pytest.raises(ValueError):
Parameter(data_bool, name=data_str)
# test name
Parameter(tensor, name=data_none)
with pytest.raises(ValueError):
Parameter(tensor, name=dat)
with pytest.raises(ValueError):
Parameter(tensor, name=tensor)
with pytest.raises(ValueError):
Parameter(tensor, name=data_bool)
with pytest.raises(ValueError):
Parameter(tensor, name=data_int)
with pytest.raises(ValueError):
Parameter(tensor, name=data_list)
with pytest.raises(ValueError):
Parameter(tensor, name=data_tuple)
Parameter(tensor, name=data_str, requires_grad=data_bool)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_none)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=dat)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=tensor)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_str)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_int)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_list)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_tuple)
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_bool)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=dat)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=tensor)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_none)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_str)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_int)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_list)
with pytest.raises(TypeError):
Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_tuple)
def test_check_str_by_regular():