230 lines
8.3 KiB
Python
230 lines
8.3 KiB
Python
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""
|
|
Testing SlicePatches Python API
|
|
"""
|
|
import functools
|
|
import numpy as np
|
|
|
|
import mindspore.dataset as ds
|
|
import mindspore.dataset.vision as vision
|
|
import mindspore.dataset.vision.utils as mode
|
|
|
|
from mindspore import log as logger
|
|
from util import diff_mse, visualize_list
|
|
|
|
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
|
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
|
|
|
|
|
def test_slice_patches_01(plot=False):
|
|
"""
|
|
slice rgb image(100, 200) to 4 patches
|
|
"""
|
|
slice_to_patches([100, 200], 2, 2, True, plot=plot)
|
|
|
|
|
|
def test_slice_patches_02(plot=False):
|
|
"""
|
|
no op
|
|
"""
|
|
slice_to_patches([100, 200], 1, 1, True, plot=plot)
|
|
|
|
|
|
def test_slice_patches_03(plot=False):
|
|
"""
|
|
slice rgb image(99, 199) to 4 patches in pad mode
|
|
"""
|
|
slice_to_patches([99, 199], 2, 2, True, plot=plot)
|
|
|
|
|
|
def test_slice_patches_04(plot=False):
|
|
"""
|
|
slice rgb image(99, 199) to 4 patches in drop mode
|
|
"""
|
|
slice_to_patches([99, 199], 2, 2, False, plot=plot)
|
|
|
|
|
|
def test_slice_patches_05(plot=False):
|
|
"""
|
|
slice rgb image(99, 199) to 4 patches in pad mode
|
|
"""
|
|
slice_to_patches([99, 199], 2, 2, True, 255, plot=plot)
|
|
|
|
|
|
def slice_to_patches(ori_size, num_h, num_w, pad_or_drop, fill_value=0, plot=False):
|
|
"""
|
|
Tool function for slice patches
|
|
"""
|
|
logger.info("test_slice_patches_pipeline")
|
|
|
|
cols = ['img' + str(x) for x in range(num_h*num_w)]
|
|
# First dataset
|
|
dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
|
decode_op = vision.Decode()
|
|
resize_op = vision.Resize(ori_size) # H, W
|
|
slice_patches_op = vision.SlicePatches(
|
|
num_h, num_w, mode.SliceMode.PAD, fill_value)
|
|
if not pad_or_drop:
|
|
slice_patches_op = vision.SlicePatches(
|
|
num_h, num_w, mode.SliceMode.DROP)
|
|
dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
|
|
dataset1 = dataset1.map(operations=resize_op, input_columns=["image"])
|
|
dataset1 = dataset1.map(operations=slice_patches_op,
|
|
input_columns=["image"], output_columns=cols, column_order=cols)
|
|
# Second dataset
|
|
dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
|
|
dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
|
|
dataset2 = dataset2.map(operations=resize_op, input_columns=["image"])
|
|
func_slice_patches = functools.partial(
|
|
slice_patches, num_h=num_h, num_w=num_w, pad_or_drop=pad_or_drop, fill_value=fill_value)
|
|
dataset2 = dataset2.map(operations=func_slice_patches,
|
|
input_columns=["image"], output_columns=cols, column_order=cols)
|
|
|
|
num_iter = 0
|
|
patches_c = []
|
|
patches_py = []
|
|
for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
|
|
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
|
|
|
|
for x in range(num_h*num_w):
|
|
col = "img" + str(x)
|
|
mse = diff_mse(data1[col], data2[col])
|
|
logger.info("slice_patches_{}, mse: {}".format(num_iter + 1, mse))
|
|
assert mse == 0
|
|
patches_c.append(data1[col])
|
|
patches_py.append(data2[col])
|
|
num_iter += 1
|
|
if plot:
|
|
visualize_list(patches_py, patches_c)
|
|
|
|
|
|
def test_slice_patches_exception_01():
|
|
"""
|
|
Test SlicePatches with invalid parameters
|
|
"""
|
|
logger.info("test_Slice_Patches_exception")
|
|
try:
|
|
_ = vision.SlicePatches(0, 2)
|
|
except ValueError as e:
|
|
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
|
|
assert "Input num_height is not within" in str(e)
|
|
|
|
try:
|
|
_ = vision.SlicePatches(2, 0)
|
|
except ValueError as e:
|
|
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
|
|
assert "Input num_width is not within" in str(e)
|
|
|
|
try:
|
|
_ = vision.SlicePatches(2, 2, 1)
|
|
except TypeError as e:
|
|
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
|
|
assert "Argument slice_mode with value" in str(e)
|
|
|
|
try:
|
|
_ = vision.SlicePatches(2, 2, mode.SliceMode.PAD, -1)
|
|
except ValueError as e:
|
|
logger.info("Got an exception in SlicePatches: {}".format(str(e)))
|
|
assert "Input fill_value is not within" in str(e)
|
|
|
|
def test_slice_patches_06():
|
|
image = np.random.randint(0, 255, (158, 126, 1)).astype(np.int32)
|
|
slice_patches_op = vision.SlicePatches(2, 8)
|
|
patches = slice_patches_op(image)
|
|
assert len(patches) == 16
|
|
assert patches[0].shape == (79, 16, 1)
|
|
|
|
def test_slice_patches_07():
|
|
image = np.random.randint(0, 255, (158, 126)).astype(np.int32)
|
|
slice_patches_op = vision.SlicePatches(2, 8)
|
|
patches = slice_patches_op(image)
|
|
assert len(patches) == 16
|
|
assert patches[0].shape == (79, 16)
|
|
|
|
def test_slice_patches_08():
|
|
np_data = np.random.randint(0, 255, (1, 56, 82, 256)).astype(np.uint8)
|
|
dataset = ds.NumpySlicesDataset(np_data, column_names=["image"])
|
|
slice_patches_op = vision.SlicePatches(2, 2)
|
|
dataset = dataset.map(input_columns=["image"], output_columns=["img0", "img1", "img2", "img3"],
|
|
column_order=["img0", "img1", "img2", "img3"],
|
|
operations=slice_patches_op)
|
|
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
patch_shape = item['img0'].shape
|
|
assert patch_shape == (28, 41, 256)
|
|
|
|
def test_slice_patches_09():
|
|
image = np.random.randint(0, 255, (56, 82, 256)).astype(np.uint8)
|
|
slice_patches_op = vision.SlicePatches(4, 3, mode.SliceMode.PAD)
|
|
patches = slice_patches_op(image)
|
|
assert len(patches) == 12
|
|
assert patches[0].shape == (14, 28, 256)
|
|
|
|
def skip_test_slice_patches_10():
|
|
image = np.random.randint(0, 255, (7000, 7000, 255)).astype(np.uint8)
|
|
slice_patches_op = vision.SlicePatches(10, 13, mode.SliceMode.DROP)
|
|
patches = slice_patches_op(image)
|
|
assert patches[0].shape == (700, 538, 255)
|
|
|
|
def skip_test_slice_patches_11():
|
|
np_data = np.random.randint(0, 255, (1, 7000, 7000, 256)).astype(np.uint8)
|
|
dataset = ds.NumpySlicesDataset(np_data, column_names=["image"])
|
|
slice_patches_op = vision.SlicePatches(10, 13, mode.SliceMode.DROP)
|
|
cols = ['img' + str(x) for x in range(10*13)]
|
|
dataset = dataset.map(input_columns=["image"], output_columns=cols,
|
|
column_order=cols, operations=slice_patches_op)
|
|
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
|
patch_shape = item['img0'].shape
|
|
assert patch_shape == (700, 538, 256)
|
|
|
|
def slice_patches(image, num_h, num_w, pad_or_drop, fill_value):
|
|
""" help function which slice patches with numpy """
|
|
if num_h == 1 and num_w == 1:
|
|
return image
|
|
# (H, W, C)
|
|
H, W, C = image.shape
|
|
patch_h = H // num_h
|
|
patch_w = W // num_w
|
|
if H % num_h != 0:
|
|
if pad_or_drop:
|
|
patch_h += 1
|
|
if W % num_w != 0:
|
|
if pad_or_drop:
|
|
patch_w += 1
|
|
img = image[:, :, :]
|
|
if pad_or_drop:
|
|
img = np.full([patch_h*num_h, patch_w*num_w, C], fill_value, dtype=np.uint8)
|
|
img[:H, :W] = image[:, :, :]
|
|
patches = []
|
|
for top in range(num_h):
|
|
for left in range(num_w):
|
|
patches.append(img[top*patch_h:(top+1)*patch_h,
|
|
left*patch_w:(left+1)*patch_w, :])
|
|
|
|
return (*patches,)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_slice_patches_01(plot=True)
|
|
test_slice_patches_02(plot=True)
|
|
test_slice_patches_03(plot=True)
|
|
test_slice_patches_04(plot=True)
|
|
test_slice_patches_05(plot=True)
|
|
test_slice_patches_06()
|
|
test_slice_patches_07()
|
|
test_slice_patches_08()
|
|
test_slice_patches_09()
|
|
test_slice_patches_exception_01()
|