forked from mindspore-Ecosystem/mindspore
update test for batch_to_space
This commit is contained in:
parent
79a3f0821a
commit
a945299e42
|
@ -95,21 +95,6 @@ def test_select():
|
|||
expect = np.array([[1, 8, 9], [10, 5, 6]])
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
def test_batch_to_space():
|
||||
block_size = 2
|
||||
crops = [[0, 0], [0, 0]]
|
||||
batch_to_space = P.BatchToSpace(block_size, crops)
|
||||
input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]).astype(np.float16))
|
||||
output = batch_to_space(input_x)
|
||||
assert output.shape() == (1, 1, 2, 2)
|
||||
|
||||
def test_space_to_batch():
|
||||
block_size = 2
|
||||
paddings = [[0, 0], [0, 0]]
|
||||
space_to_batch = P.SpaceToBatch(block_size, paddings)
|
||||
input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))
|
||||
output = space_to_batch(input_x)
|
||||
assert output.shape() == (4, 1, 1, 1)
|
||||
|
||||
def test_argmin_invalid_output_type():
|
||||
P.Argmin(-1, mstype.int64)
|
||||
|
@ -219,6 +204,28 @@ class MathBinaryNet2(Cell):
|
|||
return self.logic_or(ret_less_equal, ret_greater)
|
||||
|
||||
|
||||
class BatchToSpaceNet(Cell):
|
||||
def __init__(self):
|
||||
super(BatchToSpaceNet, self).__init__()
|
||||
block_size = 2
|
||||
crops = [[0, 0], [0, 0]]
|
||||
self.batch_to_space = P.BatchToSpace(block_size, crops)
|
||||
|
||||
def construct(self, x):
|
||||
return self.batch_to_space(x)
|
||||
|
||||
|
||||
class SpaceToBatchNet(Cell):
|
||||
def __init__(self):
|
||||
super(SpaceToBatchNet, self).__init__()
|
||||
block_size = 2
|
||||
paddings = [[0, 0], [0, 0]]
|
||||
self.space_to_batch = P.SpaceToBatch(block_size, paddings)
|
||||
|
||||
def construct(self, x):
|
||||
return self.space_to_batch(x)
|
||||
|
||||
|
||||
test_case_array_ops = [
|
||||
('CustNet1', {
|
||||
'block': CustNet1(),
|
||||
|
@ -235,6 +242,12 @@ test_case_array_ops = [
|
|||
('MathBinaryNet2', {
|
||||
'block': MathBinaryNet2(),
|
||||
'desc_inputs': [Tensor(np.ones([2, 2]), dtype=ms.int32)]}),
|
||||
('BatchToSpaceNet', {
|
||||
'block': BatchToSpaceNet(),
|
||||
'desc_inputs': [Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]).astype(np.float16))]}),
|
||||
('SpaceToBatchNet', {
|
||||
'block': SpaceToBatchNet(),
|
||||
'desc_inputs': [Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))]}),
|
||||
]
|
||||
|
||||
test_case_lists = [test_case_array_ops]
|
||||
|
|
Loading…
Reference in New Issue