!25650 Add check with the names of parameter in list or tuple.

Merge pull request !25650 from Margaret_wangrui/param_name
This commit is contained in:
i-robot 2021-11-04 11:33:44 +00:00 committed by Gitee
commit c4d837598d
3 changed files with 118 additions and 5 deletions

View File

@ -81,7 +81,24 @@ class Parameter(Tensor_):
Args:
default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data,
to initialize the parameter data.
name (str): Name of the child parameter. Default: None.
name (str): Name of the parameter. Default: None.
1) If the parameter is not given a name, the default name is its variable name. For example, the name of
param_a below is name_a, and the name of param_b is the variable name param_b.
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = Parameter(Tensor([2], ms.float32))
2) If parameter in list or tuple is not given a name, will give it a unique name. For example, the names of
parameters below are Parameter$1 and Parameter$2.
self.param_list = [Parameter(Tensor([3], ms.float32)),
Parameter(Tensor([4], ms.float32))]
3) If the parameter is given a name, and the same name exists between different parameters, an exception
will be thrown. For example, "its name 'name_a' already exists." will be thrown.
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_tuple = (Parameter(Tensor([5], ms.float32), name="name_a"),
Parameter(Tensor([6], ms.float32)))
4) If a parameter appear multiple times in list or tuple, check the name of the object only once. For
example, the following example will not throw an exception.
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_tuple = (self.param_a, self.param_a)
requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): When layerwise_parallel is true in data/hybrid parallel mode,
broadcast and gradients communication would not be applied to parameters. Default: False.

View File

@ -101,6 +101,7 @@ class Cell(Cell_):
self.arguments_key = ""
self.compile_cache = set()
self.parameter_broadcast_done = False
self._id = 1
init_pipeline()
# call gc to release GE session resources used by non-used cell objects
@ -532,8 +533,22 @@ class Cell(Cell_):
params_list = self.__dict__.get('_params_list')
if params is None:
raise AttributeError("Can not assign params before Cell.__init__() call.")
exist_names = set("")
exist_objs = set()
for item in value:
self.insert_param_to_cell(item.name, item, check_name=False)
if item in exist_objs:
# If there are multiple identical objects, their names only check once.
continue
exist_objs.add(item)
if item.name == PARAMETER_NAME_DEFAULT:
item.name = item.name + "$" + str(self._id)
self._id += 1
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
if item.name in exist_names:
raise ValueError("The value {} , its name '{}' already exists.".
format(value, item.name))
exist_names.add(item.name)
if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__:
del self.__dict__[name]
@ -559,6 +574,17 @@ class Cell(Cell_):
if hasattr(self, '_cell_init_args'):
self.cell_init_args += str({name: value})
def _check_param_list_tuple(self, value):
"""
Check the type of input in list or tuple is Parameter.
:param value: list or tuple.
:return: The types of all inputs are parameter.
"""
for item in value:
if not isinstance(item, Parameter):
return False
return True
def __setattr__(self, name, value):
cells = self.__dict__.get('_cells')
params = self.__dict__.get('_params')
@ -567,6 +593,8 @@ class Cell(Cell_):
self._set_attr_for_parameter(name, value)
elif isinstance(value, ParameterTuple):
self._set_attr_for_parameter_tuple(name, value)
elif isinstance(value, (list, tuple)) and value and self._check_param_list_tuple(value):
self._set_attr_for_parameter_tuple(name, value)
elif isinstance(value, Cell):
self._set_attr_for_cell(name, value)
elif params and name in params:
@ -756,7 +784,7 @@ class Cell(Cell_):
"""Executes saving checkpoint graph operation."""
_cell_graph_executor(self, phase='save')
def insert_param_to_cell(self, param_name, param, check_name=True):
def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
"""
Adds a parameter to the current cell.
@ -766,7 +794,7 @@ class Cell(Cell_):
Args:
param_name (str): Name of the parameter.
param (Parameter): Parameter to be inserted to the cell.
check_name (bool): Determines whether the name input is compatible. Default: True.
check_name_contain_dot (bool): Determines whether the name input is compatible. Default: True.
Raises:
KeyError: If the name of parameter is null or contains dot.
@ -775,7 +803,7 @@ class Cell(Cell_):
"""
if not param_name:
raise KeyError("The name of parameter should not be null.")
if check_name and '.' in param_name:
if check_name_contain_dot and '.' in param_name:
raise KeyError("The name of parameter should not contain \".\"")
if '_params' not in self.__dict__:
raise AttributeError("You need call init() first.")

View File

@ -732,3 +732,71 @@ def test_assign_in_zip_loop():
net = AssignInZipLoop()
out = net(x)
assert np.all(out.asnumpy() == 1)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter():
"""
Feature: Check the names of parameters.
Description: If parameter in list or tuple is not given a name, will give it a unique name.
Expectation: No exception.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
self.param_c = Parameter(Tensor([3], ms.float32))
self.param_d = Parameter(Tensor([4], ms.float32))
self.param_tuple = (Parameter(Tensor([5], ms.float32)),
Parameter(Tensor([6], ms.float32)))
self.param_list = [Parameter(Tensor([5], ms.float32)),
Parameter(Tensor([6], ms.float32))]
def construct(self, x):
res1 = self.param_a + self.param_b + self.param_c + self.param_d
res1 = res1 - self.param_list[0] + self.param_list[1] + x
res2 = self.param_list[0] + self.param_list[1]
return res1, res2
net = ParamNet()
x = Tensor([10], ms.float32)
output1, output2 = net(x)
output1_expect = Tensor(21, ms.float32)
output2_expect = Tensor(11, ms.float32)
assert output1 == output1_expect
assert output2 == output2_expect
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_same_name():
"""
Feature: Check the names of parameters.
Description: If the same name exists between different parameters, an exception will be thrown.
Expectation: Get the expected exception report.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
self.param_tuple = (Parameter(Tensor([5], ms.float32), name="name_a"),
Parameter(Tensor([6], ms.float32)))
def construct(self, x):
res1 = self.param_a + self.param_b - self.param_tuple[0] + self.param_tuple[1] + x
return res1
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
net = ParamNet()
x = Tensor([10], ms.float32)
output = net(x)
output_expect = Tensor(14, ms.float32)
assert output == output_expect