forked from mindspore-Ecosystem/mindspore
add dtype shape and value in __str__ and __repr__ of Parameter
This commit is contained in:
parent
659b5d8e10
commit
7c4b7203b0
|
@ -201,10 +201,11 @@ class Parameter(Tensor_):
|
|||
return (Tensor, data)
|
||||
|
||||
def __str__(self):
|
||||
return f'Parameter (name={self._param_info.name})'
|
||||
return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
|
||||
f'requires_grad={self.requires_grad})'
|
||||
|
||||
def __repr__(self):
|
||||
return f'Parameter (name={self._param_info.name})'
|
||||
return self.__str__()
|
||||
|
||||
def __parameter__(self):
|
||||
"""For parse check."""
|
||||
|
@ -242,7 +243,6 @@ class Parameter(Tensor_):
|
|||
"""
|
||||
return self._inited_param
|
||||
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Get the name of the parameter."""
|
||||
|
@ -501,10 +501,8 @@ class Parameter(Tensor_):
|
|||
Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
|
||||
returns the same initialized `Parameter`.
|
||||
"""
|
||||
if self.is_default_input_init:
|
||||
is_current_in_parallel = _is_in_parallel_mode()
|
||||
if self.is_in_parallel != is_current_in_parallel:
|
||||
raise RuntimeError("Must set or change parallel mode before any Tensor created.")
|
||||
if self.is_default_input_init and self.is_in_parallel != _is_in_parallel_mode():
|
||||
raise RuntimeError("Must set or change parallel mode before any Tensor created.")
|
||||
if self.init_mode is None:
|
||||
return self
|
||||
if self.inited_param is not None:
|
||||
|
@ -512,29 +510,21 @@ class Parameter(Tensor_):
|
|||
if _is_role_worker() and self.cache_enable:
|
||||
global_seed, op_seed = _get_global_and_op_seed()
|
||||
_insert_weight_init_info(self.name, global_seed, op_seed)
|
||||
|
||||
init_data_args = ()
|
||||
if layout is not None:
|
||||
if not isinstance(layout, tuple):
|
||||
raise TypeError("The layout should be tuple! layout is {}.".format(layout))
|
||||
raise TypeError("The layout should be tuple, but got layout is {}.".format(layout))
|
||||
if len(layout) < 3:
|
||||
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
|
||||
raise ValueError("The length of layout must be larger than 2, but got layout is {}.".format(layout))
|
||||
slice_index = int(_get_slice_index(layout[0], layout[1]))
|
||||
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor)
|
||||
and self.init_mode.init is not None):
|
||||
if _is_role_worker() or _is_role_sched():
|
||||
data = self.init_mode.init_data(0, [1])
|
||||
else:
|
||||
data = self.init_mode.init_data(slice_index, layout[2], layout[5])
|
||||
else:
|
||||
data = self.init_mode.init_data(slice_index, layout[2], layout[5])
|
||||
init_data_args += (slice_index, layout[2], layout[5])
|
||||
|
||||
if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \
|
||||
self.init_mode.init is not None and (_is_role_worker() or _is_role_sched()):
|
||||
data = self.init_mode.init_data(0, [1])
|
||||
else:
|
||||
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor)
|
||||
and self.init_mode.init is not None):
|
||||
if _is_role_worker() or _is_role_sched():
|
||||
data = self.init_mode.init_data(0, [1])
|
||||
else:
|
||||
data = self.init_mode.init_data()
|
||||
else:
|
||||
data = self.init_mode.init_data()
|
||||
data = self.init_mode.init_data(*init_data_args)
|
||||
|
||||
obj = self._update_tensor_data(data)
|
||||
if id(obj) != id(self):
|
||||
|
|
|
@ -93,7 +93,8 @@ def test_outermost_net_pass_parameter():
|
|||
assert "The inputs types of the outermost network support bool, int, float, tensor, " \
|
||||
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
|
||||
"and tuple or list containing only these types, and dict whose values are these types, " \
|
||||
"but got 1th arg is Parameter (name=weight)" in str(err.value)
|
||||
"but got 1th arg is Parameter (name=weight, shape=(2, 2), dtype=Float32, requires_grad=True)" \
|
||||
in str(err.value)
|
||||
|
||||
|
||||
def test_outermost_net_pass_tuple_including_parameter():
|
||||
|
|
Loading…
Reference in New Issue