changed discription

Added validator checks

Added builtin sampler check

Added proper input check for samplers
This commit is contained in:
Eric 2021-02-08 15:59:59 -05:00
parent e489b67a3a
commit 5d91fa6d77
3 changed files with 76 additions and 3 deletions

View File

@ -16,12 +16,13 @@
General Validators.
"""
import inspect
import numbers
from multiprocessing import cpu_count
import os
import numpy as np
import mindspore._c_dataengine as cde
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
UINT8_MAX = 255
@ -288,6 +289,7 @@ def check_sampler_shuffle_shard_options(param_dict):
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
num_samples = param_dict.get('num_samples')
check_sampler(sampler)
if sampler is not None:
if shuffle is not None:
@ -384,6 +386,37 @@ def check_tensor_op(param, param_name):
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
def check_sampler(sampler):
"""
Check if the sampler is of valid input.
Args:
param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler
Returns:
Exception: TypeError if error
"""
builtin = False
base_sampler = False
list_num = False
if sampler is not None:
if isinstance(sampler, samplers.BuiltinSampler):
builtin = True
elif isinstance(sampler, samplers.Sampler):
base_sampler = True
else:
# check for list of numbers
list_num = True
# subset sampler check
subset_sampler = sampler
if not isinstance(sampler, list):
subset_sampler = [sampler]
for _, item in enumerate(subset_sampler):
if not isinstance(item, numbers.Number):
list_num = False
if not (builtin or base_sampler or list_num):
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
def replace_none(value, default):
return value if value is not None else default

View File

@ -3535,7 +3535,8 @@ class TFRecordDataset(SourceDataset):
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
is false, number of rows of each shard may be not equal.
is false, number of rows of each shard may be not equal. This
argument should only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,6 +20,10 @@ DATA_DIR = "../data/dataset/testCelebAData/"
def test_celeba_dataset_label():
"""
Test CelebA dataset with labels
"""
logger.info("Test CelebA labels")
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
expect_labels = [
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
@ -43,6 +47,10 @@ def test_celeba_dataset_label():
def test_celeba_dataset_op():
"""
Test CelebA dataset with decode
"""
logger.info("Test CelebA with decode")
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
crop_size = (80, 80)
resize_size = (24, 24)
@ -62,6 +70,10 @@ def test_celeba_dataset_op():
def test_celeba_dataset_ext():
"""
Test CelebA dataset with extension
"""
logger.info("Test CelebA extension option")
ext = [".JPEG"]
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
expect_labels = [
@ -82,6 +94,10 @@ def test_celeba_dataset_ext():
def test_celeba_dataset_distribute():
"""
Test CelebA dataset with distributed options
"""
logger.info("Test CelebA with sharding")
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
count = 0
for item in data.create_dict_iterator(num_epochs=1):
@ -94,6 +110,10 @@ def test_celeba_dataset_distribute():
def test_celeba_get_dataset_size():
"""
Test CelebA dataset get dataset size
"""
logger.info("Test CelebA get dataset size")
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
size = data.get_dataset_size()
assert size == 4
@ -112,6 +132,10 @@ def test_celeba_get_dataset_size():
def test_celeba_dataset_exception_file_path():
"""
Test CelebA dataset with bad file path
"""
logger.info("Test CelebA with bad file path")
def exception_func(item):
raise Exception("Error occur!")
@ -144,6 +168,20 @@ def test_celeba_dataset_exception_file_path():
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
def test_celeba_sampler_exception():
"""
Test CelebA with bad sampler input
"""
logger.info("Test CelebA with bad sampler input")
try:
data = ds.CelebADataset(DATA_DIR, sampler="")
for _ in data.create_dict_iterator():
pass
assert False
except TypeError as e:
assert "Argument" in str(e)
if __name__ == '__main__':
test_celeba_dataset_label()
test_celeba_dataset_op()
@ -151,3 +189,4 @@ if __name__ == '__main__':
test_celeba_dataset_distribute()
test_celeba_get_dataset_size()
test_celeba_dataset_exception_file_path()
test_celeba_sampler_exception()