forked from OSSInnovation/mindspore
!3525 Fix a bug for Parameter
Merge pull request !3525 from hewei/fix_parameter_bug
This commit is contained in:
commit
5f7d2ba396
|
@ -210,7 +210,6 @@ class Parameter:
|
|||
|
||||
def set_parameter_data(self, data):
|
||||
"""Set `default_input` of current `Parameter`."""
|
||||
self.init_mode = None
|
||||
if isinstance(data, bool):
|
||||
raise ValueError('Parameter data can not be `bool`')
|
||||
if isinstance(data, Tensor):
|
||||
|
@ -243,7 +242,8 @@ class Parameter:
|
|||
set_sliced (bool): True if should set parameter sliced after init the data of initializer.
|
||||
Default: False.
|
||||
"""
|
||||
if self.init_mode is None:
|
||||
if isinstance(self.default_input, Tensor):
|
||||
# skip if data already initialized.
|
||||
return
|
||||
if layout is not None:
|
||||
if not isinstance(layout, list):
|
||||
|
|
|
@ -134,3 +134,19 @@ def test_check_str_by_regular():
|
|||
_check_str_by_regular(str5)
|
||||
with pytest.raises(ValueError):
|
||||
_check_str_by_regular(str6)
|
||||
|
||||
def test_parameter_lazy_init():
|
||||
# Call init_data() without set default_input.
|
||||
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
|
||||
assert not isinstance(para.default_input, Tensor)
|
||||
para.init_data()
|
||||
assert isinstance(para.default_input, Tensor)
|
||||
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
|
||||
|
||||
# Call init_data() after default_input is set.
|
||||
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
|
||||
assert not isinstance(para.default_input, Tensor)
|
||||
para.default_input = Tensor(np.zeros((1, 2, 3)))
|
||||
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
|
||||
para.init_data() # expect no effect.
|
||||
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
|
||||
|
|
Loading…
Reference in New Issue