fix pynative parameter list problem

This commit is contained in:
Margaret_wangrui 2021-11-05 16:48:58 +08:00
parent 3caa6cb09e
commit 0170463364
2 changed files with 50 additions and 2 deletions

View File

@ -102,6 +102,8 @@ class Cell(Cell_):
self.compile_cache = set()
self.parameter_broadcast_done = False
self._id = 1
self.exist_names = set("")
self.exist_objs = set()
init_pipeline()
# call gc to release GE session resources used by non-used cell objects
@ -528,7 +530,7 @@ class Cell(Cell_):
self.insert_param_to_cell(name, value)
def _set_attr_for_parameter_tuple(self, name, value):
"""Set attr for parameter tuple."""
"""Set attr for parameter in ParameterTuple."""
params = self.__dict__.get('_params')
params_list = self.__dict__.get('_params_list')
if params is None:
@ -558,6 +560,22 @@ class Cell(Cell_):
else:
object.__setattr__(self, name, value)
def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
"""Set attr for parameter in list or tuple."""
for item in value:
if item in self.exist_objs:
# If there are multiple identical objects, their names only check once.
continue
self.exist_objs.add(item)
if item.name == PARAMETER_NAME_DEFAULT:
item.name = item.name + "$" + str(self._id)
self._id += 1
if item.name in self.exist_names:
raise ValueError("The value {} , its name '{}' already exists.".
format(value, item.name))
self.exist_names.add(item.name)
object.__setattr__(self, name, value)
def _set_attr_for_cell(self, name, value):
"""Set attr for cell."""
cells = self.__dict__.get('_cells')
@ -594,7 +612,7 @@ class Cell(Cell_):
elif isinstance(value, ParameterTuple):
self._set_attr_for_parameter_tuple(name, value)
elif isinstance(value, (list, tuple)) and value and self._check_param_list_tuple(value):
self._set_attr_for_parameter_tuple(name, value)
self._set_attr_for_parameter_in_list_or_tuple(name, value)
elif isinstance(value, Cell):
self._set_attr_for_cell(name, value)
elif params and name in params:

View File

@ -800,3 +800,33 @@ def test_parameter_same_name():
output = net(x)
output_expect = Tensor(14, ms.float32)
assert output == output_expect
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_same_name_between_tuple_or_list():
"""
Feature: Check the names of parameters between tuple or list.
Description: If the same name exists between tuple and list, an exception will be thrown.
Expectation: Get the expected exception report.
"""
class ParamNet(Cell):
def __init__(self):
super(ParamNet, self).__init__()
self.param_tuple = (Parameter(Tensor([1], ms.float32), name="name_a"),
Parameter(Tensor([2], ms.float32)))
self.param_list = [Parameter(Tensor([3], ms.float32), name="name_a"),
Parameter(Tensor([4], ms.float32))]
def construct(self, x):
res = self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_listp[1] + x
return res
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
net = ParamNet()
x = Tensor([10], ms.float32)
output = net(x)
output_expect = Tensor(20, ms.float32)
assert output == output_expect