From c5a8ffe4f4bcfb499bd25463af8cdf76d9159dd1 Mon Sep 17 00:00:00 2001 From: Junhan Hu Date: Wed, 29 Apr 2020 11:52:58 -0400 Subject: [PATCH] Add schema support for GeneratorDataset --- mindspore/dataset/engine/datasets.py | 21 +++++++++---- mindspore/dataset/engine/validators.py | 11 +++++-- tests/ut/python/dataset/test_generator.py | 36 +++++++++++++++++++++++ 3 files changed, 60 insertions(+), 8 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 5504cc3362d..b56e2ce4ae0 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2504,11 +2504,12 @@ class GeneratorDataset(SourceDataset): Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next(). Random accessible source is required to return a tuple of numpy array as a row of the dataset on source[idx]. - column_names (list[str]): List of column names of the dataset. + column_names (list[str], optional): List of column names of the dataset (default=None). Users are required to + provide either column_names or schema. column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None). If provided, sanity check will be performed on generator output. - schema (Schema/String, optional): Path to the json schema file or schema object (default=None). - If the schema is not provided, the meta data from column_names and column_types is considered the schema. + schema (Schema/String, optional): Path to the json schema file or schema object (default=None). Users are + required to provide either column_names or schema. If both are provided, schema will be used. num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images). num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). @@ -2555,8 +2556,8 @@ class GeneratorDataset(SourceDataset): """ @check_generatordataset - def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, - shuffle=None, sampler=None, num_shards=None, shard_id=None): + def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, + num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) if self.sampler is not None and hasattr(source, "__getitem__"): @@ -2598,6 +2599,16 @@ class GeneratorDataset(SourceDataset): else: self.column_types = column_types + if schema is not None: + self.schema = schema + if not isinstance(schema, Schema): + self.schema = Schema(schema) + self.column_names = [] + self.column_types = [] + for col in self.schema.columns: + self.column_names.append(col["name"]) + self.column_types.append(DataType(col["type"])) + def get_args(self): args = super().get_args() args["source"] = self.source diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 4f1b394634f..dbe8e47d031 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -555,10 +555,15 @@ def check_generatordataset(method): except TypeError: raise TypeError("source should be callable, iterable or random accessible") - # check column_names; required argument + # check column_names or schema; required argument column_names = param_dict.get('column_names') - if column_names is None: - raise ValueError("column_names is not provided.") + schema = param_dict.get('schema') + if column_names is None and schema is None: + raise ValueError("Neither columns_names not schema are provided.") + + if schema is not None: + if not isinstance(schema, datasets.Schema) and not isinstance(schema, str): + raise ValueError("schema should be a path to schema file or a schema object.") # check optional argument nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"] diff --git a/tests/ut/python/dataset/test_generator.py b/tests/ut/python/dataset/test_generator.py index 4daf952eba8..529788fcaaf 100644 --- a/tests/ut/python/dataset/test_generator.py +++ b/tests/ut/python/dataset/test_generator.py @@ -580,6 +580,41 @@ def test_num_samples_underflow(): count = count + 1 assert count == 64 + +def type_tester_with_type_check_2c_schema(t, c): + logger.info("Test with Type {}".format(t.__name__)) + + schema = ds.Schema() + schema.add_column("data0", c[0]) + schema.add_column("data1", c[1]) + + # apply dataset operations + data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)), schema=schema) + + data1 = data1.batch(4) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t) + assert np.array_equal(item["data0"], golden) + i = i + 4 + + +def test_schema(): + """ + Test 2 column Generator on different data type with type check with schema input + """ + logger.info("Test 2 column Generator on all data types with type check") + + np_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float32, + np.float64] + de_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint16, mstype.uint32, + mstype.uint64, mstype.float32, mstype.float64] + + for i in range(len(np_types)): + type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]]) + + def manual_test_keyborad_interrupt(): """ Test keyborad_interrupt @@ -626,5 +661,6 @@ if __name__ == "__main__": test_sequential_sampler() test_distributed_sampler() test_random_sampler() + test_schema()