forked from mindspore-Ecosystem/mindspore
changed discription
Added validator checks Added builtin sampler check Added proper input check for samplers
This commit is contained in:
parent
e489b67a3a
commit
5d91fa6d77
|
@ -16,12 +16,13 @@
|
|||
General Validators.
|
||||
"""
|
||||
import inspect
|
||||
import numbers
|
||||
from multiprocessing import cpu_count
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from ..engine import samplers
|
||||
# POS_INT_MIN is used to limit values from starting from 0
|
||||
POS_INT_MIN = 1
|
||||
UINT8_MAX = 255
|
||||
|
@ -288,6 +289,7 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
|
||||
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||
num_samples = param_dict.get('num_samples')
|
||||
check_sampler(sampler)
|
||||
|
||||
if sampler is not None:
|
||||
if shuffle is not None:
|
||||
|
@ -384,6 +386,37 @@ def check_tensor_op(param, param_name):
|
|||
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
|
||||
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
|
||||
|
||||
def check_sampler(sampler):
|
||||
"""
|
||||
Check if the sampler is of valid input.
|
||||
|
||||
Args:
|
||||
param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler
|
||||
|
||||
Returns:
|
||||
Exception: TypeError if error
|
||||
"""
|
||||
builtin = False
|
||||
base_sampler = False
|
||||
list_num = False
|
||||
if sampler is not None:
|
||||
if isinstance(sampler, samplers.BuiltinSampler):
|
||||
builtin = True
|
||||
elif isinstance(sampler, samplers.Sampler):
|
||||
base_sampler = True
|
||||
else:
|
||||
# check for list of numbers
|
||||
list_num = True
|
||||
# subset sampler check
|
||||
subset_sampler = sampler
|
||||
if not isinstance(sampler, list):
|
||||
subset_sampler = [sampler]
|
||||
|
||||
for _, item in enumerate(subset_sampler):
|
||||
if not isinstance(item, numbers.Number):
|
||||
list_num = False
|
||||
if not (builtin or base_sampler or list_num):
|
||||
raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers")
|
||||
|
||||
def replace_none(value, default):
|
||||
return value if value is not None else default
|
||||
|
|
|
@ -3535,7 +3535,8 @@ class TFRecordDataset(SourceDataset):
|
|||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument can only be specified when num_shards is also specified.
|
||||
shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
|
||||
is false, number of rows of each shard may be not equal.
|
||||
is false, number of rows of each shard may be not equal. This
|
||||
argument should only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd.
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -20,6 +20,10 @@ DATA_DIR = "../data/dataset/testCelebAData/"
|
|||
|
||||
|
||||
def test_celeba_dataset_label():
|
||||
"""
|
||||
Test CelebA dataset with labels
|
||||
"""
|
||||
logger.info("Test CelebA labels")
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
||||
expect_labels = [
|
||||
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
||||
|
@ -43,6 +47,10 @@ def test_celeba_dataset_label():
|
|||
|
||||
|
||||
def test_celeba_dataset_op():
|
||||
"""
|
||||
Test CelebA dataset with decode
|
||||
"""
|
||||
logger.info("Test CelebA with decode")
|
||||
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
|
||||
crop_size = (80, 80)
|
||||
resize_size = (24, 24)
|
||||
|
@ -62,6 +70,10 @@ def test_celeba_dataset_op():
|
|||
|
||||
|
||||
def test_celeba_dataset_ext():
|
||||
"""
|
||||
Test CelebA dataset with extension
|
||||
"""
|
||||
logger.info("Test CelebA extension option")
|
||||
ext = [".JPEG"]
|
||||
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
|
||||
expect_labels = [
|
||||
|
@ -82,6 +94,10 @@ def test_celeba_dataset_ext():
|
|||
|
||||
|
||||
def test_celeba_dataset_distribute():
|
||||
"""
|
||||
Test CelebA dataset with distributed options
|
||||
"""
|
||||
logger.info("Test CelebA with sharding")
|
||||
data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
|
||||
count = 0
|
||||
for item in data.create_dict_iterator(num_epochs=1):
|
||||
|
@ -94,6 +110,10 @@ def test_celeba_dataset_distribute():
|
|||
|
||||
|
||||
def test_celeba_get_dataset_size():
|
||||
"""
|
||||
Test CelebA dataset get dataset size
|
||||
"""
|
||||
logger.info("Test CelebA get dataset size")
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
||||
size = data.get_dataset_size()
|
||||
assert size == 4
|
||||
|
@ -112,6 +132,10 @@ def test_celeba_get_dataset_size():
|
|||
|
||||
|
||||
def test_celeba_dataset_exception_file_path():
|
||||
"""
|
||||
Test CelebA dataset with bad file path
|
||||
"""
|
||||
logger.info("Test CelebA with bad file path")
|
||||
def exception_func(item):
|
||||
raise Exception("Error occur!")
|
||||
|
||||
|
@ -144,6 +168,20 @@ def test_celeba_dataset_exception_file_path():
|
|||
assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
|
||||
|
||||
|
||||
def test_celeba_sampler_exception():
|
||||
"""
|
||||
Test CelebA with bad sampler input
|
||||
"""
|
||||
logger.info("Test CelebA with bad sampler input")
|
||||
try:
|
||||
data = ds.CelebADataset(DATA_DIR, sampler="")
|
||||
for _ in data.create_dict_iterator():
|
||||
pass
|
||||
assert False
|
||||
except TypeError as e:
|
||||
assert "Argument" in str(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_celeba_dataset_label()
|
||||
test_celeba_dataset_op()
|
||||
|
@ -151,3 +189,4 @@ if __name__ == '__main__':
|
|||
test_celeba_dataset_distribute()
|
||||
test_celeba_get_dataset_size()
|
||||
test_celeba_dataset_exception_file_path()
|
||||
test_celeba_sampler_exception()
|
||||
|
|
Loading…
Reference in New Issue