!4909 [bug]support implicit type conversion for parameter

Merge pull request !4909 from vlne-v1/I1QM7L-implicit-type-conversion-parameter
This commit is contained in:
mindspore-ci-bot 2020-08-21 17:08:23 +08:00 committed by Gitee
commit 0d1a7ac654
4 changed files with 32 additions and 4 deletions

View File

@ -119,6 +119,9 @@ int_type = (int8, int16, int32, int64,)
uint_type = (uint8, uint16, uint32, uint64) uint_type = (uint8, uint16, uint32, uint64)
float_type = (float16, float32, float64,) float_type = (float16, float32, float64,)
implicit_conversion_seq = {t: idx for idx, t in enumerate((
bool_, int8, uint8, int16, int32, int64, float16, float32, float64))}
_simple_types = { _simple_types = {
list: list_, list: list_,
tuple: tuple_, tuple: tuple_,

View File

@ -313,8 +313,9 @@ class Parameter(MetaTensor):
Parameter, the parameter after set data. Parameter, the parameter after set data.
""" """
def raise_type_error(incoming): def raise_type_error(incoming):
raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}" raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.") f"Current dtype is {self.dtype}, and incoming is {incoming}. "
f"Use .set_dtype(xxx) to change the dtype.")
if not isinstance(data, (MetaTensor, Initializer, int, float)): if not isinstance(data, (MetaTensor, Initializer, int, float)):
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` " raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` "
@ -338,7 +339,10 @@ class Parameter(MetaTensor):
raise ValueError(f"Can not change the shape of Parameter which has been initialized." raise ValueError(f"Can not change the shape of Parameter which has been initialized."
f" Current shape is {self.shape}, and incoming is {data.shape}.") f" Current shape is {self.shape}, and incoming is {data.shape}.")
if self.dtype != data.dtype: if self.dtype != data.dtype:
raise_type_error(data.dtype) if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
raise_type_error(data.dtype)
else:
data = Tensor(data, self.dtype)
if isinstance(data, Initializer): if isinstance(data, Initializer):
# The parameter has been initializered, directly update by the data # The parameter has been initializered, directly update by the data
if is_current_tensor: if is_current_tensor:

View File

@ -74,7 +74,7 @@ class Tensor(Tensor_):
self._virtual_flag = False self._virtual_flag = False
def __repr__(self): def __repr__(self):
return str(Tensor_.__str__(self)) return Tensor_.__repr__(self)
def __add__(self, other): def __add__(self, other):
out = tensor_operator_registry.get('__add__')(self, other) out = tensor_operator_registry.get('__add__')(self, other)

View File

@ -157,6 +157,7 @@ def test_parameter_compute():
def test_scalar_parameter_update(): def test_scalar_parameter_update():
# float
fp = Parameter(0.5, 'fp') fp = Parameter(0.5, 'fp')
fp.default_input = 0.8 fp.default_input = 0.8
assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32)) assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32))
@ -167,6 +168,26 @@ def test_scalar_parameter_update():
assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32)) assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32))
with pytest.raises(TypeError): with pytest.raises(TypeError):
int_.default_input = 1.2 int_.default_input = 1.2
# Tensor
fp32 = Tensor(0.5, mstype.float32)
int32 = Tensor(2, mstype.int32)
fp16 = Tensor(0.6, mstype.float16)
int16 = Tensor(3, mstype.int16)
bool_ = Tensor(np.array(True, dtype=np.bool_))
# updata_by_tensor
fp32_p = Parameter(fp32, 'fp32')
fp32_p.default_input = 0.8
fp32_p.default_input = 1
fp32_p.default_input = int32
fp32_p.default_input = fp32
fp32_p.default_input = int16
fp32_p.default_input = fp16
fp32_p.default_input = bool_
# updata_by_tensor
fp16_p = Parameter(fp16, 'fp16')
with pytest.raises(TypeError):
fp16_p.default_input = fp32
def test_parameter_lazy_init(): def test_parameter_lazy_init():