forked from mindspore-Ecosystem/mindspore
!13176 fix minddata transform issue
From: @luoyang42 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
682a8926bd
|
@ -1137,8 +1137,8 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) {
|
InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) {
|
||||||
try {
|
try {
|
||||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||||
if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
|
if (input_cv->Rank() == 1 || input_cv->Rank() > 3) {
|
||||||
RETURN_STATUS_UNEXPECTED("Affine: image shape is not <H,W,C> or channel is not 3.");
|
RETURN_STATUS_UNEXPECTED("Affine: image shape is not <H,W,C> or <H,W>.");
|
||||||
}
|
}
|
||||||
|
|
||||||
cv::Mat affine_mat(mat);
|
cv::Mat affine_mat(mat);
|
||||||
|
|
|
@ -13,9 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
This module provides APIs to load and process various common datasets such as MNIST,
|
This module provides APIs to load and process various common datasets such as MNIST,
|
||||||
CIFAR-10, CIFAR-100, VOC, ImageNet, CelebA, etc. It also supports datasets in standard
|
CIFAR-10, CIFAR-100, VOC, COCO, ImageNet, CelebA, CLUE, etc. It also supports datasets
|
||||||
format, including MindRecord, TFRecord, Manifest, etc. Users can also define their own
|
in standard format, including MindRecord, TFRecord, Manifest, etc. Users can also define
|
||||||
datasets with this module.
|
their owndatasets with this module.
|
||||||
|
|
||||||
Besides, this module provides APIs to sample data while loading.
|
Besides, this module provides APIs to sample data while loading.
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,14 @@ def check_value(value, valid_range, arg_name=""):
|
||||||
valid_range[1]))
|
valid_range[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def check_value_cutoff(value, valid_range, arg_name=""):
|
||||||
|
arg_name = pad_arg_name(arg_name)
|
||||||
|
if value < valid_range[0] or value >= valid_range[1]:
|
||||||
|
raise ValueError(
|
||||||
|
"Input {0}is not within the required interval of [{1}, {2}).".format(arg_name, valid_range[0],
|
||||||
|
valid_range[1]))
|
||||||
|
|
||||||
|
|
||||||
def check_value_normalize_std(value, valid_range, arg_name=""):
|
def check_value_normalize_std(value, valid_range, arg_name=""):
|
||||||
arg_name = pad_arg_name(arg_name)
|
arg_name = pad_arg_name(arg_name)
|
||||||
if value <= valid_range[0] or value > valid_range[1]:
|
if value <= valid_range[0] or value > valid_range[1]:
|
||||||
|
@ -404,7 +412,7 @@ def check_tensor_op(param, param_name):
|
||||||
|
|
||||||
def check_c_tensor_op(param, param_name):
|
def check_c_tensor_op(param, param_name):
|
||||||
"""check whether param is a tensor op or a callable Python function but not a py_transform"""
|
"""check whether param is a tensor op or a callable Python function but not a py_transform"""
|
||||||
if callable(param) and getattr(param, 'parse', True):
|
if callable(param) and str(param).find("py_transform") >= 0:
|
||||||
raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name))
|
raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name))
|
||||||
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
|
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
|
||||||
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
||||||
|
|
|
@ -531,7 +531,8 @@ class PythonTokenizer:
|
||||||
self.random = False
|
self.random = False
|
||||||
|
|
||||||
def __call__(self, in_array):
|
def __call__(self, in_array):
|
||||||
in_array = to_str(in_array)
|
if not isinstance(in_array, str):
|
||||||
|
in_array = to_str(in_array)
|
||||||
tokens = self.tokenizer(in_array)
|
tokens = self.tokenizer(in_array)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,8 @@ class AutoContrast(ImageTensorOperation):
|
||||||
Apply automatic contrast on input image.
|
Apply automatic contrast on input image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
|
cutoff (float, optional): Percent of pixels to cut off from the histogram,
|
||||||
|
the value must be in the range [0.0, 50.0) (default=0.0).
|
||||||
ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
|
ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -770,7 +771,7 @@ class RandomCropDecodeResize(ImageTensorOperation):
|
||||||
if img.ndim != 1 or img.dtype.type is not np.uint8:
|
if img.ndim != 1 or img.dtype.type is not np.uint8:
|
||||||
raise TypeError("Input should be an encoded image with uint8 type in 1-D NumPy format, " +
|
raise TypeError("Input should be an encoded image with uint8 type in 1-D NumPy format, " +
|
||||||
"got format:{}, dtype:{}.".format(type(img), img.dtype.type))
|
"got format:{}, dtype:{}.".format(type(img), img.dtype.type))
|
||||||
super().__call__(img=img)
|
return super().__call__(img)
|
||||||
|
|
||||||
|
|
||||||
class RandomCropWithBBox(ImageTensorOperation):
|
class RandomCropWithBBox(ImageTensorOperation):
|
||||||
|
|
|
@ -1031,7 +1031,7 @@ class RandomErasing:
|
||||||
|
|
||||||
class Cutout:
|
class Cutout:
|
||||||
"""
|
"""
|
||||||
Randomly cut (mask) out a given number of square patches from the input NumPy image array.
|
Randomly cut (mask) out a given number of square patches from the input NumPy image array of shape (C, H, W).
|
||||||
|
|
||||||
Terrance DeVries and Graham W. Taylor 'Improved Regularization of Convolutional Neural Networks with Cutout' 2017
|
Terrance DeVries and Graham W. Taylor 'Improved Regularization of Convolutional Neural Networks with Cutout' 2017
|
||||||
See https://arxiv.org/pdf/1708.04552.pdf
|
See https://arxiv.org/pdf/1708.04552.pdf
|
||||||
|
@ -1068,6 +1068,9 @@ class Cutout:
|
||||||
"""
|
"""
|
||||||
if not isinstance(np_img, np.ndarray):
|
if not isinstance(np_img, np.ndarray):
|
||||||
raise TypeError("img should be NumPy array. Got {}.".format(type(np_img)))
|
raise TypeError("img should be NumPy array. Got {}.".format(type(np_img)))
|
||||||
|
if np_img.ndim != 3:
|
||||||
|
raise TypeError('img dimension should be 3. Got {}.'.format(np_img.ndim))
|
||||||
|
|
||||||
_, image_h, image_w = np_img.shape
|
_, image_h, image_w = np_img.shape
|
||||||
scale = (self.length * self.length) / (image_h * image_w)
|
scale = (self.length * self.length) / (image_h * image_w)
|
||||||
bounded = False
|
bounded = False
|
||||||
|
@ -1426,7 +1429,8 @@ class AutoContrast:
|
||||||
Automatically maximize the contrast of the input PIL image.
|
Automatically maximize the contrast of the input PIL image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
|
cutoff (float, optional): Percent of pixels to cut off from the histogram,
|
||||||
|
the value must be in the range [0.0, 50.0) (default=0.0).
|
||||||
ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
|
ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
|
@ -56,13 +56,16 @@ def normalize(img, mean, std, pad_channel=False, dtype="float32"):
|
||||||
Returns:
|
Returns:
|
||||||
img (numpy.ndarray), Normalized image.
|
img (numpy.ndarray), Normalized image.
|
||||||
"""
|
"""
|
||||||
|
if not is_numpy(img):
|
||||||
|
raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
|
||||||
|
|
||||||
|
if img.ndim != 3:
|
||||||
|
raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
|
||||||
|
|
||||||
if np.issubdtype(img.dtype, np.integer):
|
if np.issubdtype(img.dtype, np.integer):
|
||||||
raise NotImplementedError("Unsupported image datatype: [{}], pls execute [ToTensor] before [Normalize]."
|
raise NotImplementedError("Unsupported image datatype: [{}], pls execute [ToTensor] before [Normalize]."
|
||||||
.format(img.dtype))
|
.format(img.dtype))
|
||||||
|
|
||||||
if not is_numpy(img):
|
|
||||||
raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
|
|
||||||
|
|
||||||
num_channels = img.shape[0] # shape is (C, H, W)
|
num_channels = img.shape[0] # shape is (C, H, W)
|
||||||
|
|
||||||
if len(mean) != len(std):
|
if len(mean) != len(std):
|
||||||
|
@ -119,9 +122,11 @@ def hwc_to_chw(img):
|
||||||
Returns:
|
Returns:
|
||||||
img (numpy.ndarray), Converted image.
|
img (numpy.ndarray), Converted image.
|
||||||
"""
|
"""
|
||||||
if is_numpy(img):
|
if not is_numpy(img):
|
||||||
return img.transpose(2, 0, 1).copy()
|
raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
|
||||||
raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
|
if img.ndim != 3:
|
||||||
|
raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
|
||||||
|
return img.transpose(2, 0, 1).copy()
|
||||||
|
|
||||||
|
|
||||||
def to_tensor(img, output_type):
|
def to_tensor(img, output_type):
|
||||||
|
@ -140,7 +145,7 @@ def to_tensor(img, output_type):
|
||||||
|
|
||||||
img = np.asarray(img)
|
img = np.asarray(img)
|
||||||
if img.ndim not in (2, 3):
|
if img.ndim not in (2, 3):
|
||||||
raise ValueError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
|
raise TypeError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
|
||||||
|
|
||||||
if img.ndim == 2:
|
if img.ndim == 2:
|
||||||
img = img[:, :, None]
|
img = img[:, :, None]
|
||||||
|
@ -856,8 +861,8 @@ def pad(img, padding, fill_value, padding_mode):
|
||||||
|
|
||||||
elif isinstance(padding, (tuple, list)):
|
elif isinstance(padding, (tuple, list)):
|
||||||
if len(padding) == 2:
|
if len(padding) == 2:
|
||||||
left = right = padding[0]
|
left = top = padding[0]
|
||||||
top = bottom = padding[1]
|
right = bottom = padding[1]
|
||||||
elif len(padding) == 4:
|
elif len(padding) == 4:
|
||||||
left = padding[0]
|
left = padding[0]
|
||||||
top = padding[1]
|
top = padding[1]
|
||||||
|
@ -877,10 +882,10 @@ def pad(img, padding, fill_value, padding_mode):
|
||||||
if padding_mode == 'constant':
|
if padding_mode == 'constant':
|
||||||
if img.mode == 'P':
|
if img.mode == 'P':
|
||||||
palette = img.getpalette()
|
palette = img.getpalette()
|
||||||
image = ImageOps.expand(img, border=padding, fill=fill_value)
|
image = ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value)
|
||||||
image.putpalette(palette)
|
image.putpalette(palette)
|
||||||
return image
|
return image
|
||||||
return ImageOps.expand(img, border=padding, fill=fill_value)
|
return ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value)
|
||||||
|
|
||||||
if img.mode == 'P':
|
if img.mode == 'P':
|
||||||
palette = img.getpalette()
|
palette = img.getpalette()
|
||||||
|
@ -1254,6 +1259,9 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
|
||||||
if not is_numpy(np_rgb_imgs):
|
if not is_numpy(np_rgb_imgs):
|
||||||
raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
|
raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
|
||||||
|
|
||||||
|
if not isinstance(is_hwc, bool):
|
||||||
|
raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
|
||||||
|
|
||||||
shape_size = len(np_rgb_imgs.shape)
|
shape_size = len(np_rgb_imgs.shape)
|
||||||
|
|
||||||
if not shape_size in (3, 4):
|
if not shape_size in (3, 4):
|
||||||
|
@ -1322,6 +1330,9 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
|
||||||
if not is_numpy(np_hsv_imgs):
|
if not is_numpy(np_hsv_imgs):
|
||||||
raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs)))
|
raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs)))
|
||||||
|
|
||||||
|
if not isinstance(is_hwc, bool):
|
||||||
|
raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
|
||||||
|
|
||||||
shape_size = len(np_hsv_imgs.shape)
|
shape_size = len(np_hsv_imgs.shape)
|
||||||
|
|
||||||
if not shape_size in (3, 4):
|
if not shape_size in (3, 4):
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation
|
||||||
|
|
||||||
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
||||||
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
|
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
|
||||||
check_c_tensor_op, UINT8_MAX, check_value_normalize_std
|
check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff
|
||||||
from .utils import Inter, Border, ImageBatchFormat
|
from .utils import Inter, Border, ImageBatchFormat
|
||||||
|
|
||||||
|
|
||||||
|
@ -650,7 +650,7 @@ def check_auto_contrast(method):
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[cutoff, ignore], _ = parse_user_args(method, *args, **kwargs)
|
[cutoff, ignore], _ = parse_user_args(method, *args, **kwargs)
|
||||||
type_check(cutoff, (int, float), "cutoff")
|
type_check(cutoff, (int, float), "cutoff")
|
||||||
check_value(cutoff, [0, 100], "cutoff")
|
check_value_cutoff(cutoff, [0, 50], "cutoff")
|
||||||
if ignore is not None:
|
if ignore is not None:
|
||||||
type_check(ignore, (list, tuple, int), "ignore")
|
type_check(ignore, (list, tuple, int), "ignore")
|
||||||
if isinstance(ignore, int):
|
if isinstance(ignore, int):
|
||||||
|
|
|
@ -270,7 +270,7 @@ def test_auto_contrast_invalid_cutoff_param_c():
|
||||||
data_set = data_set.map(operations=C.AutoContrast(cutoff=-10.0), input_columns="image")
|
data_set = data_set.map(operations=C.AutoContrast(cutoff=-10.0), input_columns="image")
|
||||||
except ValueError as error:
|
except ValueError as error:
|
||||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
assert "Input cutoff is not within the required interval of (0 to 100)." in str(error)
|
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
|
||||||
try:
|
try:
|
||||||
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
data_set = data_set.map(operations=[C.Decode(),
|
data_set = data_set.map(operations=[C.Decode(),
|
||||||
|
@ -280,7 +280,7 @@ def test_auto_contrast_invalid_cutoff_param_c():
|
||||||
data_set = data_set.map(operations=C.AutoContrast(cutoff=120.0), input_columns="image")
|
data_set = data_set.map(operations=C.AutoContrast(cutoff=120.0), input_columns="image")
|
||||||
except ValueError as error:
|
except ValueError as error:
|
||||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
assert "Input cutoff is not within the required interval of (0 to 100)." in str(error)
|
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
|
||||||
|
|
||||||
|
|
||||||
def test_auto_contrast_invalid_ignore_param_py():
|
def test_auto_contrast_invalid_ignore_param_py():
|
||||||
|
@ -327,7 +327,7 @@ def test_auto_contrast_invalid_cutoff_param_py():
|
||||||
input_columns=["image"])
|
input_columns=["image"])
|
||||||
except ValueError as error:
|
except ValueError as error:
|
||||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
assert "Input cutoff is not within the required interval of (0 to 100)." in str(error)
|
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
|
||||||
try:
|
try:
|
||||||
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
|
||||||
data_set = data_set.map(
|
data_set = data_set.map(
|
||||||
|
@ -338,7 +338,7 @@ def test_auto_contrast_invalid_cutoff_param_py():
|
||||||
input_columns=["image"])
|
input_columns=["image"])
|
||||||
except ValueError as error:
|
except ValueError as error:
|
||||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||||
assert "Input cutoff is not within the required interval of (0 to 100)." in str(error)
|
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.dataset.text.transforms as T
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
def test_sliding_window():
|
||||||
|
txt = ["Welcome", "to", "Beijing", "!"]
|
||||||
|
sliding_window = T.SlidingWindow(width=2)
|
||||||
|
txt = sliding_window(txt)
|
||||||
|
logger.info("Result: {}".format(txt))
|
||||||
|
|
||||||
|
expected = [['Welcome', 'to'], ['to', 'Beijing'], ['Beijing', '!']]
|
||||||
|
np.testing.assert_equal(txt, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_number():
|
||||||
|
txt = ["123456"]
|
||||||
|
to_number = T.ToNumber(mstype.int32)
|
||||||
|
txt = to_number(txt)
|
||||||
|
logger.info("Result: {}, type: {}".format(txt, type(txt[0])))
|
||||||
|
|
||||||
|
assert txt == 123456
|
||||||
|
|
||||||
|
|
||||||
|
def test_whitespace_tokenizer():
|
||||||
|
txt = "Welcome to Beijing !"
|
||||||
|
txt = T.WhitespaceTokenizer()(txt)
|
||||||
|
logger.info("Tokenize result: {}".format(txt))
|
||||||
|
|
||||||
|
expected = ['Welcome', 'to', 'Beijing', '!']
|
||||||
|
np.testing.assert_equal(txt, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_tokenizer():
|
||||||
|
# whitespace tokenizer
|
||||||
|
def my_tokenizer(line):
|
||||||
|
words = line.split()
|
||||||
|
if not words:
|
||||||
|
return [""]
|
||||||
|
return words
|
||||||
|
txt = "Welcome to Beijing !"
|
||||||
|
txt = T.PythonTokenizer(my_tokenizer)(txt)
|
||||||
|
logger.info("Tokenize result: {}".format(txt))
|
||||||
|
|
||||||
|
expected = ['Welcome', 'to', 'Beijing', '!']
|
||||||
|
np.testing.assert_equal(txt, expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_sliding_window()
|
||||||
|
test_to_number()
|
||||||
|
test_whitespace_tokenizer()
|
||||||
|
test_python_tokenizer()
|
Loading…
Reference in New Issue