fix RandomAffine degree error and some examples in docs

This commit is contained in:
Xiao Tianci 2021-03-11 11:57:24 +08:00
parent 3b74db3fad
commit 4a7100c7f9
25 changed files with 64 additions and 57 deletions

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <algorithm>
#include <cmath>
#include <random>
#include <utility>
#include <vector>
@ -72,7 +73,7 @@ Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared
float_t shear_y = 0.0;
RETURN_IF_NOT_OK(GenerateRealNumber(shear_ranges_[2], shear_ranges_[3], &rnd_, &shear_y));
// assign to base class variables
degrees_ = degrees;
degrees_ = fmod(degrees, 360.0);
scale_ = scale;
translation_[0] = translation_x;
translation_[1] = translation_y;

View File

@ -44,6 +44,7 @@ valid_detype = [
"uint32", "uint64", "float16", "float32", "float64", "string"
]
def is_iterable(obj):
"""
Helper function to check if object is iterable.
@ -60,6 +61,7 @@ def is_iterable(obj):
return False
return True
def pad_arg_name(arg_name):
if arg_name != "":
arg_name = arg_name + " "
@ -70,7 +72,7 @@ def check_value(value, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value < valid_range[0] or value > valid_range[1]:
raise ValueError(
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
"Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0],
valid_range[1]))
@ -86,7 +88,7 @@ def check_value_normalize_std(value, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if value <= valid_range[0] or value > valid_range[1]:
raise ValueError(
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
"Input {0}is not within the required interval of ({1}, {2}].".format(arg_name, valid_range[0],
valid_range[1]))
@ -94,7 +96,7 @@ def check_range(values, valid_range, arg_name=""):
arg_name = pad_arg_name(arg_name)
if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
raise ValueError(
"Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0],
"Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0],
valid_range[1]))

View File

@ -506,12 +506,12 @@ class Dataset:
Dataset, dataset applied by the function.
Examples:
>>> # use NumpySliceDataset as an example
>>> # use NumpySlicesDataset as an example
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
>>>
>>> def flat_map_func(array):
... # create a NumpySliceDataset with the array
... dataset = ds.NumpySliceDataset(array)
... # create a NumpySlicesDataset with the array
... dataset = ds.NumpySlicesDataset(array)
... # repeat the dataset twice
... dataset = dataset.repeat(2)
... return dataset
@ -3429,6 +3429,8 @@ class GeneratorDataset(MappableDataset):
option could be beneficial if the Python operation is computational heavy (default=True).
Examples:
>>> import numpy as np
>>>
>>> # 1) Multidimensional generator function as callable input.
>>> def generator_multidimensional():
... for i in range(64):

View File

@ -24,18 +24,16 @@ and use Lookup to find the index of tokens in Vocab.
class attributes (self.xxx) to support save() and load().
Examples:
>>> text_file_dataset_dir = "/path/to/text_file_dataset_file"
>>> text_file_dataset_dir = ["/path/to/text_file_dataset_file"] # contains 1 or multiple text files
>>> # Create a dataset for text sentences saved as line data in a file
>>> text_file_dataset = ds.TextFileDataset(text_file_dataset_dir, shuffle=False)
>>> text_file_dataset = ds.TextFileDataset(dataset_files=text_file_dataset_dir, shuffle=False)
>>> # Tokenize sentences to unicode characters
>>> tokenizer = text.UnicodeCharTokenizer()
>>> # Load vocabulary from list
>>> vocab = text.Vocab.from_list(['', '', '', '', ''])
>>> vocab = text.Vocab.from_list(word_list=['', '', '', '', ''])
>>> # Use Lookup operator to map tokens to ids
>>> lookup = text.Lookup(vocab)
>>> lookup = text.Lookup(vocab=vocab)
>>> text_file_dataset = text_file_dataset.map(operations=[tokenizer, lookup])
>>> for i in text_file_dataset.create_dict_iterator():
... print(i)
>>> # if text line in dataset_file is:
>>> # 深圳欢迎您
>>> # then the output will be:

View File

@ -125,10 +125,12 @@ def check_degrees(degrees):
"""Check if the degrees is legal."""
type_check(degrees, (numbers.Number, list, tuple), "degrees")
if isinstance(degrees, numbers.Number):
check_value(degrees, (0, float("inf")), "degrees")
check_pos_float32(degrees, "degrees")
elif isinstance(degrees, (list, tuple)):
if len(degrees) == 2:
type_check_list(degrees, (numbers.Number,), "degrees")
for value in degrees:
check_float32(value, "degrees")
if degrees[0] > degrees[1]:
raise ValueError("degrees should be in (min,max) format. Got (max,min).")
else:

View File

@ -477,7 +477,7 @@ def test_batch_exception_15():
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
except ValueError as e:
err_msg = str(e)
assert "batch_size is not within the required interval of (1 to 2147483647)" in err_msg
assert "batch_size is not within the required interval of [1, 2147483647]" in err_msg
if __name__ == '__main__':

View File

@ -243,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c():
column_order=["image", "bbox"]) # Add column for "bbox"
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error)
assert "Input ratio is not within the required interval of [0.0, 1.0]." in str(error)
def test_bounding_box_augment_invalid_bounds_c():

View File

@ -181,7 +181,7 @@ def test_numpy_slices_distributed_shard_limit():
num = sys.maxsize
with pytest.raises(ValueError) as err:
de.NumpySlicesDataset(np_data, num_shards=num, shard_id=0, shuffle=False)
assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value)
assert "Input num_shards is not within the required interval of [1, 2147483647]." in str(err.value)
def test_numpy_slices_distributed_zero_shard():
@ -190,7 +190,7 @@ def test_numpy_slices_distributed_zero_shard():
np_data = [1, 2, 3]
with pytest.raises(ValueError) as err:
de.NumpySlicesDataset(np_data, num_shards=0, shard_id=0, shuffle=False)
assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value)
assert "Input num_shards is not within the required interval of [1, 2147483647]." in str(err.value)
def test_numpy_slices_sequential_sampler():

View File

@ -201,7 +201,7 @@ def test_minddataset_invalidate_num_shards():
for _ in data_set.create_dict_iterator(num_epochs=1):
num_iter += 1
try:
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value)
assert 'Input shard_id is not within the required interval of [0, 0].' in str(error_info.value)
except Exception as error:
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
@ -221,7 +221,7 @@ def test_minddataset_invalidate_shard_id():
for _ in data_set.create_dict_iterator(num_epochs=1):
num_iter += 1
try:
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value)
assert 'Input shard_id is not within the required interval of [0, 0].' in str(error_info.value)
except Exception as error:
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
@ -241,7 +241,7 @@ def test_minddataset_shard_id_bigger_than_num_shard():
for _ in data_set.create_dict_iterator(num_epochs=1):
num_iter += 1
try:
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value)
assert 'Input shard_id is not within the required interval of [0, 1].' in str(error_info.value)
except Exception as error:
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
@ -253,7 +253,7 @@ def test_minddataset_shard_id_bigger_than_num_shard():
for _ in data_set.create_dict_iterator(num_epochs=1):
num_iter += 1
try:
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value)
assert 'Input shard_id is not within the required interval of [0, 1].' in str(error_info.value)
except Exception as error:
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
@ -277,6 +277,7 @@ def test_cv_minddataset_partition_num_samples_equals_0():
num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1):
num_iter += 1
with pytest.raises(ValueError) as error_info:
partitions(5)
try:
@ -289,8 +290,10 @@ def test_cv_minddataset_partition_num_samples_equals_0():
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_mindrecord_exception():
"""tutorial for exception scenario of minderdataset + map would print error info."""
def exception_func(item):
raise Exception("Error occur!")

View File

@ -33,7 +33,7 @@ def normalize_np(image, mean, std):
"""
Apply the normalization
"""
# DE decodes the image in RGB by deafult, hence
# DE decodes the image in RGB by default, hence
# the values here are in RGB
image = np.array(image, np.float32)
image = image - np.array(mean)
@ -300,7 +300,7 @@ def test_normalize_exception_invalid_range_py():
_ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e)
assert "Input mean_value is not within the required interval of [0.0, 1.0]." in str(e)
def test_normalize_grayscale_md5_01():

View File

@ -33,7 +33,7 @@ def normalizepad_np(image, mean, std):
"""
Apply the normalize+pad
"""
# DE decodes the image in RGB by deafult, hence
# DE decodes the image in RGB by default, hence
# the values here are in RGB
image = np.array(image, np.float32)
image = image - np.array(mean)
@ -198,4 +198,4 @@ def test_normalizepad_exception_invalid_range_py():
_ = py_vision.NormalizePad([0.75, 1.25, 0.5], [0.1, 0.18, 1.32])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e)
assert "Input mean_value is not within the required interval of [0.0, 1.0]." in str(e)

View File

@ -211,7 +211,7 @@ def test_random_affine_exception_negative_degrees():
_ = py_vision.RandomAffine(degrees=-15)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input degrees is not within the required interval of (0 to inf)."
assert str(e) == "Input degrees is not within the required interval of [0, 16777216]."
def test_random_affine_exception_translation_range():
@ -223,13 +223,13 @@ def test_random_affine_exception_translation_range():
_ = c_vision.RandomAffine(degrees=15, translate=(0.1, 1.5))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input translate at 1 is not within the required interval of (-1.0 to 1.0)."
assert str(e) == "Input translate at 1 is not within the required interval of [-1.0, 1.0]."
logger.info("test_random_affine_exception_translation_range")
try:
_ = c_vision.RandomAffine(degrees=15, translate=(-2, 1.5))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input translate at 0 is not within the required interval of (-1.0 to 1.0)."
assert str(e) == "Input translate at 0 is not within the required interval of [-1.0, 1.0]."
def test_random_affine_exception_scale_value():

View File

@ -260,7 +260,7 @@ def test_random_crop_with_bbox_op_bad_padding():
break
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input padding is not within the required interval of (0 to 2147483647)." in str(err)
assert "Input padding is not within the required interval of [0, 2147483647]." in str(err)
try:
test_op = c_vision.RandomCropWithBBox([512, 512], padding=[16777216, 16777216, 16777216, 16777216])

View File

@ -188,7 +188,7 @@ def test_random_grayscale_invalid_param():
data = data.map(operations=transform, input_columns=["image"])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
if __name__ == "__main__":

View File

@ -143,7 +143,7 @@ def test_random_horizontal_invalid_prob_c():
data = data.map(operations=random_horizontal_op, input_columns=["image"])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
def test_random_horizontal_invalid_prob_py():
@ -166,7 +166,7 @@ def test_random_horizontal_invalid_prob_py():
data = data.map(operations=transform, input_columns=["image"])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e)
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
def test_random_horizontal_comp(plot=False):

View File

@ -185,7 +185,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
column_order=["image", "bbox"]) # Add column for "bbox"
except ValueError as error:
logger.info("Got an exception in DE: {}".format(str(error)))
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error)
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(error)
def test_random_horizontal_flip_with_bbox_invalid_bounds_c():

View File

@ -24,7 +24,6 @@ from mindspore import log as logger
from util import visualize_list, save_and_check_md5, \
config_get_set_seed, config_get_set_num_parallel_workers
GENERATE_GOLDEN = False
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
@ -109,7 +108,7 @@ def test_random_perspective_exception_distortion_scale_range():
_ = py_vision.RandomPerspective(distortion_scale=1.5)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)."
assert str(e) == "Input distortion_scale is not within the required interval of [0.0, 1.0]."
def test_random_perspective_exception_prob_range():
@ -121,7 +120,7 @@ def test_random_perspective_exception_prob_range():
_ = py_vision.RandomPerspective(prob=1.2)
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)."
assert str(e) == "Input prob is not within the required interval of [0.0, 1.0]."
if __name__ == "__main__":

View File

@ -168,19 +168,19 @@ def test_random_posterize_exception_bit():
_ = c_vision.RandomPosterize((1, 9))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required interval of (1 to 8)."
assert str(e) == "Input is not within the required interval of [1, 8]."
# Test min < 1
try:
_ = c_vision.RandomPosterize((0, 7))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required interval of (1 to 8)."
assert str(e) == "Input is not within the required interval of [1, 8]."
# Test max < min
try:
_ = c_vision.RandomPosterize((8, 1))
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Input is not within the required interval of (1 to 8)."
assert str(e) == "Input is not within the required interval of [1, 8]."
# Test wrong type (not uint8)
try:
_ = c_vision.RandomPosterize(1.1)

View File

@ -160,7 +160,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not within the required interval of (1 to 16777216)." in str(err)
assert "Input is not within the required interval of [1, 16777216]." in str(err)
try:
# one of the size values is zero
@ -168,7 +168,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err)
assert "Input size at dim 0 is not within the required interval of [1, 2147483647]." in str(err)
try:
# negative value for resize
@ -176,7 +176,7 @@ def test_random_resize_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not within the required interval of (1 to 16777216)." in str(err)
assert "Input is not within the required interval of [1, 16777216]." in str(err)
try:
# invalid input shape

View File

@ -43,7 +43,7 @@ def test_random_select_subpolicy():
assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]])
assert "op of (op, prob) in policy[1][0] is neither a c_transform op (TensorOperation) nor a callable pyfunc" \
in test_config([[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]])
assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [
assert "prob of (op, prob) policy[1][0] is not within the required interval of [0, 1]" in test_config([[1]], [
[(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]])

View File

@ -110,7 +110,7 @@ def test_random_solarize_errors():
with pytest.raises(ValueError) as error_info:
vision.RandomSolarize((12, 1000))
assert "Input is not within the required interval of (0 to 255)." in str(error_info.value)
assert "Input is not within the required interval of [0, 255]." in str(error_info.value)
with pytest.raises(TypeError) as error_info:
vision.RandomSolarize((122.1, 140))

View File

@ -143,7 +143,7 @@ def test_random_vertical_invalid_prob_c():
data = data.map(operations=random_horizontal_op, input_columns=["image"])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
assert 'Input prob is not within the required interval of [0.0, 1.0].' in str(e)
def test_random_vertical_invalid_prob_py():
@ -165,7 +165,7 @@ def test_random_vertical_invalid_prob_py():
data = data.map(operations=transform, input_columns=["image"])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e)
assert 'Input prob is not within the required interval of [0.0, 1.0].' in str(e)
def test_random_vertical_comp(plot=False):

View File

@ -187,7 +187,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err)
assert "Input prob is not within the required interval of [0.0, 1.0]." in str(err)
def test_random_vertical_flip_with_bbox_op_bad_c():

View File

@ -149,7 +149,7 @@ def test_shuffle_exception_01():
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e)
def test_shuffle_exception_02():
@ -167,7 +167,7 @@ def test_shuffle_exception_02():
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e)
def test_shuffle_exception_03():
@ -185,7 +185,7 @@ def test_shuffle_exception_03():
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e)
assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e)
def test_shuffle_exception_05():

View File

@ -146,7 +146,7 @@ def test_ten_crop_invalid_size_error_msg():
vision.TenCrop(0),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
error_msg = "Input is not within the required interval of (1 to 16777216)."
error_msg = "Input is not within the required interval of [1, 16777216]."
assert error_msg == str(info.value)
with pytest.raises(ValueError) as info: