modify sequentialcell

This commit is contained in:
Jiaqi 2020-09-14 17:13:47 +08:00
parent 2f14c40934
commit 76cec24098
2 changed files with 26 additions and 1 deletions

View File

@ -146,9 +146,20 @@ class SequentialCell(Cell):
cell.set_grad(flag)
def append(self, cell):
"""Appends a given cell to the end of the list."""
"""Appends a given cell to the end of the list.
Examples:
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid')
>>> bn = nn.BatchNorm2d(2)
>>> relu = nn.ReLU()
>>> seq = nn.SequentialCell([conv, bn])
>>> seq.append(relu)
>>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32)
>>> seq(x)
"""
if _valid_cell(cell):
self._cells[str(len(self))] = cell
self.cell_list = list(self._cells.values())
return self
def construct(self, input_data):

View File

@ -84,6 +84,20 @@ class TestSequentialCell():
del m[:]
assert type(m).__name__ == 'SequentialCell'
def test_sequentialcell_append(self):
input_np = np.ones((1, 3)).astype(np.float32)
input_me = Tensor(input_np)
relu = nn.ReLU()
tanh = nn.Tanh()
seq = nn.SequentialCell([relu])
seq.append(tanh)
out_me = seq(input_me)
seq1 = nn.SequentialCell([relu, tanh])
out = seq1(input_me)
assert out[0][0] == out_me[0][0]
class TestCellList():
""" TestCellList """