!5385 dataset fixes: Update OneHot API doc; fixup UTs
Merge pull request !5385 from cathwong/ckw_dataset_ut_cleanup8
This commit is contained in:
commit
ab29dbf98b
|
@ -36,7 +36,7 @@ from .. import callback
|
|||
|
||||
|
||||
def check_imagefolderdatasetv2(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDatasetV2)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -62,7 +62,7 @@ def check_imagefolderdatasetv2(method):
|
|||
|
||||
|
||||
def check_mnist_cifar_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -85,7 +85,7 @@ def check_mnist_cifar_dataset(method):
|
|||
|
||||
|
||||
def check_manifestdataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -112,7 +112,7 @@ def check_manifestdataset(method):
|
|||
|
||||
|
||||
def check_tfrecorddataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -138,7 +138,7 @@ def check_tfrecorddataset(method):
|
|||
|
||||
|
||||
def check_vocdataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(VOCDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -179,7 +179,7 @@ def check_vocdataset(method):
|
|||
|
||||
|
||||
def check_cocodataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(CocoDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -215,7 +215,7 @@ def check_cocodataset(method):
|
|||
|
||||
|
||||
def check_celebadataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(CelebADataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -251,7 +251,7 @@ def check_celebadataset(method):
|
|||
|
||||
|
||||
def check_save(method):
|
||||
"""A wrapper that wrap a parameter checker to the save op."""
|
||||
"""A wrapper that wraps a parameter checker around the saved operator."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -271,7 +271,7 @@ def check_save(method):
|
|||
|
||||
|
||||
def check_minddataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -303,7 +303,7 @@ def check_minddataset(method):
|
|||
|
||||
|
||||
def check_generatordataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -369,7 +369,7 @@ def check_generatordataset(method):
|
|||
|
||||
|
||||
def check_random_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -794,7 +794,7 @@ def check_add_column(method):
|
|||
|
||||
|
||||
def check_cluedataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -824,7 +824,7 @@ def check_cluedataset(method):
|
|||
|
||||
|
||||
def check_csvdataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(CSVDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -871,7 +871,7 @@ def check_csvdataset(method):
|
|||
|
||||
|
||||
def check_textfiledataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -964,7 +964,7 @@ def check_gnn_graphdata(method):
|
|||
|
||||
|
||||
def check_gnn_get_all_nodes(method):
|
||||
"""A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -977,7 +977,7 @@ def check_gnn_get_all_nodes(method):
|
|||
|
||||
|
||||
def check_gnn_get_all_edges(method):
|
||||
"""A wrapper that wraps a parameter checker to the GNN `get_all_edges` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_all_edges` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -990,7 +990,7 @@ def check_gnn_get_all_edges(method):
|
|||
|
||||
|
||||
def check_gnn_get_nodes_from_edges(method):
|
||||
"""A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_nodes_from_edges` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -1003,7 +1003,7 @@ def check_gnn_get_nodes_from_edges(method):
|
|||
|
||||
|
||||
def check_gnn_get_all_neighbors(method):
|
||||
"""A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -1018,7 +1018,7 @@ def check_gnn_get_all_neighbors(method):
|
|||
|
||||
|
||||
def check_gnn_get_sampled_neighbors(method):
|
||||
"""A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_sampled_neighbors` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -1046,7 +1046,7 @@ def check_gnn_get_sampled_neighbors(method):
|
|||
|
||||
|
||||
def check_gnn_get_neg_sampled_neighbors(method):
|
||||
"""A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_neg_sampled_neighbors` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -1062,7 +1062,7 @@ def check_gnn_get_neg_sampled_neighbors(method):
|
|||
|
||||
|
||||
def check_gnn_random_walk(method):
|
||||
"""A wrapper that wraps a parameter checker to the GNN `random_walk` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `random_walk` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -1110,7 +1110,7 @@ def check_aligned_list(param, param_name, member_type):
|
|||
|
||||
|
||||
def check_gnn_get_node_feature(method):
|
||||
"""A wrapper that wraps a parameter checker to the GNN `get_node_feature` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_node_feature` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -1132,7 +1132,7 @@ def check_gnn_get_node_feature(method):
|
|||
|
||||
|
||||
def check_gnn_get_edge_feature(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function."""
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_edge_feature` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -1154,7 +1154,7 @@ def check_gnn_get_edge_feature(method):
|
|||
|
||||
|
||||
def check_numpyslicesdataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -1195,17 +1195,17 @@ def check_numpyslicesdataset(method):
|
|||
|
||||
|
||||
def check_paddeddataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(PaddedDataset)."""
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
paddedSamples = param_dict.get("padded_samples")
|
||||
if not paddedSamples:
|
||||
padded_samples = param_dict.get("padded_samples")
|
||||
if not padded_samples:
|
||||
raise ValueError("Argument padded_samples cannot be empty")
|
||||
type_check(paddedSamples, (list,), "padded_samples")
|
||||
type_check(paddedSamples[0], (dict,), "padded_element")
|
||||
type_check(padded_samples, (list,), "padded_samples")
|
||||
type_check(padded_samples[0], (dict,), "padded_element")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -328,7 +328,7 @@ def check_from_dataset(method):
|
|||
return new_method
|
||||
|
||||
def check_slidingwindow(method):
|
||||
"""A wrapper that wrap a parameter checker to the original function(sliding window operation)."""
|
||||
"""A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -496,4 +496,3 @@ def check_save_model(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
@ -31,8 +31,8 @@ class OneHot(cde.OneHotOp):
|
|||
Tensor operation to apply one hot encoding.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes of the label
|
||||
it should be bigger than largest label number in dataset.
|
||||
num_classes (int): Number of classes of the label.
|
||||
It should be larger than the largest label number in the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: feature size is bigger than num_classes.
|
||||
|
|
|
@ -27,8 +27,9 @@ class OneHotOp:
|
|||
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.
|
||||
|
||||
Args:
|
||||
num_classes (int): Num class of object in dataset, type is int and value over 0.
|
||||
smoothing_rate (float): The adjustable Hyper parameter decides the label smoothing level , 0.0 means not do it.
|
||||
num_classes (int): Number of classes of objects in dataset. Value must be larger than 0.
|
||||
smoothing_rate (float, optional): Adjustable hyperparameter for label smoothing level.
|
||||
(Default=0.0 means no smoothing is applied.)
|
||||
"""
|
||||
|
||||
@check_one_hot_op
|
||||
|
|
|
@ -152,7 +152,7 @@ def check_erasing_value(value):
|
|||
|
||||
|
||||
def check_crop(method):
|
||||
"""A wrapper that wraps a parameter checker to the original function(crop operation)."""
|
||||
"""A wrapper that wraps a parameter checker around the original function(crop operation)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -165,7 +165,7 @@ def check_crop(method):
|
|||
|
||||
|
||||
def check_posterize(method):
|
||||
""""A wrapper that wraps a parameter checker to the original function(posterize operation)."""
|
||||
"""A wrapper that wraps a parameter checker around the original function(posterize operation)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -187,7 +187,7 @@ def check_posterize(method):
|
|||
|
||||
|
||||
def check_resize_interpolation(method):
|
||||
"""A wrapper that wraps a parameter checker to the original function(resize interpolation operation)."""
|
||||
"""A wrapper that wraps a parameter checker around the original function(resize interpolation operation)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -202,7 +202,7 @@ def check_resize_interpolation(method):
|
|||
|
||||
|
||||
def check_resize(method):
|
||||
"""A wrapper that wraps a parameter checker to the original function(resize operation)."""
|
||||
"""A wrapper that wraps a parameter checker around the original function(resize operation)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -235,7 +235,7 @@ def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
|
|||
|
||||
|
||||
def check_random_resize_crop(method):
|
||||
"""A wrapper that wraps a parameter checker to the original function(random resize crop operation)."""
|
||||
"""A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -250,7 +250,7 @@ def check_random_resize_crop(method):
|
|||
|
||||
|
||||
def check_prob(method):
|
||||
"""A wrapper that wraps a parameter checker(check the probability) to the original function."""
|
||||
"""A wrapper that wraps a parameter checker (to confirm probability) around the original function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -264,7 +264,7 @@ def check_prob(method):
|
|||
|
||||
|
||||
def check_normalize_c(method):
|
||||
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in C++)."""
|
||||
"""A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
@ -277,7 +277,7 @@ def check_normalize_c(method):
|
|||
|
||||
|
||||
def check_normalize_py(method):
|
||||
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in Python)."""
|
||||
"""A wrapper that wraps a parameter checker around the original function(normalize operation written in Python)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -86,8 +86,13 @@ def test_five_crop_error_msg():
|
|||
transform = vision.ComposeOp(transforms)
|
||||
data = data.map(input_columns=["image"], operations=transform())
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
data.create_tuple_iterator().__next__()
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in data:
|
||||
pass
|
||||
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
|
||||
|
||||
# error msg comes from ToTensor()
|
||||
assert error_msg in str(info.value)
|
||||
|
||||
|
||||
def test_five_crop_md5():
|
||||
|
|
|
@ -149,7 +149,7 @@ def test_random_color_py_md5():
|
|||
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
|
||||
|
||||
transforms = F.ComposeOp([F.Decode(),
|
||||
F.RandomColor((0.1, 1.9)),
|
||||
F.RandomColor((2.0, 2.5)),
|
||||
F.ToTensor()])
|
||||
|
||||
data = data.map(input_columns="image", operations=transforms())
|
||||
|
@ -244,12 +244,12 @@ def test_random_color_c_errors():
|
|||
if __name__ == "__main__":
|
||||
test_random_color_py()
|
||||
test_random_color_py(plot=True)
|
||||
test_random_color_py(degrees=(0.5, 1.5), plot=True)
|
||||
test_random_color_py(degrees=(2.0, 2.5), plot=True) # Test with degree values that show more obvious transformation
|
||||
test_random_color_py_md5()
|
||||
|
||||
test_random_color_c()
|
||||
test_random_color_c(plot=True)
|
||||
test_random_color_c(degrees=(0.5, 1.5), plot=True, run_golden=False)
|
||||
test_random_color_c(degrees=(2.0, 2.5), plot=True, run_golden=False) # Test with degree values that show more obvious transformation
|
||||
test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False)
|
||||
test_compare_random_color_op(plot=True)
|
||||
test_random_color_c_errors()
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_random_sharpness_py_md5():
|
|||
# define map operations
|
||||
transforms = [
|
||||
F.Decode(),
|
||||
F.RandomSharpness((0.1, 1.9)),
|
||||
F.RandomSharpness((20.0, 25.0)),
|
||||
F.ToTensor()
|
||||
]
|
||||
transform = F.ComposeOp(transforms)
|
||||
|
@ -193,7 +193,7 @@ def test_random_sharpness_c_md5():
|
|||
# define map operations
|
||||
transforms = [
|
||||
C.Decode(),
|
||||
C.RandomSharpness((0.1, 1.9))
|
||||
C.RandomSharpness((10.0, 15.0))
|
||||
]
|
||||
|
||||
# Generate dataset
|
||||
|
@ -337,14 +337,16 @@ def test_random_sharpness_invalid_params():
|
|||
|
||||
if __name__ == "__main__":
|
||||
test_random_sharpness_py(plot=True)
|
||||
test_random_sharpness_py(None, plot=True) # test with default values
|
||||
test_random_sharpness_py(None, plot=True) # Test with default values
|
||||
test_random_sharpness_py(degrees=(20.0, 25.0), plot=True) # Test with degree values that show more obvious transformation
|
||||
test_random_sharpness_py_md5()
|
||||
test_random_sharpness_c(plot=True)
|
||||
test_random_sharpness_c(None, plot=True) # test with default values
|
||||
test_random_sharpness_c(degrees=[10, 15], plot=True) # Test with degrees values that show more obvious transformation
|
||||
test_random_sharpness_c_md5()
|
||||
test_random_sharpness_c_py(degrees=[1.5, 1.5], plot=True)
|
||||
test_random_sharpness_c_py(degrees=[1, 1], plot=True)
|
||||
test_random_sharpness_c_py(degrees=[10, 10], plot=True)
|
||||
test_random_sharpness_one_channel_c(degrees=[1.7, 1.7], plot=True)
|
||||
test_random_sharpness_one_channel_c(degrees=None, plot=True) # test with default values
|
||||
test_random_sharpness_one_channel_c(degrees=None, plot=True) # Test with default values
|
||||
test_random_sharpness_invalid_params()
|
||||
|
|
|
@ -303,7 +303,7 @@ def test_repeat_count0():
|
|||
with pytest.raises(ValueError) as info:
|
||||
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
|
||||
data1.repeat(0)
|
||||
assert "count" in str(info)
|
||||
assert "count" in str(info.value)
|
||||
|
||||
def test_repeat_countneg2():
|
||||
"""
|
||||
|
@ -313,7 +313,7 @@ def test_repeat_countneg2():
|
|||
with pytest.raises(ValueError) as info:
|
||||
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
|
||||
data1.repeat(-2)
|
||||
assert "count" in str(info)
|
||||
assert "count" in str(info.value)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tf_repeat_01()
|
||||
|
|
Loading…
Reference in New Issue