forked from mindspore-Ecosystem/mindspore
Added UniformAugment + Python Augmentation Ops
This commit is contained in:
parent
862d23fe90
commit
56e7a7deb5
|
@ -1312,3 +1312,177 @@ class HsvToRgb:
|
|||
rgb_imgs (numpy.ndarray), Numpy RGB image with same shape of hsv_imgs.
|
||||
"""
|
||||
return util.hsv_to_rgbs(hsv_imgs, self.is_hwc)
|
||||
|
||||
|
||||
class RandomColor:
|
||||
"""
|
||||
Adjust the color of the input PIL image by a random degree.
|
||||
|
||||
Args:
|
||||
degrees (sequence): Range of random color adjustment degrees.
|
||||
It should be in (min, max) format (default=(0.1,1.9)).
|
||||
|
||||
Examples:
|
||||
>>> py_transforms.ComposeOp([py_transforms.Decode(),
|
||||
>>> py_transforms.RandomColor(0.5,1.5),
|
||||
>>> py_transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, degrees=(0.1, 1.9)):
|
||||
self.degrees = degrees
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be color adjusted.
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Color adjusted image.
|
||||
"""
|
||||
|
||||
return util.random_color(img, self.degrees)
|
||||
|
||||
class RandomSharpness:
|
||||
"""
|
||||
Adjust the sharpness of the input PIL image by a random degree.
|
||||
|
||||
Args:
|
||||
degrees (sequence): Range of random sharpness adjustment degrees.
|
||||
It should be in (min, max) format (default=(0.1,1.9)).
|
||||
|
||||
Examples:
|
||||
>>> py_transforms.ComposeOp([py_transforms.Decode(),
|
||||
>>> py_transforms.RandomColor(0.5,1.5),
|
||||
>>> py_transforms.ToTensor()])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, degrees=(0.1, 1.9)):
|
||||
self.degrees = degrees
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be sharpness adjusted.
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Color adjusted image.
|
||||
"""
|
||||
|
||||
return util.random_sharpness(img, self.degrees)
|
||||
|
||||
|
||||
class AutoContrast:
|
||||
"""
|
||||
Automatically maximize the contrast of the input PIL image.
|
||||
|
||||
Examples:
|
||||
>>> py_transforms.ComposeOp([py_transforms.Decode(),
|
||||
>>> py_transforms.AutoContrast(),
|
||||
>>> py_transforms.ToTensor()])
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be augmented with AutoContrast.
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Augmented image.
|
||||
"""
|
||||
|
||||
return util.auto_contrast(img)
|
||||
|
||||
|
||||
class Invert:
|
||||
"""
|
||||
Invert colors of input PIL image.
|
||||
|
||||
Examples:
|
||||
>>> py_transforms.ComposeOp([py_transforms.Decode(),
|
||||
>>> py_transforms.Invert(),
|
||||
>>> py_transforms.ToTensor()])
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be color Inverted.
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Color inverted image.
|
||||
"""
|
||||
|
||||
return util.invert_color(img)
|
||||
|
||||
|
||||
class Equalize:
|
||||
"""
|
||||
Equalize the histogram of input PIL image.
|
||||
|
||||
Examples:
|
||||
>>> py_transforms.ComposeOp([py_transforms.Decode(),
|
||||
>>> py_transforms.Equalize(),
|
||||
>>> py_transforms.ToTensor()])
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be equalized.
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Equalized image.
|
||||
"""
|
||||
|
||||
return util.equalize(img)
|
||||
|
||||
|
||||
class UniformAugment:
|
||||
"""
|
||||
Uniformly select and apply a number of transforms sequentially from
|
||||
a list of transforms. Randomly assigns a probability to each transform for
|
||||
each image to decide whether apply it or not.
|
||||
|
||||
Args:
|
||||
transforms (list): List of transformations to be chosen from to apply.
|
||||
num_ops (int, optional): number of transforms to sequentially apply (default=2).
|
||||
|
||||
Examples:
|
||||
>>> transforms_list = [py_transforms.CenterCrop(64),
|
||||
>>> py_transforms.RandomColor(),
|
||||
>>> py_transforms.RandomSharpness(),
|
||||
>>> py_transforms.RandomRotation(30)]
|
||||
>>> py_transforms.ComposeOp([py_transforms.Decode(),
|
||||
>>> py_transforms.UniformAugment(transforms_list),
|
||||
>>> py_transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, num_ops=2):
|
||||
self.transforms = transforms
|
||||
self.num_ops = num_ops
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Call method.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be applied transformation.
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Transformed image.
|
||||
"""
|
||||
return util.uniform_augment(img, self.transforms, self.num_ops)
|
||||
|
|
|
@ -1408,3 +1408,160 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
|
|||
if batch_size == 0:
|
||||
return hsv_to_rgb(np_hsv_imgs, is_hwc)
|
||||
return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs])
|
||||
|
||||
|
||||
def random_color(img, degrees):
|
||||
|
||||
"""
|
||||
Adjust the color of the input PIL image by a random degree.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be color adjusted.
|
||||
degrees (sequence): Range of random color adjustment degrees.
|
||||
It should be in (min, max) format (default=(0.1,1.9)).
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Color adjusted image.
|
||||
"""
|
||||
|
||||
if not is_pil(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
if isinstance(degrees, (list, tuple)):
|
||||
if len(degrees) != 2:
|
||||
raise ValueError("Degrees must be a sequence length 2.")
|
||||
if degrees[0] < 0:
|
||||
raise ValueError("Degree value must be non-negative.")
|
||||
if degrees[0] > degrees[1]:
|
||||
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
|
||||
|
||||
else:
|
||||
raise TypeError("Degrees must be a sequence in (min,max) format.")
|
||||
|
||||
v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
|
||||
return ImageEnhance.Color(img).enhance(v)
|
||||
|
||||
|
||||
def random_sharpness(img, degrees):
|
||||
|
||||
"""
|
||||
Adjust the sharpness of the input PIL image by a random degree.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be sharpness adjusted.
|
||||
degrees (sequence): Range of random sharpness adjustment degrees.
|
||||
It should be in (min, max) format (default=(0.1,1.9)).
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Sharpness adjusted image.
|
||||
"""
|
||||
|
||||
if not is_pil(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
if isinstance(degrees, (list, tuple)):
|
||||
if len(degrees) != 2:
|
||||
raise ValueError("Degrees must be a sequence length 2.")
|
||||
if degrees[0] < 0:
|
||||
raise ValueError("Degree value must be non-negative.")
|
||||
if degrees[0] > degrees[1]:
|
||||
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
|
||||
|
||||
else:
|
||||
raise TypeError("Degrees must be a sequence in (min,max) format.")
|
||||
|
||||
v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
|
||||
return ImageEnhance.Sharpness(img).enhance(v)
|
||||
|
||||
|
||||
def auto_contrast(img):
|
||||
|
||||
"""
|
||||
Automatically maximize the contrast of the input PIL image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be augmented with AutoContrast.
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Augmented image.
|
||||
|
||||
"""
|
||||
|
||||
if not is_pil(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
return ImageOps.autocontrast(img)
|
||||
|
||||
|
||||
def invert_color(img):
|
||||
|
||||
"""
|
||||
Invert colors of input PIL image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be color inverted.
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Color inverted image.
|
||||
|
||||
"""
|
||||
|
||||
if not is_pil(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
return ImageOps.invert(img)
|
||||
|
||||
|
||||
def equalize(img):
|
||||
|
||||
"""
|
||||
Equalize the histogram of input PIL image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be equalized
|
||||
|
||||
Returns:
|
||||
img (PIL Image), Equalized image.
|
||||
|
||||
"""
|
||||
|
||||
if not is_pil(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
return ImageOps.equalize(img)
|
||||
|
||||
|
||||
def uniform_augment(img, transforms, num_ops):
|
||||
|
||||
"""
|
||||
Uniformly select and apply a number of transforms sequentially from
|
||||
a list of transforms. Randomly assigns a probability to each transform for
|
||||
each image to decide whether apply it or not.
|
||||
|
||||
Args:
|
||||
img: Image to be applied transformation.
|
||||
transforms (list): List of transformations to be chosen from to apply.
|
||||
num_ops (int): number of transforms to sequentially aaply.
|
||||
|
||||
Returns:
|
||||
img, Transformed image.
|
||||
"""
|
||||
|
||||
if transforms is None:
|
||||
raise ValueError("transforms is not provided.")
|
||||
if not isinstance(transforms, list):
|
||||
raise ValueError("The transforms needs to be a list.")
|
||||
|
||||
if not isinstance(num_ops, int):
|
||||
raise ValueError("Number of operations should be a positive integer.")
|
||||
if num_ops < 1:
|
||||
raise ValueError("Number of operators should equal or greater than one.")
|
||||
|
||||
for _ in range(num_ops):
|
||||
AugmentOp = random.choice(transforms)
|
||||
pr = random.random()
|
||||
if random.random() < pr:
|
||||
img = AugmentOp(img.copy())
|
||||
transforms.remove(AugmentOp)
|
||||
|
||||
return img
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
|
||||
def visualize(image_original, image_auto_contrast):
|
||||
"""
|
||||
visualizes the image using DE op and Numpy op
|
||||
"""
|
||||
num = len(image_auto_contrast)
|
||||
for i in range(num):
|
||||
plt.subplot(2, num, i + 1)
|
||||
plt.imshow(image_original[i])
|
||||
plt.title("Original image")
|
||||
|
||||
plt.subplot(2, num, i + num + 1)
|
||||
plt.imshow(image_auto_contrast[i])
|
||||
plt.title("DE AutoContrast image")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_auto_contrast(plot=False):
|
||||
"""
|
||||
Test AutoContrast
|
||||
"""
|
||||
logger.info("Test AutoContrast")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_original = ds.map(input_columns="image",
|
||||
operations=transforms_original())
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
# AutoContrast Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_auto_contrast = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.AutoContrast(),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_auto_contrast = ds.map(input_columns="image",
|
||||
operations=transforms_auto_contrast())
|
||||
|
||||
ds_auto_contrast = ds_auto_contrast.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_auto_contrast):
|
||||
if idx == 0:
|
||||
images_auto_contrast = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_auto_contrast = np.append(images_auto_contrast,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = np.mean((images_auto_contrast[i]-images_original[i])**2)
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize(images_original, images_auto_contrast)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_auto_contrast(plot=True)
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
|
||||
def visualize(image_original, image_equalize):
|
||||
"""
|
||||
visualizes the image using DE op and Numpy op
|
||||
"""
|
||||
num = len(image_equalize)
|
||||
for i in range(num):
|
||||
plt.subplot(2, num, i + 1)
|
||||
plt.imshow(image_original[i])
|
||||
plt.title("Original image")
|
||||
|
||||
plt.subplot(2, num, i + num + 1)
|
||||
plt.imshow(image_equalize[i])
|
||||
plt.title("DE Color Equalized image")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_equalize(plot=False):
|
||||
"""
|
||||
Test Equalize
|
||||
"""
|
||||
logger.info("Test Equalize")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_original = ds.map(input_columns="image",
|
||||
operations=transforms_original())
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
# Color Equalized Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_equalize = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.Equalize(),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_equalize = ds.map(input_columns="image",
|
||||
operations=transforms_equalize())
|
||||
|
||||
ds_equalize = ds_equalize.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_equalize):
|
||||
if idx == 0:
|
||||
images_equalize = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_equalize = np.append(images_equalize,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = np.mean((images_equalize[i]-images_original[i])**2)
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize(images_original, images_equalize)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_equalize(plot=True)
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
def visualize(image_original, image_invert):
|
||||
"""
|
||||
visualizes the image using DE op and Numpy op
|
||||
"""
|
||||
num = len(image_invert)
|
||||
for i in range(num):
|
||||
plt.subplot(2, num, i + 1)
|
||||
plt.imshow(image_original[i])
|
||||
plt.title("Original image")
|
||||
|
||||
plt.subplot(2, num, i + num + 1)
|
||||
plt.imshow(image_invert[i])
|
||||
plt.title("DE Color Inverted image")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_invert(plot=False):
|
||||
"""
|
||||
Test Invert
|
||||
"""
|
||||
logger.info("Test Invert")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_original = ds.map(input_columns="image",
|
||||
operations=transforms_original())
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
# Color Inverted Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_invert = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.Invert(),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_invert = ds.map(input_columns="image",
|
||||
operations=transforms_invert())
|
||||
|
||||
ds_invert = ds_invert.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_invert):
|
||||
if idx == 0:
|
||||
images_invert = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_invert = np.append(images_invert,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = np.mean((images_invert[i]-images_original[i])**2)
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize(images_original, images_invert)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_invert(plot=True)
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
|
||||
def visualize(image_original, image_random_color):
|
||||
"""
|
||||
visualizes the image using DE op and Numpy op
|
||||
"""
|
||||
num = len(image_random_color)
|
||||
for i in range(num):
|
||||
plt.subplot(2, num, i + 1)
|
||||
plt.imshow(image_original[i])
|
||||
plt.title("Original image")
|
||||
|
||||
plt.subplot(2, num, i + num + 1)
|
||||
plt.imshow(image_random_color[i])
|
||||
plt.title("DE Random Color image")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_random_color(degrees=(0.1,1.9), plot=False):
|
||||
"""
|
||||
Test RandomColor
|
||||
"""
|
||||
logger.info("Test RandomColor")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_original = ds.map(input_columns="image",
|
||||
operations=transforms_original())
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
# Random Color Adjusted Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_random_color = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.RandomColor(degrees=degrees),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_random_color = ds.map(input_columns="image",
|
||||
operations=transforms_random_color())
|
||||
|
||||
ds_random_color = ds_random_color.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_random_color):
|
||||
if idx == 0:
|
||||
images_random_color = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_random_color = np.append(images_random_color,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = np.mean((images_random_color[i]-images_original[i])**2)
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize(images_original, images_random_color)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_color()
|
||||
test_random_color(plot=True)
|
||||
test_random_color(degrees=(0.5,1.5), plot=True)
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
|
||||
def visualize(image_original, image_random_sharpness):
|
||||
"""
|
||||
visualizes the image using DE op and Numpy op
|
||||
"""
|
||||
num = len(image_random_sharpness)
|
||||
for i in range(num):
|
||||
plt.subplot(2, num, i + 1)
|
||||
plt.imshow(image_original[i])
|
||||
plt.title("Original image")
|
||||
|
||||
plt.subplot(2, num, i + num + 1)
|
||||
plt.imshow(image_random_sharpness[i])
|
||||
plt.title("DE Random Sharpness image")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_random_sharpness(degrees=(0.1,1.9), plot=False):
|
||||
"""
|
||||
Test RandomSharpness
|
||||
"""
|
||||
logger.info("Test RandomSharpness")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_original = ds.map(input_columns="image",
|
||||
operations=transforms_original())
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
# Random Sharpness Adjusted Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_random_sharpness = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.RandomSharpness(degrees=degrees),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_random_sharpness = ds.map(input_columns="image",
|
||||
operations=transforms_random_sharpness())
|
||||
|
||||
ds_random_sharpness = ds_random_sharpness.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_random_sharpness):
|
||||
if idx == 0:
|
||||
images_random_sharpness = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_random_sharpness = np.append(images_random_sharpness,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = np.mean((images_random_sharpness[i]-images_original[i])**2)
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize(images_original, images_random_sharpness)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_sharpness()
|
||||
test_random_sharpness(plot=True)
|
||||
test_random_sharpness(degrees=(0.5,1.5), plot=True)
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.py_transforms as F
|
||||
|
||||
DATA_DIR = "../data/dataset/testImageNetData/train/"
|
||||
|
||||
def visualize(image_original, image_ua):
|
||||
"""
|
||||
visualizes the image using DE op and Numpy op
|
||||
"""
|
||||
num = len(image_ua)
|
||||
for i in range(num):
|
||||
plt.subplot(2, num, i + 1)
|
||||
plt.imshow(image_original[i])
|
||||
plt.title("Original image")
|
||||
|
||||
plt.subplot(2, num, i + num + 1)
|
||||
plt.imshow(image_ua[i])
|
||||
plt.title("DE UniformAugment image")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_uniform_augment(plot=False, num_ops=2):
|
||||
"""
|
||||
Test UniformAugment
|
||||
"""
|
||||
logger.info("Test UniformAugment")
|
||||
|
||||
# Original Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms_original = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_original = ds.map(input_columns="image",
|
||||
operations=transforms_original())
|
||||
|
||||
ds_original = ds_original.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_original = np.append(images_original,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
# UniformAugment Images
|
||||
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transform_list = [F.RandomRotation(45),
|
||||
F.RandomColor(),
|
||||
F.RandomSharpness(),
|
||||
F.Invert(),
|
||||
F.AutoContrast(),
|
||||
F.Equalize()]
|
||||
|
||||
transforms_ua = F.ComposeOp([F.Decode(),
|
||||
F.Resize((224,224)),
|
||||
F.UniformAugment(transforms=transform_list, num_ops=num_ops),
|
||||
F.ToTensor()])
|
||||
|
||||
ds_ua = ds.map(input_columns="image",
|
||||
operations=transforms_ua())
|
||||
|
||||
ds_ua = ds_ua.batch(512)
|
||||
|
||||
for idx, (image,label) in enumerate(ds_ua):
|
||||
if idx == 0:
|
||||
images_ua = np.transpose(image, (0, 2,3,1))
|
||||
else:
|
||||
images_ua = np.append(images_ua,
|
||||
np.transpose(image, (0, 2,3,1)),
|
||||
axis=0)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = np.mean((images_ua[i]-images_original[i])**2)
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
if plot:
|
||||
visualize(images_original, images_ua)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_uniform_augment(num_ops=1)
|
||||
|
Loading…
Reference in New Issue