forked from mindspore-Ecosystem/mindspore
!872 [Dataset] Add schema support for GeneratorDataset
Merge pull request !872 from JunhanHu/generator_schema
This commit is contained in:
commit
a606c2e4da
|
@ -2504,11 +2504,12 @@ class GeneratorDataset(SourceDataset):
|
||||||
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
|
Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
|
||||||
Random accessible source is required to return a tuple of numpy array as a row of the dataset on
|
Random accessible source is required to return a tuple of numpy array as a row of the dataset on
|
||||||
source[idx].
|
source[idx].
|
||||||
column_names (list[str]): List of column names of the dataset.
|
column_names (list[str], optional): List of column names of the dataset (default=None). Users are required to
|
||||||
|
provide either column_names or schema.
|
||||||
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
|
column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
|
||||||
If provided, sanity check will be performed on generator output.
|
If provided, sanity check will be performed on generator output.
|
||||||
schema (Schema/String, optional): Path to the json schema file or schema object (default=None).
|
schema (Schema/String, optional): Path to the json schema file or schema object (default=None). Users are
|
||||||
If the schema is not provided, the meta data from column_names and column_types is considered the schema.
|
required to provide either column_names or schema. If both are provided, schema will be used.
|
||||||
num_samples (int, optional): The number of samples to be included in the dataset
|
num_samples (int, optional): The number of samples to be included in the dataset
|
||||||
(default=None, all images).
|
(default=None, all images).
|
||||||
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
|
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
|
||||||
|
@ -2555,8 +2556,8 @@ class GeneratorDataset(SourceDataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_generatordataset
|
@check_generatordataset
|
||||||
def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1,
|
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
||||||
shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||||
if self.sampler is not None and hasattr(source, "__getitem__"):
|
if self.sampler is not None and hasattr(source, "__getitem__"):
|
||||||
|
@ -2598,6 +2599,16 @@ class GeneratorDataset(SourceDataset):
|
||||||
else:
|
else:
|
||||||
self.column_types = column_types
|
self.column_types = column_types
|
||||||
|
|
||||||
|
if schema is not None:
|
||||||
|
self.schema = schema
|
||||||
|
if not isinstance(schema, Schema):
|
||||||
|
self.schema = Schema(schema)
|
||||||
|
self.column_names = []
|
||||||
|
self.column_types = []
|
||||||
|
for col in self.schema.columns:
|
||||||
|
self.column_names.append(col["name"])
|
||||||
|
self.column_types.append(DataType(col["type"]))
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
args = super().get_args()
|
args = super().get_args()
|
||||||
args["source"] = self.source
|
args["source"] = self.source
|
||||||
|
|
|
@ -555,10 +555,15 @@ def check_generatordataset(method):
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise TypeError("source should be callable, iterable or random accessible")
|
raise TypeError("source should be callable, iterable or random accessible")
|
||||||
|
|
||||||
# check column_names; required argument
|
# check column_names or schema; required argument
|
||||||
column_names = param_dict.get('column_names')
|
column_names = param_dict.get('column_names')
|
||||||
if column_names is None:
|
schema = param_dict.get('schema')
|
||||||
raise ValueError("column_names is not provided.")
|
if column_names is None and schema is None:
|
||||||
|
raise ValueError("Neither columns_names not schema are provided.")
|
||||||
|
|
||||||
|
if schema is not None:
|
||||||
|
if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
|
||||||
|
raise ValueError("schema should be a path to schema file or a schema object.")
|
||||||
|
|
||||||
# check optional argument
|
# check optional argument
|
||||||
nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
|
nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
|
||||||
|
|
|
@ -580,6 +580,41 @@ def test_num_samples_underflow():
|
||||||
count = count + 1
|
count = count + 1
|
||||||
assert count == 64
|
assert count == 64
|
||||||
|
|
||||||
|
|
||||||
|
def type_tester_with_type_check_2c_schema(t, c):
|
||||||
|
logger.info("Test with Type {}".format(t.__name__))
|
||||||
|
|
||||||
|
schema = ds.Schema()
|
||||||
|
schema.add_column("data0", c[0])
|
||||||
|
schema.add_column("data1", c[1])
|
||||||
|
|
||||||
|
# apply dataset operations
|
||||||
|
data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), schema=schema)
|
||||||
|
|
||||||
|
data1 = data1.batch(4)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for item in data1.create_dict_iterator(): # each data is a dictionary
|
||||||
|
golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
|
||||||
|
assert np.array_equal(item["data0"], golden)
|
||||||
|
i = i + 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema():
|
||||||
|
"""
|
||||||
|
Test 2 column Generator on different data type with type check with schema input
|
||||||
|
"""
|
||||||
|
logger.info("Test 2 column Generator on all data types with type check")
|
||||||
|
|
||||||
|
np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32,
|
||||||
|
np.float64]
|
||||||
|
de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32,
|
||||||
|
mstype.uint64, mstype.float32, mstype.float64]
|
||||||
|
|
||||||
|
for i in range(len(np_types)):
|
||||||
|
type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
|
||||||
|
|
||||||
|
|
||||||
def manual_test_keyborad_interrupt():
|
def manual_test_keyborad_interrupt():
|
||||||
"""
|
"""
|
||||||
Test keyborad_interrupt
|
Test keyborad_interrupt
|
||||||
|
@ -626,5 +661,6 @@ if __name__ == "__main__":
|
||||||
test_sequential_sampler()
|
test_sequential_sampler()
|
||||||
test_distributed_sampler()
|
test_distributed_sampler()
|
||||||
test_random_sampler()
|
test_random_sampler()
|
||||||
|
test_schema()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue