forked from mindspore-Ecosystem/mindspore
147 lines
4.3 KiB
Python
147 lines
4.3 KiB
Python
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
""" test container """
|
|
from collections import OrderedDict
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import mindspore.nn as nn
|
|
from mindspore import Tensor
|
|
|
|
weight = Tensor(np.ones([2, 2]))
|
|
conv2 = nn.Conv2d(3, 64, (3, 3), stride=2, padding=0)
|
|
|
|
kernel_size = 3
|
|
stride = 2
|
|
padding = 1
|
|
avg_pool = nn.AvgPool2d(kernel_size, stride)
|
|
|
|
|
|
class TestSequentialCell():
|
|
""" TestSequentialCell """
|
|
|
|
def test_SequentialCell_init(self):
|
|
m = nn.SequentialCell()
|
|
assert type(m).__name__ == 'SequentialCell'
|
|
|
|
def test_SequentialCell_init2(self):
|
|
m = nn.SequentialCell([conv2])
|
|
assert len(m) == 1
|
|
|
|
def test_SequentialCell_init3(self):
|
|
m = nn.SequentialCell([conv2, avg_pool])
|
|
assert len(m) == 2
|
|
|
|
def test_SequentialCell_init4(self):
|
|
m = nn.SequentialCell(OrderedDict(
|
|
[('cov2d', conv2), ('avg_pool', avg_pool)]))
|
|
assert len(m) == 2
|
|
|
|
def test_getitem1(self):
|
|
m = nn.SequentialCell(OrderedDict(
|
|
[('cov2d', conv2), ('avg_pool', avg_pool)]))
|
|
assert m[0] == conv2
|
|
|
|
def test_getitem2(self):
|
|
m = nn.SequentialCell(OrderedDict(
|
|
[('cov2d', conv2), ('avg_pool', avg_pool)]))
|
|
assert len(m[0:2]) == 2
|
|
assert m[:2][1] == avg_pool
|
|
|
|
def test_setitem1(self):
|
|
m = nn.SequentialCell(OrderedDict(
|
|
[('cov2d', conv2), ('avg_pool', avg_pool)]))
|
|
m[1] = conv2
|
|
assert m[1] == m[0]
|
|
|
|
def test_setitem2(self):
|
|
m = nn.SequentialCell(OrderedDict(
|
|
[('cov2d', conv2), ('avg_pool', avg_pool)]))
|
|
with pytest.raises(TypeError):
|
|
m[1.0] = conv2
|
|
|
|
def test_delitem1(self):
|
|
m = nn.SequentialCell(OrderedDict(
|
|
[('cov2d', conv2), ('avg_pool', avg_pool)]))
|
|
del m[0]
|
|
assert len(m) == 1
|
|
|
|
def test_delitem2(self):
|
|
m = nn.SequentialCell(OrderedDict(
|
|
[('cov2d', conv2), ('avg_pool', avg_pool)]))
|
|
del m[:]
|
|
assert type(m).__name__ == 'SequentialCell'
|
|
|
|
|
|
class TestCellList():
|
|
""" TestCellList """
|
|
|
|
def test_init1(self):
|
|
cell_list = nn.CellList([conv2, avg_pool])
|
|
assert len(cell_list) == 2
|
|
|
|
def test_init2(self):
|
|
with pytest.raises(TypeError):
|
|
nn.CellList(["test"])
|
|
|
|
def test_getitem(self):
|
|
cell_list = nn.CellList([conv2, avg_pool])
|
|
assert cell_list[0] == conv2
|
|
temp_cells = cell_list[:]
|
|
assert temp_cells[1] == avg_pool
|
|
|
|
def test_setitem(self):
|
|
cell_list = nn.CellList([conv2, avg_pool])
|
|
cell_list[0] = avg_pool
|
|
assert cell_list[0] == cell_list[1]
|
|
|
|
def test_delitem(self):
|
|
cell_list = nn.CellList([conv2, avg_pool])
|
|
del cell_list[0]
|
|
assert len(cell_list) == 1
|
|
del cell_list[:]
|
|
assert type(cell_list).__name__ == 'CellList'
|
|
|
|
def test_iter(self):
|
|
cell_list = nn.CellList([conv2, avg_pool])
|
|
for item in cell_list:
|
|
cell = item
|
|
assert type(cell).__name__ == 'AvgPool2d'
|
|
|
|
def test_add(self):
|
|
cell_list = nn.CellList([conv2, avg_pool])
|
|
cell_list += [conv2]
|
|
assert len(cell_list) == 3
|
|
assert cell_list[0] == cell_list[2]
|
|
|
|
def test_insert(self):
|
|
cell_list = nn.CellList([conv2, avg_pool])
|
|
cell_list.insert(0, avg_pool)
|
|
assert len(cell_list) == 3
|
|
assert cell_list[0] == cell_list[2]
|
|
|
|
def test_append(self):
|
|
cell_list = nn.CellList([conv2, avg_pool])
|
|
cell_list.append(conv2)
|
|
assert len(cell_list) == 3
|
|
assert cell_list[0] == cell_list[2]
|
|
|
|
def test_extend(self):
|
|
cell_list = nn.CellList()
|
|
cell_list.extend([conv2, avg_pool])
|
|
assert len(cell_list) == 2
|
|
assert cell_list[0] == conv2
|