forked from mindspore-Ecosystem/mindspore
!25650 Add check with the names of parameter in list or tuple.
Merge pull request !25650 from Margaret_wangrui/param_name
This commit is contained in:
commit
c4d837598d
|
@ -81,7 +81,24 @@ class Parameter(Tensor_):
|
|||
Args:
|
||||
default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data,
|
||||
to initialize the parameter data.
|
||||
name (str): Name of the child parameter. Default: None.
|
||||
name (str): Name of the parameter. Default: None.
|
||||
1) If the parameter is not given a name, the default name is its variable name. For example, the name of
|
||||
param_a below is name_a, and the name of param_b is the variable name param_b.
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = Parameter(Tensor([2], ms.float32))
|
||||
2) If parameter in list or tuple is not given a name, will give it a unique name. For example, the names of
|
||||
parameters below are Parameter$1 and Parameter$2.
|
||||
self.param_list = [Parameter(Tensor([3], ms.float32)),
|
||||
Parameter(Tensor([4], ms.float32))]
|
||||
3) If the parameter is given a name, and the same name exists between different parameters, an exception
|
||||
will be thrown. For example, "its name 'name_a' already exists." will be thrown.
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_tuple = (Parameter(Tensor([5], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([6], ms.float32)))
|
||||
4) If a parameter appear multiple times in list or tuple, check the name of the object only once. For
|
||||
example, the following example will not throw an exception.
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_tuple = (self.param_a, self.param_a)
|
||||
requires_grad (bool): True if the parameter requires gradient. Default: True.
|
||||
layerwise_parallel (bool): When layerwise_parallel is true in data/hybrid parallel mode,
|
||||
broadcast and gradients communication would not be applied to parameters. Default: False.
|
||||
|
|
|
@ -101,6 +101,7 @@ class Cell(Cell_):
|
|||
self.arguments_key = ""
|
||||
self.compile_cache = set()
|
||||
self.parameter_broadcast_done = False
|
||||
self._id = 1
|
||||
init_pipeline()
|
||||
|
||||
# call gc to release GE session resources used by non-used cell objects
|
||||
|
@ -532,8 +533,22 @@ class Cell(Cell_):
|
|||
params_list = self.__dict__.get('_params_list')
|
||||
if params is None:
|
||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||
exist_names = set("")
|
||||
exist_objs = set()
|
||||
for item in value:
|
||||
self.insert_param_to_cell(item.name, item, check_name=False)
|
||||
if item in exist_objs:
|
||||
# If there are multiple identical objects, their names only check once.
|
||||
continue
|
||||
exist_objs.add(item)
|
||||
if item.name == PARAMETER_NAME_DEFAULT:
|
||||
item.name = item.name + "$" + str(self._id)
|
||||
self._id += 1
|
||||
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
||||
if item.name in exist_names:
|
||||
raise ValueError("The value {} , its name '{}' already exists.".
|
||||
format(value, item.name))
|
||||
exist_names.add(item.name)
|
||||
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
if name in self.__dict__:
|
||||
del self.__dict__[name]
|
||||
|
@ -559,6 +574,17 @@ class Cell(Cell_):
|
|||
if hasattr(self, '_cell_init_args'):
|
||||
self.cell_init_args += str({name: value})
|
||||
|
||||
def _check_param_list_tuple(self, value):
|
||||
"""
|
||||
Check the type of input in list or tuple is Parameter.
|
||||
:param value: list or tuple.
|
||||
:return: The types of all inputs are parameter.
|
||||
"""
|
||||
for item in value:
|
||||
if not isinstance(item, Parameter):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
cells = self.__dict__.get('_cells')
|
||||
params = self.__dict__.get('_params')
|
||||
|
@ -567,6 +593,8 @@ class Cell(Cell_):
|
|||
self._set_attr_for_parameter(name, value)
|
||||
elif isinstance(value, ParameterTuple):
|
||||
self._set_attr_for_parameter_tuple(name, value)
|
||||
elif isinstance(value, (list, tuple)) and value and self._check_param_list_tuple(value):
|
||||
self._set_attr_for_parameter_tuple(name, value)
|
||||
elif isinstance(value, Cell):
|
||||
self._set_attr_for_cell(name, value)
|
||||
elif params and name in params:
|
||||
|
@ -756,7 +784,7 @@ class Cell(Cell_):
|
|||
"""Executes saving checkpoint graph operation."""
|
||||
_cell_graph_executor(self, phase='save')
|
||||
|
||||
def insert_param_to_cell(self, param_name, param, check_name=True):
|
||||
def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
|
||||
"""
|
||||
Adds a parameter to the current cell.
|
||||
|
||||
|
@ -766,7 +794,7 @@ class Cell(Cell_):
|
|||
Args:
|
||||
param_name (str): Name of the parameter.
|
||||
param (Parameter): Parameter to be inserted to the cell.
|
||||
check_name (bool): Determines whether the name input is compatible. Default: True.
|
||||
check_name_contain_dot (bool): Determines whether the name input is compatible. Default: True.
|
||||
|
||||
Raises:
|
||||
KeyError: If the name of parameter is null or contains dot.
|
||||
|
@ -775,7 +803,7 @@ class Cell(Cell_):
|
|||
"""
|
||||
if not param_name:
|
||||
raise KeyError("The name of parameter should not be null.")
|
||||
if check_name and '.' in param_name:
|
||||
if check_name_contain_dot and '.' in param_name:
|
||||
raise KeyError("The name of parameter should not contain \".\"")
|
||||
if '_params' not in self.__dict__:
|
||||
raise AttributeError("You need call init() first.")
|
||||
|
|
|
@ -732,3 +732,71 @@ def test_assign_in_zip_loop():
|
|||
net = AssignInZipLoop()
|
||||
out = net(x)
|
||||
assert np.all(out.asnumpy() == 1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: If parameter in list or tuple is not given a name, will give it a unique name.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
|
||||
self.param_c = Parameter(Tensor([3], ms.float32))
|
||||
self.param_d = Parameter(Tensor([4], ms.float32))
|
||||
self.param_tuple = (Parameter(Tensor([5], ms.float32)),
|
||||
Parameter(Tensor([6], ms.float32)))
|
||||
self.param_list = [Parameter(Tensor([5], ms.float32)),
|
||||
Parameter(Tensor([6], ms.float32))]
|
||||
|
||||
def construct(self, x):
|
||||
res1 = self.param_a + self.param_b + self.param_c + self.param_d
|
||||
res1 = res1 - self.param_list[0] + self.param_list[1] + x
|
||||
res2 = self.param_list[0] + self.param_list[1]
|
||||
return res1, res2
|
||||
|
||||
net = ParamNet()
|
||||
x = Tensor([10], ms.float32)
|
||||
output1, output2 = net(x)
|
||||
output1_expect = Tensor(21, ms.float32)
|
||||
output2_expect = Tensor(11, ms.float32)
|
||||
assert output1 == output1_expect
|
||||
assert output2 == output2_expect
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_same_name():
|
||||
"""
|
||||
Feature: Check the names of parameters.
|
||||
Description: If the same name exists between different parameters, an exception will be thrown.
|
||||
Expectation: Get the expected exception report.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
||||
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
|
||||
self.param_tuple = (Parameter(Tensor([5], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([6], ms.float32)))
|
||||
|
||||
def construct(self, x):
|
||||
res1 = self.param_a + self.param_b - self.param_tuple[0] + self.param_tuple[1] + x
|
||||
return res1
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
net = ParamNet()
|
||||
x = Tensor([10], ms.float32)
|
||||
output = net(x)
|
||||
output_expect = Tensor(14, ms.float32)
|
||||
assert output == output_expect
|
||||
|
|
Loading…
Reference in New Issue