From d18ce957145568591b733b34d836c10ec2db154f Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Fri, 14 Aug 2020 14:49:58 +0800 Subject: [PATCH] * add support to update by scalar* fix update error for by Initializer when slice_shape is* add ut for compute and update --- mindspore/common/parameter.py | 35 +++++++++++++----------- tests/ut/python/nn/test_parameter.py | 40 +++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 091f4bc9673..cf0e9e1dc9a 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -280,15 +280,23 @@ class Parameter(MetaTensor): Set `default_input` of current `Parameter`. Args: - data (Union[Tensor, Initializer]): new data. - slice_shape (bool): If slice the Parameter. Default: False. + data (Union[Tensor, Initializer, int, float]): new data. + slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False. Retruns: Parameter, the parameter after set data. """ - if not isinstance(data, (MetaTensor, Initializer)): - raise ValueError(f"Parameter data must be `Initializer` or a kind of `MetaTensor` " - f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.") + def raise_type_error(incoming): + raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}" + f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.") + + if not isinstance(data, (MetaTensor, Initializer, int, float)): + raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` " + f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.") + if isinstance(data, (int, float)): + if self.dtype in mstype.int_type and isinstance(data, float): + raise_type_error(mstype.float_) + data = Tensor(data, self.dtype) # both not init. is_incoming_tensor = isinstance(data, Tensor) is_current_tensor = isinstance(self, Tensor) @@ -300,25 +308,25 @@ class Parameter(MetaTensor): "network, then call this method.") if tuple(self.shape) != tuple(data.shape): # If Slice create Parameter shape can be change. - if slice_shape: - self._update_tensor_data(data) - self.sliced = True - else: + if not slice_shape: raise ValueError(f"Can not change the shape of Parameter which has been initialized." f" Current shape is {self.shape}, and incoming is {data.shape}.") if self.dtype != data.dtype: - raise ValueError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}" - f", and incoming is {data.dtype}. Use .set_dtype(xxx) to change the dtype.") + raise_type_error(data.dtype) if isinstance(data, Initializer): # The parameter has been initializered, directly update by the data if is_current_tensor: self._update_tensor_data(data.to_tensor()) else: + # also update the related inited parameter data + if self.inited_param is not None: + self.inited_param.set_parameter_data(data) self.init_mode = data elif is_incoming_tensor or is_current_tensor: self._update_tensor_data(data) else: raise ValueError(f"Not support to update the Parameter by {data}") + self.sliced = slice_shape return self def init_data(self, layout=None, set_sliced=False): @@ -340,8 +348,6 @@ class Parameter(MetaTensor): """ if self.init_mode is None: return self - if self.inited_param is not None: - return self.inited_param if layout is not None: if not isinstance(layout, list): raise TypeError("The layout should be list! layout is {}.".format(layout)) @@ -362,8 +368,7 @@ class Parameter(MetaTensor): if id(obj) != id(self): self._inited_param = obj obj.init_mode = None - if set_sliced: - obj.sliced = True + obj.sliced = set_sliced return obj diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index f4ab8734f88..0ff0949b3d2 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -135,6 +135,40 @@ def test_check_str_by_regular(): with pytest.raises(ValueError): _check_str_by_regular(str6) +def test_parameter_compute(): + para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1') + para_2 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test2') + + t3 = Tensor(np.ones((1, 2, 3))) + + out = para_1 + para_2 + assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)) * 2) + + out = para_1 * para_2 + assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3))) + + out = para_1 + t3 + assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)) * 2) + + out = para_1 * t3 + assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3))) + + assert isinstance(para_1, Tensor) + + +def test_scalar_parameter_update(): + fp = Parameter(0.5, 'fp') + fp.default_input = 0.8 + assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32)) + fp.default_input = 1 + assert np.array_equal(fp.default_input.asnumpy(), np.array(1.0, np.float32)) + int_ = Parameter(1, 'fp') + int_.default_input = 2 + assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32)) + with pytest.raises(TypeError): + int_.default_input = 1.2 + + def test_parameter_lazy_init(): # support lazy init in SEMI_AUTO_PARALLEL mode context.reset_auto_parallel_context() @@ -155,7 +189,7 @@ def test_parameter_lazy_init(): # init then assign para = para.init_data() # check the type - with pytest.raises(ValueError): + with pytest.raises(TypeError): para.default_input = Tensor(np.zeros((1, 2, 3))) # check the shape with pytest.raises(ValueError): @@ -170,4 +204,8 @@ def test_parameter_lazy_init(): # expect no effect. para.init_data() assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) + para.set_parameter_data(Tensor(np.zeros((1, 2)).astype(np.float32)), slice_shape=True) + assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2))) + para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True) + assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2))) context.reset_auto_parallel_context()