dataset fixes: Update OneHot API docs; fixup Python UTs

This commit is contained in:
Cathy Wong 2020-08-26 17:53:28 -04:00
parent 39e2791149
commit 7f6782be2a
12 changed files with 61 additions and 54 deletions

View File

@ -36,7 +36,7 @@ from .. import callback
def check_imagefolderdatasetv2(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
"""A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDatasetV2)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -62,7 +62,7 @@ def check_imagefolderdatasetv2(method):
def check_mnist_cifar_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -85,7 +85,7 @@ def check_mnist_cifar_dataset(method):
def check_manifestdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -112,7 +112,7 @@ def check_manifestdataset(method):
def check_tfrecorddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -138,7 +138,7 @@ def check_tfrecorddataset(method):
def check_vocdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(VOCDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -179,7 +179,7 @@ def check_vocdataset(method):
def check_cocodataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CocoDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -215,7 +215,7 @@ def check_cocodataset(method):
def check_celebadataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CelebADataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -251,7 +251,7 @@ def check_celebadataset(method):
def check_save(method):
"""A wrapper that wrap a parameter checker to the save op."""
"""A wrapper that wraps a parameter checker around the saved operator."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -271,7 +271,7 @@ def check_save(method):
def check_minddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -303,7 +303,7 @@ def check_minddataset(method):
def check_generatordataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -369,7 +369,7 @@ def check_generatordataset(method):
def check_random_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -794,7 +794,7 @@ def check_add_column(method):
def check_cluedataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -824,7 +824,7 @@ def check_cluedataset(method):
def check_csvdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CSVDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -871,7 +871,7 @@ def check_csvdataset(method):
def check_textfiledataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -964,7 +964,7 @@ def check_gnn_graphdata(method):
def check_gnn_get_all_nodes(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -977,7 +977,7 @@ def check_gnn_get_all_nodes(method):
def check_gnn_get_all_edges(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_edges` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_all_edges` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -990,7 +990,7 @@ def check_gnn_get_all_edges(method):
def check_gnn_get_nodes_from_edges(method):
"""A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_nodes_from_edges` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1003,7 +1003,7 @@ def check_gnn_get_nodes_from_edges(method):
def check_gnn_get_all_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1018,7 +1018,7 @@ def check_gnn_get_all_neighbors(method):
def check_gnn_get_sampled_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_sampled_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1046,7 +1046,7 @@ def check_gnn_get_sampled_neighbors(method):
def check_gnn_get_neg_sampled_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_neg_sampled_neighbors` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1062,7 +1062,7 @@ def check_gnn_get_neg_sampled_neighbors(method):
def check_gnn_random_walk(method):
"""A wrapper that wraps a parameter checker to the GNN `random_walk` function."""
"""A wrapper that wraps a parameter checker around the GNN `random_walk` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1110,7 +1110,7 @@ def check_aligned_list(param, param_name, member_type):
def check_gnn_get_node_feature(method):
"""A wrapper that wraps a parameter checker to the GNN `get_node_feature` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_node_feature` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1132,7 +1132,7 @@ def check_gnn_get_node_feature(method):
def check_gnn_get_edge_feature(method):
"""A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_edge_feature` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1154,7 +1154,7 @@ def check_gnn_get_edge_feature(method):
def check_numpyslicesdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -1195,17 +1195,17 @@ def check_numpyslicesdataset(method):
def check_paddeddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(PaddedDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
paddedSamples = param_dict.get("padded_samples")
if not paddedSamples:
padded_samples = param_dict.get("padded_samples")
if not padded_samples:
raise ValueError("Argument padded_samples cannot be empty")
type_check(paddedSamples, (list,), "padded_samples")
type_check(paddedSamples[0], (dict,), "padded_element")
type_check(padded_samples, (list,), "padded_samples")
type_check(padded_samples[0], (dict,), "padded_element")
return method(self, *args, **kwargs)
return new_method

View File

@ -328,7 +328,7 @@ def check_from_dataset(method):
return new_method
def check_slidingwindow(method):
"""A wrapper that wrap a parameter checker to the original function(sliding window operation)."""
"""A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -496,4 +496,3 @@ def check_save_model(method):
return method(self, *args, **kwargs)
return new_method

View File

@ -31,8 +31,8 @@ class OneHot(cde.OneHotOp):
Tensor operation to apply one hot encoding.
Args:
num_classes (int): Number of classes of the label
it should be bigger than largest label number in dataset.
num_classes (int): Number of classes of the label.
It should be larger than the largest label number in the dataset.
Raises:
RuntimeError: feature size is bigger than num_classes.

View File

@ -27,8 +27,9 @@ class OneHotOp:
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.
Args:
num_classes (int): Num class of object in dataset, type is int and value over 0.
smoothing_rate (float): The adjustable Hyper parameter decides the label smoothing level , 0.0 means not do it.
num_classes (int): Number of classes of objects in dataset. Value must be larger than 0.
smoothing_rate (float, optional): Adjustable hyperparameter for label smoothing level.
(Default=0.0 means no smoothing is applied.)
"""
@check_one_hot_op

View File

@ -152,7 +152,7 @@ def check_erasing_value(value):
def check_crop(method):
"""A wrapper that wraps a parameter checker to the original function(crop operation)."""
"""A wrapper that wraps a parameter checker around the original function(crop operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -165,7 +165,7 @@ def check_crop(method):
def check_posterize(method):
""""A wrapper that wraps a parameter checker to the original function(posterize operation)."""
"""A wrapper that wraps a parameter checker around the original function(posterize operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -187,7 +187,7 @@ def check_posterize(method):
def check_resize_interpolation(method):
"""A wrapper that wraps a parameter checker to the original function(resize interpolation operation)."""
"""A wrapper that wraps a parameter checker around the original function(resize interpolation operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -202,7 +202,7 @@ def check_resize_interpolation(method):
def check_resize(method):
"""A wrapper that wraps a parameter checker to the original function(resize operation)."""
"""A wrapper that wraps a parameter checker around the original function(resize operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -235,7 +235,7 @@ def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
def check_random_resize_crop(method):
"""A wrapper that wraps a parameter checker to the original function(random resize crop operation)."""
"""A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -250,7 +250,7 @@ def check_random_resize_crop(method):
def check_prob(method):
"""A wrapper that wraps a parameter checker(check the probability) to the original function."""
"""A wrapper that wraps a parameter checker (to confirm probability) around the original function."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -264,7 +264,7 @@ def check_prob(method):
def check_normalize_c(method):
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in C++)."""
"""A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -277,7 +277,7 @@ def check_normalize_c(method):
def check_normalize_py(method):
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in Python)."""
"""A wrapper that wraps a parameter checker around the original function(normalize operation written in Python)."""
@wraps(method)
def new_method(self, *args, **kwargs):

View File

@ -86,8 +86,13 @@ def test_five_crop_error_msg():
transform = vision.ComposeOp(transforms)
data = data.map(input_columns=["image"], operations=transform())
with pytest.raises(RuntimeError):
data.create_tuple_iterator().__next__()
with pytest.raises(RuntimeError) as info:
for _ in data:
pass
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
# error msg comes from ToTensor()
assert error_msg in str(info.value)
def test_five_crop_md5():

View File

@ -149,7 +149,7 @@ def test_random_color_py_md5():
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms = F.ComposeOp([F.Decode(),
F.RandomColor((0.1, 1.9)),
F.RandomColor((2.0, 2.5)),
F.ToTensor()])
data = data.map(input_columns="image", operations=transforms())
@ -244,12 +244,12 @@ def test_random_color_c_errors():
if __name__ == "__main__":
test_random_color_py()
test_random_color_py(plot=True)
test_random_color_py(degrees=(0.5, 1.5), plot=True)
test_random_color_py(degrees=(2.0, 2.5), plot=True) # Test with degree values that show more obvious transformation
test_random_color_py_md5()
test_random_color_c()
test_random_color_c(plot=True)
test_random_color_c(degrees=(0.5, 1.5), plot=True, run_golden=False)
test_random_color_c(degrees=(2.0, 2.5), plot=True, run_golden=False) # Test with degree values that show more obvious transformation
test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False)
test_compare_random_color_op(plot=True)
test_random_color_c_errors()

View File

@ -103,7 +103,7 @@ def test_random_sharpness_py_md5():
# define map operations
transforms = [
F.Decode(),
F.RandomSharpness((0.1, 1.9)),
F.RandomSharpness((20.0, 25.0)),
F.ToTensor()
]
transform = F.ComposeOp(transforms)
@ -193,7 +193,7 @@ def test_random_sharpness_c_md5():
# define map operations
transforms = [
C.Decode(),
C.RandomSharpness((0.1, 1.9))
C.RandomSharpness((10.0, 15.0))
]
# Generate dataset
@ -337,14 +337,16 @@ def test_random_sharpness_invalid_params():
if __name__ == "__main__":
test_random_sharpness_py(plot=True)
test_random_sharpness_py(None, plot=True) # test with default values
test_random_sharpness_py(None, plot=True) # Test with default values
test_random_sharpness_py(degrees=(20.0, 25.0), plot=True) # Test with degree values that show more obvious transformation
test_random_sharpness_py_md5()
test_random_sharpness_c(plot=True)
test_random_sharpness_c(None, plot=True) # test with default values
test_random_sharpness_c(degrees=[10, 15], plot=True) # Test with degrees values that show more obvious transformation
test_random_sharpness_c_md5()
test_random_sharpness_c_py(degrees=[1.5, 1.5], plot=True)
test_random_sharpness_c_py(degrees=[1, 1], plot=True)
test_random_sharpness_c_py(degrees=[10, 10], plot=True)
test_random_sharpness_one_channel_c(degrees=[1.7, 1.7], plot=True)
test_random_sharpness_one_channel_c(degrees=None, plot=True) # test with default values
test_random_sharpness_one_channel_c(degrees=None, plot=True) # Test with default values
test_random_sharpness_invalid_params()

View File

@ -303,7 +303,7 @@ def test_repeat_count0():
with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(0)
assert "count" in str(info)
assert "count" in str(info.value)
def test_repeat_countneg2():
"""
@ -313,7 +313,7 @@ def test_repeat_countneg2():
with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(-2)
assert "count" in str(info)
assert "count" in str(info.value)
if __name__ == "__main__":
test_tf_repeat_01()