!3525 Fix a bug for Parameter

Merge pull request !3525 from hewei/fix_parameter_bug
This commit is contained in:
mindspore-ci-bot 2020-07-27 17:17:16 +08:00 committed by Gitee
commit 5f7d2ba396
2 changed files with 18 additions and 2 deletions

View File

@ -210,7 +210,6 @@ class Parameter:
def set_parameter_data(self, data):
"""Set `default_input` of current `Parameter`."""
self.init_mode = None
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Tensor):
@ -243,7 +242,8 @@ class Parameter:
set_sliced (bool): True if should set parameter sliced after init the data of initializer.
Default: False.
"""
if self.init_mode is None:
if isinstance(self.default_input, Tensor):
# skip if data already initialized.
return
if layout is not None:
if not isinstance(layout, list):

View File

@ -134,3 +134,19 @@ def test_check_str_by_regular():
_check_str_by_regular(str5)
with pytest.raises(ValueError):
_check_str_by_regular(str6)
def test_parameter_lazy_init():
# Call init_data() without set default_input.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
assert not isinstance(para.default_input, Tensor)
para.init_data()
assert isinstance(para.default_input, Tensor)
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
# Call init_data() after default_input is set.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
assert not isinstance(para.default_input, Tensor)
para.default_input = Tensor(np.zeros((1, 2, 3)))
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
para.init_data() # expect no effect.
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))