From 77dcdd89ecdd6b074ed56c8708dcdbc01d67fe96 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Fri, 21 Aug 2020 15:14:23 +0800 Subject: [PATCH] support parameter updata with implicit type conversion --- mindspore/common/dtype.py | 3 +++ mindspore/common/parameter.py | 10 +++++++--- mindspore/common/tensor.py | 2 +- tests/ut/python/nn/test_parameter.py | 21 +++++++++++++++++++++ 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 2b1c692ac1..d11d35fccd 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -119,6 +119,9 @@ int_type = (int8, int16, int32, int64,) uint_type = (uint8, uint16, uint32, uint64) float_type = (float16, float32, float64,) +implicit_conversion_seq = {t: idx for idx, t in enumerate(( + bool_, int8, uint8, int16, int32, int64, float16, float32, float64))} + _simple_types = { list: list_, tuple: tuple_, diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index dfe03a75e8..493168ebf4 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -313,8 +313,9 @@ class Parameter(MetaTensor): Parameter, the parameter after set data. """ def raise_type_error(incoming): - raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}" - f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.") + raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. " + f"Current dtype is {self.dtype}, and incoming is {incoming}. " + f"Use .set_dtype(xxx) to change the dtype.") if not isinstance(data, (MetaTensor, Initializer, int, float)): raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` " @@ -338,7 +339,10 @@ class Parameter(MetaTensor): raise ValueError(f"Can not change the shape of Parameter which has been initialized." f" Current shape is {self.shape}, and incoming is {data.shape}.") if self.dtype != data.dtype: - raise_type_error(data.dtype) + if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]: + raise_type_error(data.dtype) + else: + data = Tensor(data, self.dtype) if isinstance(data, Initializer): # The parameter has been initializered, directly update by the data if is_current_tensor: diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 155f62e5b5..c007097ca6 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -74,7 +74,7 @@ class Tensor(Tensor_): self._virtual_flag = False def __repr__(self): - return str(Tensor_.__str__(self)) + return Tensor_.__repr__(self) def __add__(self, other): out = tensor_operator_registry.get('__add__')(self, other) diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 58a118fded..2447112bd8 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -157,6 +157,7 @@ def test_parameter_compute(): def test_scalar_parameter_update(): + # float fp = Parameter(0.5, 'fp') fp.default_input = 0.8 assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32)) @@ -167,6 +168,26 @@ def test_scalar_parameter_update(): assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32)) with pytest.raises(TypeError): int_.default_input = 1.2 + # Tensor + fp32 = Tensor(0.5, mstype.float32) + int32 = Tensor(2, mstype.int32) + fp16 = Tensor(0.6, mstype.float16) + int16 = Tensor(3, mstype.int16) + bool_ = Tensor(np.array(True, dtype=np.bool_)) + # updata_by_tensor + fp32_p = Parameter(fp32, 'fp32') + fp32_p.default_input = 0.8 + fp32_p.default_input = 1 + fp32_p.default_input = int32 + fp32_p.default_input = fp32 + fp32_p.default_input = int16 + fp32_p.default_input = fp16 + fp32_p.default_input = bool_ + + # updata_by_tensor + fp16_p = Parameter(fp16, 'fp16') + with pytest.raises(TypeError): + fp16_p.default_input = fp32 def test_parameter_lazy_init():