forked from mindspore-Ecosystem/mindspore
!4909 [bug]support implicit type conversion for parameter
Merge pull request !4909 from vlne-v1/I1QM7L-implicit-type-conversion-parameter
This commit is contained in:
commit
0d1a7ac654
|
@ -119,6 +119,9 @@ int_type = (int8, int16, int32, int64,)
|
||||||
uint_type = (uint8, uint16, uint32, uint64)
|
uint_type = (uint8, uint16, uint32, uint64)
|
||||||
float_type = (float16, float32, float64,)
|
float_type = (float16, float32, float64,)
|
||||||
|
|
||||||
|
implicit_conversion_seq = {t: idx for idx, t in enumerate((
|
||||||
|
bool_, int8, uint8, int16, int32, int64, float16, float32, float64))}
|
||||||
|
|
||||||
_simple_types = {
|
_simple_types = {
|
||||||
list: list_,
|
list: list_,
|
||||||
tuple: tuple_,
|
tuple: tuple_,
|
||||||
|
|
|
@ -313,8 +313,9 @@ class Parameter(MetaTensor):
|
||||||
Parameter, the parameter after set data.
|
Parameter, the parameter after set data.
|
||||||
"""
|
"""
|
||||||
def raise_type_error(incoming):
|
def raise_type_error(incoming):
|
||||||
raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}"
|
raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
|
||||||
f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.")
|
f"Current dtype is {self.dtype}, and incoming is {incoming}. "
|
||||||
|
f"Use .set_dtype(xxx) to change the dtype.")
|
||||||
|
|
||||||
if not isinstance(data, (MetaTensor, Initializer, int, float)):
|
if not isinstance(data, (MetaTensor, Initializer, int, float)):
|
||||||
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` "
|
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` "
|
||||||
|
@ -338,7 +339,10 @@ class Parameter(MetaTensor):
|
||||||
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
|
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
|
||||||
f" Current shape is {self.shape}, and incoming is {data.shape}.")
|
f" Current shape is {self.shape}, and incoming is {data.shape}.")
|
||||||
if self.dtype != data.dtype:
|
if self.dtype != data.dtype:
|
||||||
raise_type_error(data.dtype)
|
if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
|
||||||
|
raise_type_error(data.dtype)
|
||||||
|
else:
|
||||||
|
data = Tensor(data, self.dtype)
|
||||||
if isinstance(data, Initializer):
|
if isinstance(data, Initializer):
|
||||||
# The parameter has been initializered, directly update by the data
|
# The parameter has been initializered, directly update by the data
|
||||||
if is_current_tensor:
|
if is_current_tensor:
|
||||||
|
|
|
@ -74,7 +74,7 @@ class Tensor(Tensor_):
|
||||||
self._virtual_flag = False
|
self._virtual_flag = False
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(Tensor_.__str__(self))
|
return Tensor_.__repr__(self)
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
out = tensor_operator_registry.get('__add__')(self, other)
|
out = tensor_operator_registry.get('__add__')(self, other)
|
||||||
|
|
|
@ -157,6 +157,7 @@ def test_parameter_compute():
|
||||||
|
|
||||||
|
|
||||||
def test_scalar_parameter_update():
|
def test_scalar_parameter_update():
|
||||||
|
# float
|
||||||
fp = Parameter(0.5, 'fp')
|
fp = Parameter(0.5, 'fp')
|
||||||
fp.default_input = 0.8
|
fp.default_input = 0.8
|
||||||
assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32))
|
assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32))
|
||||||
|
@ -167,6 +168,26 @@ def test_scalar_parameter_update():
|
||||||
assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32))
|
assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32))
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
int_.default_input = 1.2
|
int_.default_input = 1.2
|
||||||
|
# Tensor
|
||||||
|
fp32 = Tensor(0.5, mstype.float32)
|
||||||
|
int32 = Tensor(2, mstype.int32)
|
||||||
|
fp16 = Tensor(0.6, mstype.float16)
|
||||||
|
int16 = Tensor(3, mstype.int16)
|
||||||
|
bool_ = Tensor(np.array(True, dtype=np.bool_))
|
||||||
|
# updata_by_tensor
|
||||||
|
fp32_p = Parameter(fp32, 'fp32')
|
||||||
|
fp32_p.default_input = 0.8
|
||||||
|
fp32_p.default_input = 1
|
||||||
|
fp32_p.default_input = int32
|
||||||
|
fp32_p.default_input = fp32
|
||||||
|
fp32_p.default_input = int16
|
||||||
|
fp32_p.default_input = fp16
|
||||||
|
fp32_p.default_input = bool_
|
||||||
|
|
||||||
|
# updata_by_tensor
|
||||||
|
fp16_p = Parameter(fp16, 'fp16')
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
fp16_p.default_input = fp32
|
||||||
|
|
||||||
|
|
||||||
def test_parameter_lazy_init():
|
def test_parameter_lazy_init():
|
||||||
|
|
Loading…
Reference in New Issue