diff --git a/mindspore/nn/layer/container.py b/mindspore/nn/layer/container.py index 843a0784d54..e148da00567 100644 --- a/mindspore/nn/layer/container.py +++ b/mindspore/nn/layer/container.py @@ -19,6 +19,7 @@ from ..cell import Cell __all__ = ['SequentialCell', 'CellList'] + def _valid_index(cell_num, index): if not isinstance(index, int): raise TypeError("Index {} is not int type") @@ -145,6 +146,12 @@ class SequentialCell(Cell): for cell in self._cells.values(): cell.set_grad(flag) + def append(self, cell): + """Appends a given cell to the end of the list.""" + if _valid_cell(cell): + self._cells[str(len(self))] = cell + return self + def construct(self, input_data): for cell in self.cell_list: input_data = cell(input_data)