From 76cec2409864b83b9f5040ba4b89e6caa5f0b4c4 Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Mon, 14 Sep 2020 17:13:47 +0800 Subject: [PATCH] modify sequentialcell --- mindspore/nn/layer/container.py | 13 ++++++++++++- tests/ut/python/nn/test_container.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/layer/container.py b/mindspore/nn/layer/container.py index bd2e82e6d72..a30634ce3b8 100644 --- a/mindspore/nn/layer/container.py +++ b/mindspore/nn/layer/container.py @@ -146,9 +146,20 @@ class SequentialCell(Cell): cell.set_grad(flag) def append(self, cell): - """Appends a given cell to the end of the list.""" + """Appends a given cell to the end of the list. + + Examples: + >>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid') + >>> bn = nn.BatchNorm2d(2) + >>> relu = nn.ReLU() + >>> seq = nn.SequentialCell([conv, bn]) + >>> seq.append(relu) + >>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32) + >>> seq(x) + """ if _valid_cell(cell): self._cells[str(len(self))] = cell + self.cell_list = list(self._cells.values()) return self def construct(self, input_data): diff --git a/tests/ut/python/nn/test_container.py b/tests/ut/python/nn/test_container.py index c381394c26d..c19cf573e99 100644 --- a/tests/ut/python/nn/test_container.py +++ b/tests/ut/python/nn/test_container.py @@ -84,6 +84,20 @@ class TestSequentialCell(): del m[:] assert type(m).__name__ == 'SequentialCell' + def test_sequentialcell_append(self): + input_np = np.ones((1, 3)).astype(np.float32) + input_me = Tensor(input_np) + relu = nn.ReLU() + tanh = nn.Tanh() + seq = nn.SequentialCell([relu]) + seq.append(tanh) + out_me = seq(input_me) + + seq1 = nn.SequentialCell([relu, tanh]) + out = seq1(input_me) + + assert out[0][0] == out_me[0][0] + class TestCellList(): """ TestCellList """