update test for batch_to_space

This commit is contained in:
jiangjinsheng 2020-05-07 15:21:33 +08:00
parent 79a3f0821a
commit a945299e42
1 changed files with 28 additions and 15 deletions

View File

@ -95,21 +95,6 @@ def test_select():
expect = np.array([[1, 8, 9], [10, 5, 6]])
assert np.all(output.asnumpy() == expect)
def test_batch_to_space():
block_size = 2
crops = [[0, 0], [0, 0]]
batch_to_space = P.BatchToSpace(block_size, crops)
input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]).astype(np.float16))
output = batch_to_space(input_x)
assert output.shape() == (1, 1, 2, 2)
def test_space_to_batch():
block_size = 2
paddings = [[0, 0], [0, 0]]
space_to_batch = P.SpaceToBatch(block_size, paddings)
input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))
output = space_to_batch(input_x)
assert output.shape() == (4, 1, 1, 1)
def test_argmin_invalid_output_type():
P.Argmin(-1, mstype.int64)
@ -219,6 +204,28 @@ class MathBinaryNet2(Cell):
return self.logic_or(ret_less_equal, ret_greater)
class BatchToSpaceNet(Cell):
def __init__(self):
super(BatchToSpaceNet, self).__init__()
block_size = 2
crops = [[0, 0], [0, 0]]
self.batch_to_space = P.BatchToSpace(block_size, crops)
def construct(self, x):
return self.batch_to_space(x)
class SpaceToBatchNet(Cell):
def __init__(self):
super(SpaceToBatchNet, self).__init__()
block_size = 2
paddings = [[0, 0], [0, 0]]
self.space_to_batch = P.SpaceToBatch(block_size, paddings)
def construct(self, x):
return self.space_to_batch(x)
test_case_array_ops = [
('CustNet1', {
'block': CustNet1(),
@ -235,6 +242,12 @@ test_case_array_ops = [
('MathBinaryNet2', {
'block': MathBinaryNet2(),
'desc_inputs': [Tensor(np.ones([2, 2]), dtype=ms.int32)]}),
('BatchToSpaceNet', {
'block': BatchToSpaceNet(),
'desc_inputs': [Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]).astype(np.float16))]}),
('SpaceToBatchNet', {
'block': SpaceToBatchNet(),
'desc_inputs': [Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))]}),
]
test_case_lists = [test_case_array_ops]