forked from mindspore-Ecosystem/mindspore
!4443 [bug]fix bugs in parameters and add ut cases
Merge pull request !4443 from vlne-v1/I1RDCY-slice_shape_update_initializer
This commit is contained in:
commit
659ed37812
|
@ -280,15 +280,23 @@ class Parameter(MetaTensor):
|
|||
Set `default_input` of current `Parameter`.
|
||||
|
||||
Args:
|
||||
data (Union[Tensor, Initializer]): new data.
|
||||
slice_shape (bool): If slice the Parameter. Default: False.
|
||||
data (Union[Tensor, Initializer, int, float]): new data.
|
||||
slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False.
|
||||
|
||||
Retruns:
|
||||
Parameter, the parameter after set data.
|
||||
"""
|
||||
if not isinstance(data, (MetaTensor, Initializer)):
|
||||
raise ValueError(f"Parameter data must be `Initializer` or a kind of `MetaTensor` "
|
||||
f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.")
|
||||
def raise_type_error(incoming):
|
||||
raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}"
|
||||
f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.")
|
||||
|
||||
if not isinstance(data, (MetaTensor, Initializer, int, float)):
|
||||
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` "
|
||||
f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.")
|
||||
if isinstance(data, (int, float)):
|
||||
if self.dtype in mstype.int_type and isinstance(data, float):
|
||||
raise_type_error(mstype.float_)
|
||||
data = Tensor(data, self.dtype)
|
||||
# both not init.
|
||||
is_incoming_tensor = isinstance(data, Tensor)
|
||||
is_current_tensor = isinstance(self, Tensor)
|
||||
|
@ -300,25 +308,25 @@ class Parameter(MetaTensor):
|
|||
"network, then call this method.")
|
||||
if tuple(self.shape) != tuple(data.shape):
|
||||
# If Slice create Parameter shape can be change.
|
||||
if slice_shape:
|
||||
self._update_tensor_data(data)
|
||||
self.sliced = True
|
||||
else:
|
||||
if not slice_shape:
|
||||
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
|
||||
f" Current shape is {self.shape}, and incoming is {data.shape}.")
|
||||
if self.dtype != data.dtype:
|
||||
raise ValueError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}"
|
||||
f", and incoming is {data.dtype}. Use .set_dtype(xxx) to change the dtype.")
|
||||
raise_type_error(data.dtype)
|
||||
if isinstance(data, Initializer):
|
||||
# The parameter has been initializered, directly update by the data
|
||||
if is_current_tensor:
|
||||
self._update_tensor_data(data.to_tensor())
|
||||
else:
|
||||
# also update the related inited parameter data
|
||||
if self.inited_param is not None:
|
||||
self.inited_param.set_parameter_data(data)
|
||||
self.init_mode = data
|
||||
elif is_incoming_tensor or is_current_tensor:
|
||||
self._update_tensor_data(data)
|
||||
else:
|
||||
raise ValueError(f"Not support to update the Parameter by {data}")
|
||||
self.sliced = slice_shape
|
||||
return self
|
||||
|
||||
def init_data(self, layout=None, set_sliced=False):
|
||||
|
@ -340,8 +348,6 @@ class Parameter(MetaTensor):
|
|||
"""
|
||||
if self.init_mode is None:
|
||||
return self
|
||||
if self.inited_param is not None:
|
||||
return self.inited_param
|
||||
if layout is not None:
|
||||
if not isinstance(layout, list):
|
||||
raise TypeError("The layout should be list! layout is {}.".format(layout))
|
||||
|
@ -362,8 +368,7 @@ class Parameter(MetaTensor):
|
|||
if id(obj) != id(self):
|
||||
self._inited_param = obj
|
||||
obj.init_mode = None
|
||||
if set_sliced:
|
||||
obj.sliced = True
|
||||
obj.sliced = set_sliced
|
||||
return obj
|
||||
|
||||
|
||||
|
|
|
@ -135,6 +135,40 @@ def test_check_str_by_regular():
|
|||
with pytest.raises(ValueError):
|
||||
_check_str_by_regular(str6)
|
||||
|
||||
def test_parameter_compute():
|
||||
para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1')
|
||||
para_2 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test2')
|
||||
|
||||
t3 = Tensor(np.ones((1, 2, 3)))
|
||||
|
||||
out = para_1 + para_2
|
||||
assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)) * 2)
|
||||
|
||||
out = para_1 * para_2
|
||||
assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)))
|
||||
|
||||
out = para_1 + t3
|
||||
assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)) * 2)
|
||||
|
||||
out = para_1 * t3
|
||||
assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)))
|
||||
|
||||
assert isinstance(para_1, Tensor)
|
||||
|
||||
|
||||
def test_scalar_parameter_update():
|
||||
fp = Parameter(0.5, 'fp')
|
||||
fp.default_input = 0.8
|
||||
assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32))
|
||||
fp.default_input = 1
|
||||
assert np.array_equal(fp.default_input.asnumpy(), np.array(1.0, np.float32))
|
||||
int_ = Parameter(1, 'fp')
|
||||
int_.default_input = 2
|
||||
assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32))
|
||||
with pytest.raises(TypeError):
|
||||
int_.default_input = 1.2
|
||||
|
||||
|
||||
def test_parameter_lazy_init():
|
||||
# support lazy init in SEMI_AUTO_PARALLEL mode
|
||||
context.reset_auto_parallel_context()
|
||||
|
@ -155,7 +189,7 @@ def test_parameter_lazy_init():
|
|||
# init then assign
|
||||
para = para.init_data()
|
||||
# check the type
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
para.default_input = Tensor(np.zeros((1, 2, 3)))
|
||||
# check the shape
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -170,4 +204,8 @@ def test_parameter_lazy_init():
|
|||
# expect no effect.
|
||||
para.init_data()
|
||||
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
|
||||
para.set_parameter_data(Tensor(np.zeros((1, 2)).astype(np.float32)), slice_shape=True)
|
||||
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2)))
|
||||
para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True)
|
||||
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2)))
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue