forked from mindspore-Ecosystem/mindspore
modify sequentialcell
This commit is contained in:
parent
2f14c40934
commit
76cec24098
|
@ -146,9 +146,20 @@ class SequentialCell(Cell):
|
|||
cell.set_grad(flag)
|
||||
|
||||
def append(self, cell):
|
||||
"""Appends a given cell to the end of the list."""
|
||||
"""Appends a given cell to the end of the list.
|
||||
|
||||
Examples:
|
||||
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid')
|
||||
>>> bn = nn.BatchNorm2d(2)
|
||||
>>> relu = nn.ReLU()
|
||||
>>> seq = nn.SequentialCell([conv, bn])
|
||||
>>> seq.append(relu)
|
||||
>>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32)
|
||||
>>> seq(x)
|
||||
"""
|
||||
if _valid_cell(cell):
|
||||
self._cells[str(len(self))] = cell
|
||||
self.cell_list = list(self._cells.values())
|
||||
return self
|
||||
|
||||
def construct(self, input_data):
|
||||
|
|
|
@ -84,6 +84,20 @@ class TestSequentialCell():
|
|||
del m[:]
|
||||
assert type(m).__name__ == 'SequentialCell'
|
||||
|
||||
def test_sequentialcell_append(self):
|
||||
input_np = np.ones((1, 3)).astype(np.float32)
|
||||
input_me = Tensor(input_np)
|
||||
relu = nn.ReLU()
|
||||
tanh = nn.Tanh()
|
||||
seq = nn.SequentialCell([relu])
|
||||
seq.append(tanh)
|
||||
out_me = seq(input_me)
|
||||
|
||||
seq1 = nn.SequentialCell([relu, tanh])
|
||||
out = seq1(input_me)
|
||||
|
||||
assert out[0][0] == out_me[0][0]
|
||||
|
||||
|
||||
class TestCellList():
|
||||
""" TestCellList """
|
||||
|
|
Loading…
Reference in New Issue