forked from mindspore-Ecosystem/mindspore
!3978 Added unit tests for ResizeNearestNeighbor gpu kernel
Merge pull request !3978 from Peilin/master
This commit is contained in:
commit
6c4b4f91d2
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,70 +12,561 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ResizeNearestNeighborAlignCornerT(nn.Cell):
|
||||
def __init__(self, size):
|
||||
super(ResizeNearestNeighborAlignCornerT, self).__init__()
|
||||
self.ResizeNearestNeighborAlignCornerT = P.ResizeNearestNeighbor(size, align_corners=True)
|
||||
|
||||
def construct(self, x):
|
||||
return self.ResizeNearestNeighborAlignCornerT(x)
|
||||
|
||||
class ResizeNearestNeighborAlignCornerF(nn.Cell):
|
||||
def __init__(self, size):
|
||||
super(ResizeNearestNeighborAlignCornerF, self).__init__()
|
||||
self.ResizeNearestNeighborAlignCornerF = P.ResizeNearestNeighbor(size, align_corners=False)
|
||||
|
||||
def construct(self, x):
|
||||
return self.ResizeNearestNeighborAlignCornerF(x)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ResizeNearestNeighborAlignCornerT():
|
||||
def resize_nn_grayscale_integer_ratio(datatype):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
|
||||
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
|
||||
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32)
|
||||
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(datatype))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ResizeNearestNeighborAlignCornerF():
|
||||
# larger h and w
|
||||
resize_nn = P.ResizeNearestNeighbor((9, 9))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3],
|
||||
[0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3],
|
||||
[0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3],
|
||||
[0.4, 0.4, 0.4, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6],
|
||||
[0.4, 0.4, 0.4, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6],
|
||||
[0.4, 0.4, 0.4, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6],
|
||||
[0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9],
|
||||
[0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9],
|
||||
[0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h and w
|
||||
resize_nn = P.ResizeNearestNeighbor((1, 1))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h, larger w
|
||||
resize_nn = P.ResizeNearestNeighbor((1, 6))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# larger h, smaller w
|
||||
resize_nn = P.ResizeNearestNeighbor((6, 1))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1], [0.1], [0.4], [0.4], [0.7], [0.7]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h, same w
|
||||
resize_nn = P.ResizeNearestNeighbor((1, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# larger h, same w
|
||||
resize_nn = P.ResizeNearestNeighbor((6, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3],
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9],
|
||||
[0.7, 0.8, 0.9]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same h, smaller w
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 1))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1], [0.4], [0.7]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same h, larger w
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 6))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3],
|
||||
[0.4, 0.4, 0.5, 0.5, 0.6, 0.6],
|
||||
[0.7, 0.7, 0.8, 0.8, 0.9, 0.9]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same w, same h (identity)
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
np.testing.assert_array_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
|
||||
|
||||
def resize_nn_grayscale_not_integer_ratio(datatype):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
|
||||
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
|
||||
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32)
|
||||
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 0.0, 0.1, 0.2]]]]).astype(datatype))
|
||||
|
||||
# larger h and w
|
||||
resize_nn = P.ResizeNearestNeighbor((7, 7))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4],
|
||||
[0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4],
|
||||
[0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4],
|
||||
[0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8],
|
||||
[0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8],
|
||||
[0.9, 0.9, 0.0, 0.0, 0.1, 0.1, 0.2],
|
||||
[0.9, 0.9, 0.0, 0.0, 0.1, 0.1, 0.2]]]]).astype(datatype))
|
||||
|
||||
# smaller h and w
|
||||
resize_nn = P.ResizeNearestNeighbor((2, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3], [0.5, 0.6, 0.7]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h, larger w
|
||||
resize_nn = P.ResizeNearestNeighbor((2, 7))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4],
|
||||
[0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# larger h, smaller w
|
||||
resize_nn = P.ResizeNearestNeighbor((5, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3],
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.5, 0.6, 0.7],
|
||||
[0.5, 0.6, 0.7],
|
||||
[0.9, 0.0, 0.1]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h, same w
|
||||
resize_nn = P.ResizeNearestNeighbor((2, 4))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# larger h, same w
|
||||
resize_nn = P.ResizeNearestNeighbor((8, 4))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4],
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 0.0, 0.1, 0.2],
|
||||
[0.9, 0.0, 0.1, 0.2]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same h, smaller w
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 2))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.3],
|
||||
[0.5, 0.7],
|
||||
[0.9, 0.1]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same h, larger w
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 6))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.3, 0.3, 0.4],
|
||||
[0.5, 0.5, 0.6, 0.7, 0.7, 0.8],
|
||||
[0.9, 0.9, 0.0, 0.1, 0.1, 0.2]]]]).astype(datatype))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same w, same h (identity)
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 4))
|
||||
output = resize_nn(input_tensor)
|
||||
np.testing.assert_array_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
|
||||
|
||||
def test_resize_nn_rgb_integer_ratio():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array(
|
||||
[[[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
[[11, 12, 13], [14, 15, 16], [17, 18, 19]],
|
||||
[[111, 112, 113], [114, 115, 116], [117, 118, 119]]]]).astype(np.int32))
|
||||
|
||||
# larger h and w
|
||||
resize_nn = P.ResizeNearestNeighbor((9, 9))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output_array = np.array([[[[1, 1, 1, 2, 2, 2, 3, 3, 3],
|
||||
[1, 1, 1, 2, 2, 2, 3, 3, 3],
|
||||
[1, 1, 1, 2, 2, 2, 3, 3, 3],
|
||||
[4, 4, 4, 5, 5, 5, 6, 6, 6],
|
||||
[4, 4, 4, 5, 5, 5, 6, 6, 6],
|
||||
[4, 4, 4, 5, 5, 5, 6, 6, 6],
|
||||
[7, 7, 7, 8, 8, 8, 9, 9, 9],
|
||||
[7, 7, 7, 8, 8, 8, 9, 9, 9],
|
||||
[7, 7, 7, 8, 8, 8, 9, 9, 9]],
|
||||
[[11, 11, 11, 12, 12, 12, 13, 13, 13],
|
||||
[11, 11, 11, 12, 12, 12, 13, 13, 13],
|
||||
[11, 11, 11, 12, 12, 12, 13, 13, 13],
|
||||
[14, 14, 14, 15, 15, 15, 16, 16, 16],
|
||||
[14, 14, 14, 15, 15, 15, 16, 16, 16],
|
||||
[14, 14, 14, 15, 15, 15, 16, 16, 16],
|
||||
[17, 17, 17, 18, 18, 18, 19, 19, 19],
|
||||
[17, 17, 17, 18, 18, 18, 19, 19, 19],
|
||||
[17, 17, 17, 18, 18, 18, 19, 19, 19]],
|
||||
[[111, 111, 111, 112, 112, 112, 113, 113, 113],
|
||||
[111, 111, 111, 112, 112, 112, 113, 113, 113],
|
||||
[111, 111, 111, 112, 112, 112, 113, 113, 113],
|
||||
[114, 114, 114, 115, 115, 115, 116, 116, 116],
|
||||
[114, 114, 114, 115, 115, 115, 116, 116, 116],
|
||||
[114, 114, 114, 115, 115, 115, 116, 116, 116],
|
||||
[117, 117, 117, 118, 118, 118, 119, 119, 119],
|
||||
[117, 117, 117, 118, 118, 118, 119, 119, 119],
|
||||
[117, 117, 117, 118, 118, 118, 119, 119, 119]]]])
|
||||
expected_output = Tensor(np.array(expected_output_array).astype(np.int32))
|
||||
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h and w
|
||||
resize_nn = P.ResizeNearestNeighbor((1, 1))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1]], [[11]], [[111]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h, larger w
|
||||
resize_nn = P.ResizeNearestNeighbor((1, 6))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 1, 2, 2, 3, 3]],
|
||||
[[11, 11, 12, 12, 13, 13]],
|
||||
[[111, 111, 112, 112, 113, 113]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# larger h, smaller w
|
||||
resize_nn = P.ResizeNearestNeighbor((6, 1))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1], [1], [4], [4], [7], [7]],
|
||||
[[11], [11], [14], [14], [17], [17]],
|
||||
[[111], [111], [114], [114], [117], [117]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h, same w
|
||||
resize_nn = P.ResizeNearestNeighbor((1, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 2, 3]],
|
||||
[[11, 12, 13]],
|
||||
[[111, 112, 113]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# larger h, same w
|
||||
resize_nn = P.ResizeNearestNeighbor((6, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 2, 3],
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[7, 8, 9]],
|
||||
[[11, 12, 13],
|
||||
[11, 12, 13],
|
||||
[14, 15, 16],
|
||||
[14, 15, 16],
|
||||
[17, 18, 19],
|
||||
[17, 18, 19]],
|
||||
[[111, 112, 113],
|
||||
[111, 112, 113],
|
||||
[114, 115, 116],
|
||||
[114, 115, 116],
|
||||
[117, 118, 119],
|
||||
[117, 118, 119]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same h, smaller w
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 1))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1], [4], [7]],
|
||||
[[11], [14], [17]],
|
||||
[[111], [114], [117]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same h, larger w
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 6))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 1, 2, 2, 3, 3],
|
||||
[4, 4, 5, 5, 6, 6],
|
||||
[7, 7, 8, 8, 9, 9]],
|
||||
[[11, 11, 12, 12, 13, 13],
|
||||
[14, 14, 15, 15, 16, 16],
|
||||
[17, 17, 18, 18, 19, 19]],
|
||||
[[111, 111, 112, 112, 113, 113],
|
||||
[114, 114, 115, 115, 116, 116],
|
||||
[117, 117, 118, 118, 119, 119]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same w, same h (identity)
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
np.testing.assert_array_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
|
||||
|
||||
def test_resize_nn_rgb_not_integer_ratio():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 0, 1, 2]],
|
||||
[[11, 12, 13, 14],
|
||||
[15, 16, 17, 18],
|
||||
[19, 10, 11, 12]],
|
||||
[[111, 112, 113, 114],
|
||||
[115, 116, 117, 118],
|
||||
[119, 110, 111, 112]]]]).astype(np.int32))
|
||||
|
||||
# larger h and w
|
||||
resize_nn = P.ResizeNearestNeighbor((7, 7))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 1, 2, 2, 3, 3, 4],
|
||||
[1, 1, 2, 2, 3, 3, 4],
|
||||
[1, 1, 2, 2, 3, 3, 4],
|
||||
[5, 5, 6, 6, 7, 7, 8],
|
||||
[5, 5, 6, 6, 7, 7, 8],
|
||||
[9, 9, 0, 0, 1, 1, 2],
|
||||
[9, 9, 0, 0, 1, 1, 2]],
|
||||
[[11, 11, 12, 12, 13, 13, 14],
|
||||
[11, 11, 12, 12, 13, 13, 14],
|
||||
[11, 11, 12, 12, 13, 13, 14],
|
||||
[15, 15, 16, 16, 17, 17, 18],
|
||||
[15, 15, 16, 16, 17, 17, 18],
|
||||
[19, 19, 10, 10, 11, 11, 12],
|
||||
[19, 19, 10, 10, 11, 11, 12]],
|
||||
[[111, 111, 112, 112, 113, 113, 114],
|
||||
[111, 111, 112, 112, 113, 113, 114],
|
||||
[111, 111, 112, 112, 113, 113, 114],
|
||||
[115, 115, 116, 116, 117, 117, 118],
|
||||
[115, 115, 116, 116, 117, 117, 118],
|
||||
[119, 119, 110, 110, 111, 111, 112],
|
||||
[119, 119, 110, 110, 111, 111, 112]]]]).astype(np.int32))
|
||||
|
||||
# smaller h and w
|
||||
resize_nn = P.ResizeNearestNeighbor((2, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 2, 3], [5, 6, 7]],
|
||||
[[11, 12, 13], [15, 16, 17]],
|
||||
[[111, 112, 113], [115, 116, 117]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h, larger w
|
||||
resize_nn = P.ResizeNearestNeighbor((2, 7))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 1, 2, 2, 3, 3, 4],
|
||||
[5, 5, 6, 6, 7, 7, 8]],
|
||||
[[11, 11, 12, 12, 13, 13, 14],
|
||||
[15, 15, 16, 16, 17, 17, 18]],
|
||||
[[111, 111, 112, 112, 113, 113, 114],
|
||||
[115, 115, 116, 116, 117, 117, 118]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# larger h, smaller w
|
||||
resize_nn = P.ResizeNearestNeighbor((5, 3))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 2, 3],
|
||||
[1, 2, 3],
|
||||
[5, 6, 7],
|
||||
[5, 6, 7],
|
||||
[9, 0, 1]],
|
||||
[[11, 12, 13],
|
||||
[11, 12, 13],
|
||||
[15, 16, 17],
|
||||
[15, 16, 17],
|
||||
[19, 10, 11]],
|
||||
[[111, 112, 113],
|
||||
[111, 112, 113],
|
||||
[115, 116, 117],
|
||||
[115, 116, 117],
|
||||
[119, 110, 111]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# smaller h, same w
|
||||
resize_nn = P.ResizeNearestNeighbor((2, 4))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 2, 3, 4],
|
||||
[5, 6, 7, 8]],
|
||||
[[11, 12, 13, 14],
|
||||
[15, 16, 17, 18]],
|
||||
[[111, 112, 113, 114],
|
||||
[115, 116, 117, 118]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# larger h, same w
|
||||
resize_nn = P.ResizeNearestNeighbor((8, 4))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 2, 3, 4],
|
||||
[1, 2, 3, 4],
|
||||
[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[5, 6, 7, 8],
|
||||
[5, 6, 7, 8],
|
||||
[9, 0, 1, 2],
|
||||
[9, 0, 1, 2]],
|
||||
[[11, 12, 13, 14],
|
||||
[11, 12, 13, 14],
|
||||
[11, 12, 13, 14],
|
||||
[15, 16, 17, 18],
|
||||
[15, 16, 17, 18],
|
||||
[15, 16, 17, 18],
|
||||
[19, 10, 11, 12],
|
||||
[19, 10, 11, 12]],
|
||||
[[111, 112, 113, 114],
|
||||
[111, 112, 113, 114],
|
||||
[111, 112, 113, 114],
|
||||
[115, 116, 117, 118],
|
||||
[115, 116, 117, 118],
|
||||
[115, 116, 117, 118],
|
||||
[119, 110, 111, 112],
|
||||
[119, 110, 111, 112]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same h, smaller w
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 2))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 3], [5, 7], [9, 1]],
|
||||
[[11, 13], [15, 17], [19, 11]],
|
||||
[[111, 113], [115, 117], [119, 111]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same h, larger w
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 6))
|
||||
output = resize_nn(input_tensor)
|
||||
expected_output = Tensor(np.array([[[[1, 1, 2, 3, 3, 4],
|
||||
[5, 5, 6, 7, 7, 8],
|
||||
[9, 9, 0, 1, 1, 2]],
|
||||
[[11, 11, 12, 13, 13, 14],
|
||||
[15, 15, 16, 17, 17, 18],
|
||||
[19, 19, 10, 11, 11, 12]],
|
||||
[[111, 111, 112, 113, 113, 114],
|
||||
[115, 115, 116, 117, 117, 118],
|
||||
[119, 119, 110, 111, 111, 112]]]]).astype(np.int32))
|
||||
np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy())
|
||||
|
||||
# same w, same h (identity)
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 4))
|
||||
output = resize_nn(input_tensor)
|
||||
np.testing.assert_array_equal(output.asnumpy(), input_tensor.asnumpy())
|
||||
|
||||
|
||||
def resize_nn_grayscale_multiple_images(datatype):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]],
|
||||
[[[0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [0.1, 0.2, 0.3]]],
|
||||
[[[0.7, 0.8, 0.9], [0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]]).astype(datatype))
|
||||
|
||||
resize_nn = P.ResizeNearestNeighbor((2, 6))
|
||||
output = resize_nn(input_tensor)
|
||||
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3],
|
||||
[0.4, 0.4, 0.5, 0.5, 0.6, 0.6]]],
|
||||
[[[0.4, 0.4, 0.5, 0.5, 0.6, 0.6],
|
||||
[0.7, 0.7, 0.8, 0.8, 0.9, 0.9]]],
|
||||
[[[0.7, 0.7, 0.8, 0.8, 0.9, 0.9],
|
||||
[0.1, 0.1, 0.2, 0.2, 0.3, 0.3]]]]).astype(datatype))
|
||||
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output.asnumpy())
|
||||
|
||||
|
||||
def resize_nn_grayscale_align_corners(datatype):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype))
|
||||
|
||||
resize_nn_corners_aligned = P.ResizeNearestNeighbor((3, 7), align_corners=True)
|
||||
output_corners_aligned = resize_nn_corners_aligned(input_tensor)
|
||||
|
||||
resize_nn = P.ResizeNearestNeighbor((3, 7))
|
||||
output = resize_nn(input_tensor)
|
||||
|
||||
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4],
|
||||
[0.5, 0.6, 0.6, 0.7, 0.7, 0.8, 0.8],
|
||||
[0.5, 0.6, 0.6, 0.7, 0.7, 0.8, 0.8]]]]).astype(datatype))
|
||||
|
||||
np.testing.assert_array_equal(output_corners_aligned.asnumpy(), expected_output.asnumpy())
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, output.asnumpy(), expected_output.asnumpy())
|
||||
|
||||
|
||||
def test_resize_nn_rgb_multiple():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10]],
|
||||
[[11, 12, 13, 14, 15],
|
||||
[16, 17, 18, 19, 20]],
|
||||
[[111, 112, 113, 114, 115],
|
||||
[116, 117, 118, 119, 120]]],
|
||||
[[[11, 12, 13, 14, 15],
|
||||
[16, 17, 18, 19, 20]],
|
||||
[[111, 112, 113, 114, 115],
|
||||
[116, 117, 118, 119, 120]],
|
||||
[[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10]]],
|
||||
[[[111, 112, 113, 114, 115],
|
||||
[116, 117, 118, 119, 120]],
|
||||
[[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10]],
|
||||
[[11, 12, 13, 14, 15],
|
||||
[16, 17, 18, 19, 20]]]]).astype(np.int32))
|
||||
|
||||
resize_nn = P.ResizeNearestNeighbor((5, 2))
|
||||
output = resize_nn(input_tensor)
|
||||
|
||||
expected_output = Tensor(np.array([[[[1, 3], [1, 3], [1, 3], [6, 8], [6, 8]],
|
||||
[[11, 13], [11, 13], [11, 13], [16, 18], [16, 18]],
|
||||
[[111, 113], [111, 113], [111, 113], [116, 118], [116, 118]]],
|
||||
[[[11, 13], [11, 13], [11, 13], [16, 18], [16, 18]],
|
||||
[[111, 113], [111, 113], [111, 113], [116, 118], [116, 118]],
|
||||
[[1, 3], [1, 3], [1, 3], [6, 8], [6, 8]]],
|
||||
[[[111, 113], [111, 113], [111, 113], [116, 118], [116, 118]],
|
||||
[[1, 3], [1, 3], [1, 3], [6, 8], [6, 8]],
|
||||
[[11, 13], [11, 13], [11, 13], [16, 18], [16, 18]]]]).astype(np.int32))
|
||||
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output.asnumpy())
|
||||
|
||||
|
||||
def test_resize_nn_rgb_align_corners():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[1, 2, 3, 4], [5, 6, 7, 8]],
|
||||
[[11, 12, 13, 14], [15, 16, 17, 18]],
|
||||
[[21, 22, 23, 24], [25, 26, 27, 28]]]]).astype(np.int32))
|
||||
|
||||
resize_nn_corners_aligned = P.ResizeNearestNeighbor((5, 2), align_corners=True)
|
||||
output_corners_aligned = resize_nn_corners_aligned(input_tensor)
|
||||
|
||||
resize_nn = P.ResizeNearestNeighbor((5, 2))
|
||||
output = resize_nn(input_tensor)
|
||||
|
||||
expected_output = Tensor(np.array([[[[1, 4], [1, 4], [5, 8], [5, 8], [5, 8]],
|
||||
[[11, 14], [11, 14], [15, 18], [15, 18], [15, 18]],
|
||||
[[21, 24], [21, 24], [25, 28], [25, 28], [25, 28]]]]).astype(np.int32))
|
||||
|
||||
np.testing.assert_array_equal(output_corners_aligned.asnumpy(), expected_output.asnumpy())
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, output.asnumpy(), expected_output.asnumpy())
|
||||
|
||||
|
||||
def test_resize_nn_grayscale_integer_ratio_half():
|
||||
resize_nn_grayscale_integer_ratio(np.float16)
|
||||
|
||||
def test_resize_nn_grayscale_integer_ratio_float():
|
||||
resize_nn_grayscale_integer_ratio(np.float32)
|
||||
|
||||
def test_resize_nn_grayscale_not_integer_ratio_half():
|
||||
resize_nn_grayscale_not_integer_ratio(np.float16)
|
||||
|
||||
def test_resize_nn_grayscale_not_integer_ratio_float():
|
||||
resize_nn_grayscale_not_integer_ratio(np.float32)
|
||||
|
||||
def test_resize_nn_grayscale_multiple_half():
|
||||
resize_nn_grayscale_multiple_images(np.float16)
|
||||
|
||||
def test_resize_nn_grayscale_multiple_float():
|
||||
resize_nn_grayscale_multiple_images(np.float32)
|
||||
|
||||
def test_resize_nn_grayscale_align_corners_half():
|
||||
resize_nn_grayscale_align_corners(np.float16)
|
||||
|
||||
def test_resize_nn_grayscale_align_corners_float():
|
||||
resize_nn_grayscale_align_corners(np.float32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_resize_nn_grayscale_integer_ratio_half()
|
||||
test_resize_nn_grayscale_integer_ratio_float()
|
||||
test_resize_nn_grayscale_not_integer_ratio_half()
|
||||
test_resize_nn_grayscale_not_integer_ratio_float()
|
||||
test_resize_nn_grayscale_multiple_half()
|
||||
test_resize_nn_grayscale_multiple_float()
|
||||
test_resize_nn_grayscale_align_corners_half()
|
||||
test_resize_nn_grayscale_align_corners_float()
|
||||
test_resize_nn_rgb_integer_ratio()
|
||||
test_resize_nn_rgb_not_integer_ratio()
|
||||
test_resize_nn_rgb_multiple()
|
||||
test_resize_nn_rgb_align_corners()
|
||||
|
|
Loading…
Reference in New Issue