!13260 fix SequentialCell and CellList parameter name bug

From: @caozhou_huawei
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-15 17:34:36 +08:00 committed by Gitee
commit 5c39c33c92
2 changed files with 87 additions and 3 deletions

View File

@ -35,6 +35,43 @@ def _valid_cell(cell):
raise TypeError('Cell {} is not subclass of Cell'.format(cell))
def _get_prefix_and_index(cells):
"""get prefix and index of parameter name in sequential cell or cell list"""
prefix = ""
index = 0
if not cells:
return prefix, index
cell_list = list(cells.items())
first_param, first_key = None, None
second_param, second_key = None, None
for key, cell in cell_list:
try:
_, param = next(cell.parameters_and_names())
except StopIteration:
continue
if first_param is None:
first_param = param
first_key = key
continue
second_param = param
second_key = key
break
if first_param is None:
return prefix, index
split_names = first_param.name.split(".")
for idx, name in enumerate(split_names):
if name == first_key:
prefix = ".".join(split_names[:idx])
prefix = prefix + "." if prefix else prefix
index = idx
if second_param is not None and second_param.name.split(".")[idx] == second_key:
break
return prefix, index
class _CellListBase():
"""
An interface for base the cell as list.
@ -97,19 +134,26 @@ class SequentialCell(Cell):
"""
def __init__(self, *args):
super(SequentialCell, self).__init__()
self._is_dynamic_name = []
if len(args) == 1:
cells = args[0]
if isinstance(cells, list):
for index, cell in enumerate(cells):
self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
elif isinstance(cells, OrderedDict):
for name, cell in cells.items():
self.insert_child_to_cell(name, cell)
cell.update_parameters_name(name + ".")
self._is_dynamic_name.append(False)
else:
raise TypeError('Cells must be list or orderedDict')
else:
for index, cell in enumerate(args):
self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
self.cell_list = list(self._cells.values())
def __getitem__(self, index):
@ -121,9 +165,11 @@ class SequentialCell(Cell):
def __setitem__(self, index, cell):
if _valid_cell(cell):
prefix, _ = _get_prefix_and_index(self._cells)
index = _valid_index(len(self), index)
key = list(self._cells.keys())[index]
self._cells[key] = cell
cell.update_parameters_name(prefix + key + ".")
self.cell_list = list(self._cells.values())
def __delitem__(self, index):
@ -131,12 +177,25 @@ class SequentialCell(Cell):
index = _valid_index(len(self), index)
key = list(self._cells.keys())[index]
del self._cells[key]
del self._is_dynamic_name[index]
elif isinstance(index, slice):
keys = list(self._cells.keys())[index]
for key in keys:
del self._cells[key]
del self._is_dynamic_name[index]
else:
raise TypeError('Index {} is not int type or slice type'.format(index))
prefix, key_index = _get_prefix_and_index(self._cells)
temp_dict = OrderedDict()
for idx, key in enumerate(self._cells.keys()):
cell = self._cells[key]
if self._is_dynamic_name[idx]:
for _, param in cell.parameters_and_names():
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
temp_dict[str(idx)] = cell
else:
temp_dict[key] = cell
self._cells = temp_dict
self.cell_list = list(self._cells.values())
def __len__(self):
@ -165,6 +224,9 @@ class SequentialCell(Cell):
[26.999863 26.999863]]]]
"""
if _valid_cell(cell):
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + str(len(self)) + ".")
self._is_dynamic_name.append(True)
self._cells[str(len(self))] = cell
self.cell_list = list(self._cells.values())
@ -202,9 +264,10 @@ class CellList(_CellListBase, Cell):
(2): ReLU<>
>
"""
def __init__(self, *args):
def __init__(self, *args, **kwargs):
auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True
_CellListBase.__init__(self)
Cell.__init__(self)
Cell.__init__(self, auto_prefix)
if len(args) == 1:
self.extend(args[0])
@ -220,6 +283,9 @@ class CellList(_CellListBase, Cell):
if not isinstance(index, int) and _valid_cell(cell):
raise TypeError('Index {} is not int type'.format(index))
index = _valid_index(len(self), index)
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + str(index) + ".")
self._cells[str(index)] = cell
def __delitem__(self, index):
@ -233,8 +299,12 @@ class CellList(_CellListBase, Cell):
else:
raise TypeError('Index {} is not int type or slice type'.format(index))
# adjust orderedDict
prefix, key_index = _get_prefix_and_index(self._cells)
temp_dict = OrderedDict()
for idx, cell in enumerate(self._cells.values()):
if self._auto_prefix:
for _, param in cell.parameters_and_names():
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
temp_dict[str(idx)] = cell
self._cells = temp_dict
@ -253,10 +323,17 @@ class CellList(_CellListBase, Cell):
idx = _valid_index(len(self), index)
_valid_cell(cell)
length = len(self)
prefix, key_index = _get_prefix_and_index(self._cells)
while length > idx:
if self._auto_prefix:
tmp_cell = self._cells[str(length-1)]
for _, param in tmp_cell.parameters_and_names():
param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
self._cells[str(length)] = self._cells[str(length - 1)]
length -= 1
self._cells[str(idx)] = cell
if self._auto_prefix:
cell.update_parameters_name(prefix + str(idx) + ".")
def extend(self, cells):
"""
@ -267,14 +344,20 @@ class CellList(_CellListBase, Cell):
"""
if not isinstance(cells, list):
raise TypeError('Cells {} should be list of subcells'.format(cells))
prefix, _ = _get_prefix_and_index(self._cells)
for cell in cells:
if _valid_cell(cell):
if self._auto_prefix:
cell.update_parameters_name(prefix + str(len(self)) + ".")
self._cells[str(len(self))] = cell
return self
def append(self, cell):
"""Appends a given cell to the end of the list."""
if _valid_cell(cell):
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + str(len(self)) + ".")
self._cells[str(len(self))] = cell
def set_grad(self, flag=True):

View File

@ -146,7 +146,8 @@ class Optimizer(Cell):
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
if self.is_group_lr:
self.learning_rate = CellList(self.group_lr) if self.dynamic_lr else ParameterTuple(self.group_lr)
self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \
else ParameterTuple(self.group_lr)
else:
self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')