!13176 fix minddata transform issue

From: @luoyang42
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-12 17:01:56 +08:00 committed by Gitee
commit 682a8926bd
10 changed files with 120 additions and 28 deletions

View File

@ -1137,8 +1137,8 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("Affine: image shape is not <H,W,C> or channel is not 3.");
if (input_cv->Rank() == 1 || input_cv->Rank() > 3) {
RETURN_STATUS_UNEXPECTED("Affine: image shape is not <H,W,C> or <H,W>.");
}
cv::Mat affine_mat(mat);

View File

@ -13,9 +13,9 @@
# limitations under the License.
"""
This module provides APIs to load and process various common datasets such as MNIST,
CIFAR-10, CIFAR-100, VOC, ImageNet, CelebA, etc. It also supports datasets in standard
format, including MindRecord, TFRecord, Manifest, etc. Users can also define their own
datasets with this module.
CIFAR-10, CIFAR-100, VOC, COCO, ImageNet, CelebA, CLUE, etc. It also supports datasets
in standard format, including MindRecord, TFRecord, Manifest, etc. Users can also define
their owndatasets with this module.
Besides, this module provides APIs to sample data while loading.

View File

@ -74,6 +74,14 @@ def check_value(value, valid_range, arg_name=""):
valid_range[1]))
def check_value_cutoff(value, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value < valid_range[0] or value >= valid_range[1]:
raise ValueError(
"Input {0}is not within the required interval of [{1}, {2}).".format(arg_name, valid_range[0],
valid_range[1]))
def check_value_normalize_std(value, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value <= valid_range[0] or value > valid_range[1]:
@ -404,7 +412,7 @@ def check_tensor_op(param, param_name):
def check_c_tensor_op(param, param_name):
"""check whether param is a tensor op or a callable Python function but not a py_transform"""
if callable(param) and getattr(param, 'parse', True):
if callable(param) and str(param).find("py_transform") >= 0:
raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name))
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))

View File

@ -531,7 +531,8 @@ class PythonTokenizer:
self.random = False
def __call__(self, in_array):
in_array = to_str(in_array)
if not isinstance(in_array, str):
in_array = to_str(in_array)
tokens = self.tokenizer(in_array)
return tokens

View File

@ -104,7 +104,8 @@ class AutoContrast(ImageTensorOperation):
Apply automatic contrast on input image.
Args:
cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
cutoff (float, optional): Percent of pixels to cut off from the histogram,
the value must be in the range [0.0, 50.0) (default=0.0).
ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
Examples:
@ -770,7 +771,7 @@ class RandomCropDecodeResize(ImageTensorOperation):
if img.ndim != 1 or img.dtype.type is not np.uint8:
raise TypeError("Input should be an encoded image with uint8 type in 1-D NumPy format, " +
"got format:{}, dtype:{}.".format(type(img), img.dtype.type))
super().__call__(img=img)
return super().__call__(img)
class RandomCropWithBBox(ImageTensorOperation):

View File

@ -1031,7 +1031,7 @@ class RandomErasing:
class Cutout:
"""
Randomly cut (mask) out a given number of square patches from the input NumPy image array.
Randomly cut (mask) out a given number of square patches from the input NumPy image array of shape (C, H, W).
Terrance DeVries and Graham W. Taylor 'Improved Regularization of Convolutional Neural Networks with Cutout' 2017
See https://arxiv.org/pdf/1708.04552.pdf
@ -1068,6 +1068,9 @@ class Cutout:
"""
if not isinstance(np_img, np.ndarray):
raise TypeError("img should be NumPy array. Got {}.".format(type(np_img)))
if np_img.ndim != 3:
raise TypeError('img dimension should be 3. Got {}.'.format(np_img.ndim))
_, image_h, image_w = np_img.shape
scale = (self.length * self.length) / (image_h * image_w)
bounded = False
@ -1426,7 +1429,8 @@ class AutoContrast:
Automatically maximize the contrast of the input PIL image.
Args:
cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
cutoff (float, optional): Percent of pixels to cut off from the histogram,
the value must be in the range [0.0, 50.0) (default=0.0).
ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
Examples:

View File

@ -56,13 +56,16 @@ def normalize(img, mean, std, pad_channel=False, dtype="float32"):
Returns:
img (numpy.ndarray), Normalized image.
"""
if not is_numpy(img):
raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
if img.ndim != 3:
raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
if np.issubdtype(img.dtype, np.integer):
raise NotImplementedError("Unsupported image datatype: [{}], pls execute [ToTensor] before [Normalize]."
.format(img.dtype))
if not is_numpy(img):
raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
num_channels = img.shape[0] # shape is (C, H, W)
if len(mean) != len(std):
@ -119,9 +122,11 @@ def hwc_to_chw(img):
Returns:
img (numpy.ndarray), Converted image.
"""
if is_numpy(img):
return img.transpose(2, 0, 1).copy()
raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
if not is_numpy(img):
raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
if img.ndim != 3:
raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
return img.transpose(2, 0, 1).copy()
def to_tensor(img, output_type):
@ -140,7 +145,7 @@ def to_tensor(img, output_type):
img = np.asarray(img)
if img.ndim not in (2, 3):
raise ValueError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
raise TypeError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
if img.ndim == 2:
img = img[:, :, None]
@ -856,8 +861,8 @@ def pad(img, padding, fill_value, padding_mode):
elif isinstance(padding, (tuple, list)):
if len(padding) == 2:
left = right = padding[0]
top = bottom = padding[1]
left = top = padding[0]
right = bottom = padding[1]
elif len(padding) == 4:
left = padding[0]
top = padding[1]
@ -877,10 +882,10 @@ def pad(img, padding, fill_value, padding_mode):
if padding_mode == 'constant':
if img.mode == 'P':
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, fill=fill_value)
image = ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, fill=fill_value)
return ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value)
if img.mode == 'P':
palette = img.getpalette()
@ -1254,6 +1259,9 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
if not is_numpy(np_rgb_imgs):
raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
if not isinstance(is_hwc, bool):
raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
shape_size = len(np_rgb_imgs.shape)
if not shape_size in (3, 4):
@ -1322,6 +1330,9 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
if not is_numpy(np_hsv_imgs):
raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs)))
if not isinstance(is_hwc, bool):
raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
shape_size = len(np_hsv_imgs.shape)
if not shape_size in (3, 4):

View File

@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
check_c_tensor_op, UINT8_MAX, check_value_normalize_std
check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff
from .utils import Inter, Border, ImageBatchFormat
@ -650,7 +650,7 @@ def check_auto_contrast(method):
def new_method(self, *args, **kwargs):
[cutoff, ignore], _ = parse_user_args(method, *args, **kwargs)
type_check(cutoff, (int, float), "cutoff")
check_value(cutoff, [0, 100], "cutoff")
check_value_cutoff(cutoff, [0, 50], "cutoff")
if ignore is not None:
type_check(ignore, (list, tuple, int), "ignore")
if isinstance(ignore, int):

View File

@ -270,7 +270,7 @@ def test_auto_contrast_invalid_cutoff_param_c():
data_set = data_set.map(operations=C.AutoContrast(cutoff=-10.0), input_columns="image")
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input cutoff is not within the required interval of (0 to 100)." in str(error)
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data_set = data_set.map(operations=[C.Decode(),
@ -280,7 +280,7 @@ def test_auto_contrast_invalid_cutoff_param_c():
data_set = data_set.map(operations=C.AutoContrast(cutoff=120.0), input_columns="image")
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input cutoff is not within the required interval of (0 to 100)." in str(error)
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
def test_auto_contrast_invalid_ignore_param_py():
@ -327,7 +327,7 @@ def test_auto_contrast_invalid_cutoff_param_py():
input_columns=["image"])
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input cutoff is not within the required interval of (0 to 100)." in str(error)
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
try:
data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
data_set = data_set.map(
@ -338,7 +338,7 @@ def test_auto_contrast_invalid_cutoff_param_py():
input_columns=["image"])
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input cutoff is not within the required interval of (0 to 100)." in str(error)
assert "Input cutoff is not within the required interval of [0, 50)." in str(error)
if __name__ == "__main__":

View File

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset.text.transforms as T
import mindspore.common.dtype as mstype
from mindspore import log as logger
def test_sliding_window():
txt = ["Welcome", "to", "Beijing", "!"]
sliding_window = T.SlidingWindow(width=2)
txt = sliding_window(txt)
logger.info("Result: {}".format(txt))
expected = [['Welcome', 'to'], ['to', 'Beijing'], ['Beijing', '!']]
np.testing.assert_equal(txt, expected)
def test_to_number():
txt = ["123456"]
to_number = T.ToNumber(mstype.int32)
txt = to_number(txt)
logger.info("Result: {}, type: {}".format(txt, type(txt[0])))
assert txt == 123456
def test_whitespace_tokenizer():
txt = "Welcome to Beijing !"
txt = T.WhitespaceTokenizer()(txt)
logger.info("Tokenize result: {}".format(txt))
expected = ['Welcome', 'to', 'Beijing', '!']
np.testing.assert_equal(txt, expected)
def test_python_tokenizer():
# whitespace tokenizer
def my_tokenizer(line):
words = line.split()
if not words:
return [""]
return words
txt = "Welcome to Beijing !"
txt = T.PythonTokenizer(my_tokenizer)(txt)
logger.info("Tokenize result: {}".format(txt))
expected = ['Welcome', 'to', 'Beijing', '!']
np.testing.assert_equal(txt, expected)
if __name__ == '__main__':
test_sliding_window()
test_to_number()
test_whitespace_tokenizer()
test_python_tokenizer()