forked from mindspore-Ecosystem/mindspore
Added TenCrop test
Added unit tests for both testing the functinality of the TenCrop and its error messages. Due to the similarity of this method to FiveCrop the test cases are similar to FiveCrop test cases. Signed-off-by: Mahdi <mahdi.rahmani.hanzaki@huawei.com> added error_msg function call in the main method refactored the test and added visual representation of the results Separated the two error cases into two different functions and used the visualize function in util.py to plot the result of TenCrop. Signed-off-by: Mahdi <mahdi.rahmani.hanzaki@huawei.com> Added new test cases Added new test cases including test case for checking the error message when the size variable is not a positive integer, test case for rectangle crop, test case for vertical flip setting, and testing for similarity of the result of TenCrop for the same input data in different runs. Signed-off-by: Mahdi <mahdi.rahmani.hanzaki@huawei.com> changed visualize in test_five_crop Changed the visualize function in test_five_crop to use the already existing function in util.py Signed-off-by: Mahdi <mahdi.rahmani.hanzaki@huawei.com> made generate_golden variable global
This commit is contained in:
parent
4ce1cf4529
commit
dfc097019b
Binary file not shown.
|
@ -21,29 +21,13 @@ import pytest
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.py_transforms as vision
|
||||
from mindspore import log as logger
|
||||
from util import visualize
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def visualize(image_1, image_2):
|
||||
"""
|
||||
visualizes the image using FiveCrop
|
||||
"""
|
||||
plt.subplot(161)
|
||||
plt.imshow(image_1)
|
||||
plt.title("Original")
|
||||
|
||||
for i, image in enumerate(image_2):
|
||||
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
plt.subplot(162 + i)
|
||||
plt.imshow(image)
|
||||
plt.title("image {} in FiveCrop".format(i + 1))
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_five_crop_op():
|
||||
def test_five_crop_op(plot=False):
|
||||
"""
|
||||
Test FiveCrop
|
||||
"""
|
||||
|
@ -79,8 +63,8 @@ def test_five_crop_op():
|
|||
|
||||
logger.info("dtype of image_1: {}".format(image_1.dtype))
|
||||
logger.info("dtype of image_2: {}".format(image_2.dtype))
|
||||
|
||||
# visualize(image_1, image_2)
|
||||
if plot:
|
||||
visualize(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
|
||||
|
||||
# The output data should be of a 4D tensor shape, a stack of 5 images.
|
||||
assert len(image_2.shape) == 4
|
||||
|
@ -111,5 +95,5 @@ def test_five_crop_error_msg():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_five_crop_op()
|
||||
test_five_crop_op(plot=True)
|
||||
test_five_crop_error_msg()
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Testing TenCrop in DE
|
||||
"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.py_transforms as vision
|
||||
from util import visualize, save_and_check_md5
|
||||
from mindspore import log as logger
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def util_test_ten_crop(crop_size, vertical_flip=False, plot=False):
|
||||
"""
|
||||
Utility function for testing TenCrop. Input arguments are given by other tests
|
||||
"""
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
transforms_1 = [
|
||||
vision.Decode(),
|
||||
vision.ToTensor(),
|
||||
]
|
||||
transform_1 = vision.ComposeOp(transforms_1)
|
||||
data1 = data1.map(input_columns=["image"], operations=transform_1())
|
||||
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
transforms_2 = [
|
||||
vision.Decode(),
|
||||
vision.TenCrop(crop_size, use_vertical_flip=vertical_flip),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
]
|
||||
transform_2 = vision.ComposeOp(transforms_2)
|
||||
data2 = data2.map(input_columns=["image"], operations=transform_2())
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
num_iter += 1
|
||||
image_1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
image_2 = item2["image"]
|
||||
|
||||
logger.info("shape of image_1: {}".format(image_1.shape))
|
||||
logger.info("shape of image_2: {}".format(image_2.shape))
|
||||
|
||||
logger.info("dtype of image_1: {}".format(image_1.dtype))
|
||||
logger.info("dtype of image_2: {}".format(image_2.dtype))
|
||||
|
||||
if plot:
|
||||
visualize(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
|
||||
|
||||
# The output data should be of a 4D tensor shape, a stack of 10 images.
|
||||
assert len(image_2.shape) == 4
|
||||
assert image_2.shape[0] == 10
|
||||
|
||||
|
||||
def test_ten_crop_op_square(plot=False):
|
||||
"""
|
||||
Tests TenCrop for a square crop
|
||||
"""
|
||||
|
||||
logger.info("test_ten_crop_op_square")
|
||||
util_test_ten_crop(200, plot=plot)
|
||||
|
||||
|
||||
def test_ten_crop_op_rectangle(plot=False):
|
||||
"""
|
||||
Tests TenCrop for a rectangle crop
|
||||
"""
|
||||
|
||||
logger.info("test_ten_crop_op_rectangle")
|
||||
util_test_ten_crop((200, 150), plot=plot)
|
||||
|
||||
|
||||
def test_ten_crop_op_vertical_flip(plot=False):
|
||||
"""
|
||||
Tests TenCrop with vertical flip set to True
|
||||
"""
|
||||
|
||||
logger.info("test_ten_crop_op_vertical_flip")
|
||||
util_test_ten_crop(200, vertical_flip=True, plot=plot)
|
||||
|
||||
|
||||
def test_ten_crop_md5():
|
||||
"""
|
||||
Tests TenCrops for giving the same results in multiple runs.
|
||||
Since TenCrop is a deterministic function, we expect it to return the same result for a specific input every time
|
||||
"""
|
||||
logger.info("test_ten_crop_md5")
|
||||
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
transforms_2 = [
|
||||
vision.Decode(),
|
||||
vision.TenCrop((200, 100), use_vertical_flip=True),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
]
|
||||
transform_2 = vision.ComposeOp(transforms_2)
|
||||
data2 = data2.map(input_columns=["image"], operations=transform_2())
|
||||
# Compare with expected md5 from images
|
||||
filename = "ten_crop_01_result.npz"
|
||||
save_and_check_md5(data2, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_ten_crop_list_size_error_msg():
|
||||
"""
|
||||
Tests TenCrop error message when the size arg has more than 2 elements
|
||||
"""
|
||||
logger.info("test_ten_crop_list_size_error_msg")
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
transforms = [
|
||||
vision.Decode(),
|
||||
vision.TenCrop([200, 200, 200]),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
]
|
||||
error_msg = "Size should be a single integer or a list/tuple (h, w) of length 2."
|
||||
assert error_msg == str(info.value)
|
||||
|
||||
|
||||
def test_ten_crop_invalid_size_error_msg():
|
||||
"""
|
||||
Tests TenCrop error message when the size arg is not positive
|
||||
"""
|
||||
logger.info("test_ten_crop_invalid_size_error_msg")
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
transforms = [
|
||||
vision.Decode(),
|
||||
vision.TenCrop(0),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
]
|
||||
error_msg = "Input is not within the required range"
|
||||
assert error_msg == str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
transforms = [
|
||||
vision.Decode(),
|
||||
vision.TenCrop(-10),
|
||||
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
|
||||
]
|
||||
|
||||
assert error_msg == str(info.value)
|
||||
|
||||
|
||||
def test_ten_crop_wrong_img_error_msg():
|
||||
"""
|
||||
Tests TenCrop error message when the image is not in the correct format.
|
||||
"""
|
||||
logger.info("test_ten_crop_wrong_img_error_msg")
|
||||
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
transforms = [
|
||||
vision.Decode(),
|
||||
vision.TenCrop(200),
|
||||
vision.ToTensor()
|
||||
]
|
||||
transform = vision.ComposeOp(transforms)
|
||||
data = data.map(input_columns=["image"], operations=transform())
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
data.create_tuple_iterator().get_next()
|
||||
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
|
||||
|
||||
# error msg comes from ToTensor()
|
||||
assert error_msg in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ten_crop_op_square(plot=True)
|
||||
test_ten_crop_op_rectangle(plot=True)
|
||||
test_ten_crop_op_vertical_flip(plot=True)
|
||||
test_ten_crop_md5()
|
||||
test_ten_crop_list_size_error_msg()
|
||||
test_ten_crop_invalid_size_error_msg()
|
||||
test_ten_crop_wrong_img_error_msg()
|
Loading…
Reference in New Issue