forked from mindspore-Ecosystem/mindspore
fix pynative parameter list problem
This commit is contained in:
parent
3caa6cb09e
commit
0170463364
|
@ -102,6 +102,8 @@ class Cell(Cell_):
|
|||
self.compile_cache = set()
|
||||
self.parameter_broadcast_done = False
|
||||
self._id = 1
|
||||
self.exist_names = set("")
|
||||
self.exist_objs = set()
|
||||
init_pipeline()
|
||||
|
||||
# call gc to release GE session resources used by non-used cell objects
|
||||
|
@ -528,7 +530,7 @@ class Cell(Cell_):
|
|||
self.insert_param_to_cell(name, value)
|
||||
|
||||
def _set_attr_for_parameter_tuple(self, name, value):
|
||||
"""Set attr for parameter tuple."""
|
||||
"""Set attr for parameter in ParameterTuple."""
|
||||
params = self.__dict__.get('_params')
|
||||
params_list = self.__dict__.get('_params_list')
|
||||
if params is None:
|
||||
|
@ -558,6 +560,22 @@ class Cell(Cell_):
|
|||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
|
||||
"""Set attr for parameter in list or tuple."""
|
||||
for item in value:
|
||||
if item in self.exist_objs:
|
||||
# If there are multiple identical objects, their names only check once.
|
||||
continue
|
||||
self.exist_objs.add(item)
|
||||
if item.name == PARAMETER_NAME_DEFAULT:
|
||||
item.name = item.name + "$" + str(self._id)
|
||||
self._id += 1
|
||||
if item.name in self.exist_names:
|
||||
raise ValueError("The value {} , its name '{}' already exists.".
|
||||
format(value, item.name))
|
||||
self.exist_names.add(item.name)
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def _set_attr_for_cell(self, name, value):
|
||||
"""Set attr for cell."""
|
||||
cells = self.__dict__.get('_cells')
|
||||
|
@ -594,7 +612,7 @@ class Cell(Cell_):
|
|||
elif isinstance(value, ParameterTuple):
|
||||
self._set_attr_for_parameter_tuple(name, value)
|
||||
elif isinstance(value, (list, tuple)) and value and self._check_param_list_tuple(value):
|
||||
self._set_attr_for_parameter_tuple(name, value)
|
||||
self._set_attr_for_parameter_in_list_or_tuple(name, value)
|
||||
elif isinstance(value, Cell):
|
||||
self._set_attr_for_cell(name, value)
|
||||
elif params and name in params:
|
||||
|
|
|
@ -800,3 +800,33 @@ def test_parameter_same_name():
|
|||
output = net(x)
|
||||
output_expect = Tensor(14, ms.float32)
|
||||
assert output == output_expect
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_same_name_between_tuple_or_list():
|
||||
"""
|
||||
Feature: Check the names of parameters between tuple or list.
|
||||
Description: If the same name exists between tuple and list, an exception will be thrown.
|
||||
Expectation: Get the expected exception report.
|
||||
"""
|
||||
class ParamNet(Cell):
|
||||
def __init__(self):
|
||||
super(ParamNet, self).__init__()
|
||||
self.param_tuple = (Parameter(Tensor([1], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([2], ms.float32)))
|
||||
self.param_list = [Parameter(Tensor([3], ms.float32), name="name_a"),
|
||||
Parameter(Tensor([4], ms.float32))]
|
||||
|
||||
def construct(self, x):
|
||||
res = self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_listp[1] + x
|
||||
return res
|
||||
|
||||
with pytest.raises(ValueError, match="its name 'name_a' already exists."):
|
||||
net = ParamNet()
|
||||
x = Tensor([10], ms.float32)
|
||||
output = net(x)
|
||||
output_expect = Tensor(20, ms.float32)
|
||||
assert output == output_expect
|
||||
|
|
Loading…
Reference in New Issue