add dtype shape and value in __str__ and __repr__ of Parameter

This commit is contained in:
buxue 2021-01-26 12:21:45 +08:00
parent 659b5d8e10
commit 7c4b7203b0
2 changed files with 17 additions and 26 deletions

View File

@ -201,10 +201,11 @@ class Parameter(Tensor_):
return (Tensor, data)
def __str__(self):
return f'Parameter (name={self._param_info.name})'
return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
f'requires_grad={self.requires_grad})'
def __repr__(self):
return f'Parameter (name={self._param_info.name})'
return self.__str__()
def __parameter__(self):
"""For parse check."""
@ -242,7 +243,6 @@ class Parameter(Tensor_):
"""
return self._inited_param
@property
def name(self):
"""Get the name of the parameter."""
@ -501,10 +501,8 @@ class Parameter(Tensor_):
Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
returns the same initialized `Parameter`.
"""
if self.is_default_input_init:
is_current_in_parallel = _is_in_parallel_mode()
if self.is_in_parallel != is_current_in_parallel:
raise RuntimeError("Must set or change parallel mode before any Tensor created.")
if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode():
raise RuntimeError("Must set or change parallel mode before any Tensor created.")
if self.init_mode is None:
return self
if self.inited_param is not None:
@ -512,29 +510,21 @@ class Parameter(Tensor_):
if _is_role_worker() and self.cache_enable:
global_seed, op_seed = _get_global_and_op_seed()
_insert_weight_init_info(self.name, global_seed, op_seed)
init_data_args = ()
if layout is not None:
if not isinstance(layout, tuple):
raise TypeError("The layout should be tuple! layout is {}.".format(layout))
raise TypeError("The layout should be tuple, but got layout is {}.".format(layout))
if len(layout) < 3:
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
raise ValueError("The length of layout must be larger than 2, but got layout is {}.".format(layout))
slice_index = int(_get_slice_index(layout[0], layout[1]))
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor)
and self.init_mode.init is not None):
if _is_role_worker() or _is_role_sched():
data = self.init_mode.init_data(0, [1])
else:
data = self.init_mode.init_data(slice_index, layout[2], layout[5])
else:
data = self.init_mode.init_data(slice_index, layout[2], layout[5])
init_data_args += (slice_index, layout[2], layout[5])
if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \
self.init_mode.init is not None and (_is_role_worker() or _is_role_sched()):
data = self.init_mode.init_data(0, [1])
else:
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor)
and self.init_mode.init is not None):
if _is_role_worker() or _is_role_sched():
data = self.init_mode.init_data(0, [1])
else:
data = self.init_mode.init_data()
else:
data = self.init_mode.init_data()
data = self.init_mode.init_data(*init_data_args)
obj = self._update_tensor_data(data)
if id(obj) != id(self):

View File

@ -93,7 +93,8 @@ def test_outermost_net_pass_parameter():
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but got 1th arg is Parameter (name=weight)" in str(err.value)
"but got 1th arg is Parameter (name=weight, shape=(2, 2), dtype=Float32, requires_grad=True)" \
in str(err.value)
def test_outermost_net_pass_tuple_including_parameter():