forked from mindspore-Ecosystem/mindspore
!13260 fix SequentialCell and CellList parameter name bug
From: @caozhou_huawei Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5c39c33c92
|
@ -35,6 +35,43 @@ def _valid_cell(cell):
|
|||
raise TypeError('Cell {} is not subclass of Cell'.format(cell))
|
||||
|
||||
|
||||
def _get_prefix_and_index(cells):
|
||||
"""get prefix and index of parameter name in sequential cell or cell list"""
|
||||
prefix = ""
|
||||
index = 0
|
||||
if not cells:
|
||||
return prefix, index
|
||||
|
||||
cell_list = list(cells.items())
|
||||
first_param, first_key = None, None
|
||||
second_param, second_key = None, None
|
||||
for key, cell in cell_list:
|
||||
try:
|
||||
_, param = next(cell.parameters_and_names())
|
||||
except StopIteration:
|
||||
continue
|
||||
if first_param is None:
|
||||
first_param = param
|
||||
first_key = key
|
||||
continue
|
||||
second_param = param
|
||||
second_key = key
|
||||
break
|
||||
|
||||
if first_param is None:
|
||||
return prefix, index
|
||||
|
||||
split_names = first_param.name.split(".")
|
||||
for idx, name in enumerate(split_names):
|
||||
if name == first_key:
|
||||
prefix = ".".join(split_names[:idx])
|
||||
prefix = prefix + "." if prefix else prefix
|
||||
index = idx
|
||||
if second_param is not None and second_param.name.split(".")[idx] == second_key:
|
||||
break
|
||||
return prefix, index
|
||||
|
||||
|
||||
class _CellListBase():
|
||||
"""
|
||||
An interface for base the cell as list.
|
||||
|
@ -97,19 +134,26 @@ class SequentialCell(Cell):
|
|||
"""
|
||||
def __init__(self, *args):
|
||||
super(SequentialCell, self).__init__()
|
||||
self._is_dynamic_name = []
|
||||
if len(args) == 1:
|
||||
cells = args[0]
|
||||
if isinstance(cells, list):
|
||||
for index, cell in enumerate(cells):
|
||||
self.insert_child_to_cell(str(index), cell)
|
||||
cell.update_parameters_name(str(index) + ".")
|
||||
self._is_dynamic_name.append(True)
|
||||
elif isinstance(cells, OrderedDict):
|
||||
for name, cell in cells.items():
|
||||
self.insert_child_to_cell(name, cell)
|
||||
cell.update_parameters_name(name + ".")
|
||||
self._is_dynamic_name.append(False)
|
||||
else:
|
||||
raise TypeError('Cells must be list or orderedDict')
|
||||
else:
|
||||
for index, cell in enumerate(args):
|
||||
self.insert_child_to_cell(str(index), cell)
|
||||
cell.update_parameters_name(str(index) + ".")
|
||||
self._is_dynamic_name.append(True)
|
||||
self.cell_list = list(self._cells.values())
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
@ -121,9 +165,11 @@ class SequentialCell(Cell):
|
|||
|
||||
def __setitem__(self, index, cell):
|
||||
if _valid_cell(cell):
|
||||
prefix, _ = _get_prefix_and_index(self._cells)
|
||||
index = _valid_index(len(self), index)
|
||||
key = list(self._cells.keys())[index]
|
||||
self._cells[key] = cell
|
||||
cell.update_parameters_name(prefix + key + ".")
|
||||
self.cell_list = list(self._cells.values())
|
||||
|
||||
def __delitem__(self, index):
|
||||
|
@ -131,12 +177,25 @@ class SequentialCell(Cell):
|
|||
index = _valid_index(len(self), index)
|
||||
key = list(self._cells.keys())[index]
|
||||
del self._cells[key]
|
||||
del self._is_dynamic_name[index]
|
||||
elif isinstance(index, slice):
|
||||
keys = list(self._cells.keys())[index]
|
||||
for key in keys:
|
||||
del self._cells[key]
|
||||
del self._is_dynamic_name[index]
|
||||
else:
|
||||
raise TypeError('Index {} is not int type or slice type'.format(index))
|
||||
prefix, key_index = _get_prefix_and_index(self._cells)
|
||||
temp_dict = OrderedDict()
|
||||
for idx, key in enumerate(self._cells.keys()):
|
||||
cell = self._cells[key]
|
||||
if self._is_dynamic_name[idx]:
|
||||
for _, param in cell.parameters_and_names():
|
||||
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
|
||||
temp_dict[str(idx)] = cell
|
||||
else:
|
||||
temp_dict[key] = cell
|
||||
self._cells = temp_dict
|
||||
self.cell_list = list(self._cells.values())
|
||||
|
||||
def __len__(self):
|
||||
|
@ -165,6 +224,9 @@ class SequentialCell(Cell):
|
|||
[26.999863 26.999863]]]]
|
||||
"""
|
||||
if _valid_cell(cell):
|
||||
prefix, _ = _get_prefix_and_index(self._cells)
|
||||
cell.update_parameters_name(prefix + str(len(self)) + ".")
|
||||
self._is_dynamic_name.append(True)
|
||||
self._cells[str(len(self))] = cell
|
||||
self.cell_list = list(self._cells.values())
|
||||
|
||||
|
@ -202,9 +264,10 @@ class CellList(_CellListBase, Cell):
|
|||
(2): ReLU<>
|
||||
>
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
def __init__(self, *args, **kwargs):
|
||||
auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True
|
||||
_CellListBase.__init__(self)
|
||||
Cell.__init__(self)
|
||||
Cell.__init__(self, auto_prefix)
|
||||
if len(args) == 1:
|
||||
self.extend(args[0])
|
||||
|
||||
|
@ -220,6 +283,9 @@ class CellList(_CellListBase, Cell):
|
|||
if not isinstance(index, int) and _valid_cell(cell):
|
||||
raise TypeError('Index {} is not int type'.format(index))
|
||||
index = _valid_index(len(self), index)
|
||||
if self._auto_prefix:
|
||||
prefix, _ = _get_prefix_and_index(self._cells)
|
||||
cell.update_parameters_name(prefix + str(index) + ".")
|
||||
self._cells[str(index)] = cell
|
||||
|
||||
def __delitem__(self, index):
|
||||
|
@ -233,8 +299,12 @@ class CellList(_CellListBase, Cell):
|
|||
else:
|
||||
raise TypeError('Index {} is not int type or slice type'.format(index))
|
||||
# adjust orderedDict
|
||||
prefix, key_index = _get_prefix_and_index(self._cells)
|
||||
temp_dict = OrderedDict()
|
||||
for idx, cell in enumerate(self._cells.values()):
|
||||
if self._auto_prefix:
|
||||
for _, param in cell.parameters_and_names():
|
||||
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
|
||||
temp_dict[str(idx)] = cell
|
||||
self._cells = temp_dict
|
||||
|
||||
|
@ -253,10 +323,17 @@ class CellList(_CellListBase, Cell):
|
|||
idx = _valid_index(len(self), index)
|
||||
_valid_cell(cell)
|
||||
length = len(self)
|
||||
prefix, key_index = _get_prefix_and_index(self._cells)
|
||||
while length > idx:
|
||||
if self._auto_prefix:
|
||||
tmp_cell = self._cells[str(length-1)]
|
||||
for _, param in tmp_cell.parameters_and_names():
|
||||
param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
|
||||
self._cells[str(length)] = self._cells[str(length - 1)]
|
||||
length -= 1
|
||||
self._cells[str(idx)] = cell
|
||||
if self._auto_prefix:
|
||||
cell.update_parameters_name(prefix + str(idx) + ".")
|
||||
|
||||
def extend(self, cells):
|
||||
"""
|
||||
|
@ -267,14 +344,20 @@ class CellList(_CellListBase, Cell):
|
|||
"""
|
||||
if not isinstance(cells, list):
|
||||
raise TypeError('Cells {} should be list of subcells'.format(cells))
|
||||
prefix, _ = _get_prefix_and_index(self._cells)
|
||||
for cell in cells:
|
||||
if _valid_cell(cell):
|
||||
if self._auto_prefix:
|
||||
cell.update_parameters_name(prefix + str(len(self)) + ".")
|
||||
self._cells[str(len(self))] = cell
|
||||
return self
|
||||
|
||||
def append(self, cell):
|
||||
"""Appends a given cell to the end of the list."""
|
||||
if _valid_cell(cell):
|
||||
if self._auto_prefix:
|
||||
prefix, _ = _get_prefix_and_index(self._cells)
|
||||
cell.update_parameters_name(prefix + str(len(self)) + ".")
|
||||
self._cells[str(len(self))] = cell
|
||||
|
||||
def set_grad(self, flag=True):
|
||||
|
|
|
@ -146,7 +146,8 @@ class Optimizer(Cell):
|
|||
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
|
||||
|
||||
if self.is_group_lr:
|
||||
self.learning_rate = CellList(self.group_lr) if self.dynamic_lr else ParameterTuple(self.group_lr)
|
||||
self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \
|
||||
else ParameterTuple(self.group_lr)
|
||||
else:
|
||||
self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
|
||||
|
||||
|
|
Loading…
Reference in New Issue