From 138dec9a55cb6cb636c84320711d1f33f195b458 Mon Sep 17 00:00:00 2001 From: jonyguo Date: Tue, 11 Jan 2022 11:46:06 +0800 Subject: [PATCH] enhance datasets.py --- .../mindspore/dataset/engine/__init__.py | 126 +++++++-------- .../mindspore/dataset/engine/datasets.py | 149 ++++++++++++------ .../dataset/engine/datasets_audio.py | 23 +-- .../engine/datasets_standard_format.py | 81 +++++++++- .../mindspore/dataset/engine/datasets_text.py | 111 ++++--------- .../dataset/engine/datasets_user_defined.py | 17 +- .../dataset/engine/datasets_vision.py | 65 ++++---- 7 files changed, 321 insertions(+), 251 deletions(-) diff --git a/mindspore/python/mindspore/dataset/engine/__init__.py b/mindspore/python/mindspore/dataset/engine/__init__.py index e1f0906f45d..94f0c7c69ad 100644 --- a/mindspore/python/mindspore/dataset/engine/__init__.py +++ b/mindspore/python/mindspore/dataset/engine/__init__.py @@ -36,69 +36,69 @@ from .iterators import * from .samplers import * from .serializer_deserializer import compare, deserialize, serialize, show -__all__ = ["Caltech101Dataset", # vision dataset - "Caltech256Dataset", # vision dataset - "CelebADataset", # vision dataset - "Cifar10Dataset", # vision dataset - "Cifar100Dataset", # vision dataset - "CityscapesDataset", # vision dataset - "CocoDataset", # vision dataset - "DIV2KDataset", # vision dataset - "EMnistDataset", # vision dataset - "FakeImageDataset", # vision dataset - "FashionMnistDataset", # vision dataset - "FlickrDataset", # vision dataset - "Flowers102Dataset", # vision dataset - "ImageFolderDataset", # vision dataset - "KMnistDataset", # vision dataset - "ManifestDataset", # vision dataset - "MnistDataset", # vision dataset - "PhotoTourDataset", # vision dataset - "Places365Dataset", # vision dataset - "QMnistDataset", # vision dataset - "RandomDataset", # vision dataset - "SBDataset", # vision dataset - "SBUDataset", # vision dataset - "SemeionDataset", # vision dataset - "STL10Dataset", # vision dataset - "SVHNDataset", # vision dataset - "USPSDataset", # vision dataset - "VOCDataset", # vision dataset - "WIDERFaceDataset", # vision dataset - "AGNewsDataset", # text dataset - "AmazonReviewDataset", # text dataset - "CLUEDataset", # text dataset - "CoNLL2000Dataset", # text dataset - "CSVDataset", # text dataset - "DBpediaDataset", # text dataset - "EnWik9Dataset", # text dataset - "IMDBDataset", # text dataset - "IWSLT2016Dataset", # text dataset - "IWSLT2017Dataset", # text dataset - "PennTreebankDataset", # text dataset - "SogouNewsDataset", # text dataset - "TextFileDataset", # text dataset - "UDPOSDataset", # text dataset - "WikiTextDataset", # text dataset - "YahooAnswersDataset", # text dataset - "YelpReviewDataset", # text dataset - "LJSpeechDataset", # audio dataset - "SpeechCommandsDataset", # audio dataset - "TedliumDataset", # audio dataset - "YesNoDataset", # audio dataset - "MindDataset", # standard format dataset - "TFRecordDataset", # standard format dataset - "GeneratorDataset", # user defined dataset - "NumpySlicesDataset", # user defined dataset - "PaddedDataset", # user defined dataset - "GraphData", # graph data - "DistributedSampler", # sampler - "RandomSampler", # sampler - "SequentialSampler", # sampler - "SubsetRandomSampler", # sampler - "SubsetSampler", # sampler - "PKSampler", # sampler - "WeightedRandomSampler", # sampler +__all__ = ["Caltech101Dataset", # Vision + "Caltech256Dataset", # Vision + "CelebADataset", # Vision + "Cifar10Dataset", # Vision + "Cifar100Dataset", # Vision + "CityscapesDataset", # Vision + "CocoDataset", # Vision + "DIV2KDataset", # Vision + "EMnistDataset", # Vision + "FakeImageDataset", # Vision + "FashionMnistDataset", # Vision + "FlickrDataset", # Vision + "Flowers102Dataset", # Vision + "ImageFolderDataset", # Vision + "KMnistDataset", # Vision + "ManifestDataset", # Vision + "MnistDataset", # Vision + "PhotoTourDataset", # Vision + "Places365Dataset", # Vision + "QMnistDataset", # Vision + "RandomDataset", # Vision + "SBDataset", # Vision + "SBUDataset", # Vision + "SemeionDataset", # Vision + "STL10Dataset", # Vision + "SVHNDataset", # Vision + "USPSDataset", # Vision + "VOCDataset", # Vision + "WIDERFaceDataset", # Vision + "AGNewsDataset", # Text + "AmazonReviewDataset", # Text + "CLUEDataset", # Text + "CoNLL2000Dataset", # Text + "DBpediaDataset", # Text + "EnWik9Dataset", # Text + "IMDBDataset", # Text + "IWSLT2016Dataset", # Text + "IWSLT2017Dataset", # Text + "PennTreebankDataset", # Text + "SogouNewsDataset", # Text + "TextFileDataset", # Text + "UDPOSDataset", # Text + "WikiTextDataset", # Text + "YahooAnswersDataset", # Text + "YelpReviewDataset", # Text + "LJSpeechDataset", # Audio + "SpeechCommandsDataset", # Audio + "TedliumDataset", # Audio + "YesNoDataset", # Audio + "CSVDataset", # Standard Format + "MindDataset", # Standard Format + "TFRecordDataset", # Standard Format + "GeneratorDataset", # User Defined + "NumpySlicesDataset", # User Defined + "PaddedDataset", # User Defined + "GraphData", # Graph Data + "DistributedSampler", # Sampler + "RandomSampler", # Sampler + "SequentialSampler", # Sampler + "SubsetRandomSampler", # Sampler + "SubsetSampler", # Sampler + "PKSampler", # Sampler + "WeightedRandomSampler", # Sampler "DatasetCache", "DSCallback", "WaitedDSCallback", diff --git a/mindspore/python/mindspore/dataset/engine/datasets.py b/mindspore/python/mindspore/dataset/engine/datasets.py index e66720068f9..d065b749fdf 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets.py +++ b/mindspore/python/mindspore/dataset/engine/datasets.py @@ -13,15 +13,21 @@ # limitations under the License. # ============================================================================== """ -This dataset module supports various formats of datasets, including ImageNet, TFData, -MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with -high performance and parses data precisely. Some of the operations that are -provided to users to preprocess data include shuffle, batch, repeat, map, and zip. +1. This file is an abstraction of the dataset loading class. It contains +some basic dataset operations(skip, filter, map, batch, ...). +2. Specific dataset loading classes can be found in datasets_vision.py, datasets_text.py, +datasets_audio.py, datasets_standard_format.py and dataets_user_defined.py files. + datasets_vision.py: contains vision dataset loading classes. + datasets_text.py: contains text dataset loading classes. + datasets_audio.py: contains audio dataset loading classes. + datasets_standard_format.py: contains standard format loading classes which + any other kinds of datasets can be converted to. + dataets_user_defined.py: contains basic classes that help users to define + flexible ways to load dataset. """ import atexit import glob import json -import math import os import signal import stat @@ -215,6 +221,44 @@ class Dataset: This class is the base class of SourceDataset and Dataset, and represents a node in the data flow graph. + Dataset + ----------------------------------------------------------- + | | | | + VisionBaseDataset TextBaseDataset AudioBaseDataset | + - - - | + | | | | + ---------------------------------------- | + UnionBaseDataset | + | + SourceDataset + - + | + MappableDataset + + DatasetOperator: MapDataset(UnionBaseDataset) + BatchDataset(UnionBaseDataset) + BucketBatchByLengthDataset(UnionBaseDataset) + ShuffleDataset(UnionBaseDataset) + FilterDataset(UnionBaseDataset) + RepeatDataset(UnionBaseDataset) + SkipDataset(UnionBaseDataset) + TakeDataset(UnionBaseDataset) + ZipDataset(UnionBaseDataset) + ConcatDataset(UnionBaseDataset) + RenameDataset(UnionBaseDataset) + ProjectDataset(UnionBaseDataset) + SyncWaitDataset(UnionBaseDataset) + + Impl Dataset - vision: ImageFolderDataset(MappableDataset, VisionBaseDataset) + USPSDataset(SourceDataset, VisionBaseDataset) + Impl Dataset - text: TextFileDataset(SourceDataset, TextBaseDataset) + YahooAnswersDataset(SourceDataset, TextBaseDataset) + Impl Dataset - audio: LJSpeechDataset(MappableDataset, AudioBaseDataset) + TedliumDataset(MappableDataset, AudioBaseDataset) + Impl Dataset - standard: MindDataset(MappableDataset, UnionBaseDataset) + TFRecordDataset(SourceDataset, UnionBaseDataset) + Impl Dataset - user defined: GeneratorDataset(MappableDataset, UnionBaseDataset) + NumpySlicesDataset(GeneratorDataset) Args: num_parallel_workers (int, optional): Number of workers to process the dataset in parallel @@ -1796,6 +1840,18 @@ class Dataset: return ir_node +class VisionBaseDataset(Dataset): + """ + Abstract class to represent a vision source dataset which produces content to the data pipeline. + """ + + def __init__(self, children=None, num_parallel_workers=None, cache=None): + super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache) + + def parse(self, children=None): + raise NotImplementedError("Dataset has to implement parse method.") + + class TextBaseDataset(Dataset): """ Abstract class to represent a text source dataset which produces content to the data pipeline. @@ -1928,6 +1984,30 @@ class TextBaseDataset(Dataset): return vocab +class AudioBaseDataset(Dataset): + """ + Abstract class to represent a audio source dataset which produces content to the data pipeline. + """ + + def __init__(self, children=None, num_parallel_workers=None, cache=None): + super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache) + + def parse(self, children=None): + raise NotImplementedError("Dataset has to implement parse method.") + + +class UnionBaseDataset(VisionBaseDataset, TextBaseDataset, AudioBaseDataset): + """ + Abstract class to represent a union source dataset which produces content to the data pipeline. + """ + + def __init__(self, children=None, num_parallel_workers=None, cache=None): + super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache) + + def parse(self, children=None): + raise NotImplementedError("Dataset has to implement parse method.") + + class SourceDataset(Dataset): """ Abstract class to represent a source dataset which produces content to the data pipeline. @@ -2159,7 +2239,7 @@ class MappableDataset(SourceDataset): return tuple(splits) -class BucketBatchByLengthDataset(Dataset): +class BucketBatchByLengthDataset(UnionBaseDataset): """ The result of applying BucketBatchByLength operator to the input dataset. """ @@ -2211,7 +2291,7 @@ def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1): raise RuntimeError("Expected /dev/shm to exist.") -class BatchDataset(Dataset): +class BatchDataset(UnionBaseDataset): """ The result of applying Batch operator to the input dataset. @@ -2482,7 +2562,7 @@ class BlockReleasePair: self.cv.notify_all() -class SyncWaitDataset(Dataset): +class SyncWaitDataset(UnionBaseDataset): """ The result of adding a blocking condition to the input Dataset. @@ -2551,7 +2631,7 @@ class SyncWaitDataset(Dataset): self._pair.reset() -class ShuffleDataset(Dataset): +class ShuffleDataset(UnionBaseDataset): """ The result of applying Shuffle operator to the input Dataset. @@ -2830,7 +2910,7 @@ class _ExceptHookHandler: _mp_pool_exit_preprocess() -class MapDataset(TextBaseDataset, Dataset): +class MapDataset(UnionBaseDataset): """ The result of applying the Map operator to the input Dataset. @@ -3003,7 +3083,7 @@ class MapDataset(TextBaseDataset, Dataset): self._abort_watchdog() -class FilterDataset(Dataset): +class FilterDataset(UnionBaseDataset): """ The result of applying filter predicate to the input Dataset. @@ -3025,7 +3105,7 @@ class FilterDataset(Dataset): return cde.FilterNode(children[0], self.predicate, self.input_columns) -class RepeatDataset(Dataset): +class RepeatDataset(UnionBaseDataset): """ The result of applying Repeat operator to the input Dataset. @@ -3042,7 +3122,7 @@ class RepeatDataset(Dataset): return cde.RepeatNode(children[0], self.count) -class SkipDataset(Dataset): +class SkipDataset(UnionBaseDataset): """ The result of applying Skip operator to the input Dataset. @@ -3059,7 +3139,7 @@ class SkipDataset(Dataset): return cde.SkipNode(children[0], self.count) -class TakeDataset(Dataset): +class TakeDataset(UnionBaseDataset): """ The result of applying Take operator to the input Dataset. @@ -3076,7 +3156,7 @@ class TakeDataset(Dataset): return cde.TakeNode(children[0], self.count) -class ZipDataset(Dataset): +class ZipDataset(UnionBaseDataset): """ The result of applying Zip operator to the input Dataset. @@ -3097,7 +3177,7 @@ class ZipDataset(Dataset): return any([c.is_sync() for c in self.children]) -class ConcatDataset(Dataset): +class ConcatDataset(UnionBaseDataset): """ The result of applying concat dataset operator to the input Dataset. @@ -3126,7 +3206,7 @@ class ConcatDataset(Dataset): child_index += 1 # _children_flag_and_nums: A list of pair.The first element of pair is flag that characterizes - # whether the data set is mappable. The second element of pair is length of the dataset + # whether the dataset is mappable. The second element of pair is length of the dataset self._children_flag_and_nums = [] # _children_start_end_index_: A list of pair.The elements of pair are used to characterize @@ -3208,7 +3288,7 @@ class ConcatDataset(Dataset): cumulative_samples_nums %= sampler.num_shards -class RenameDataset(Dataset): +class RenameDataset(UnionBaseDataset): """ The result of applying Rename operator to the input Dataset. @@ -3237,7 +3317,7 @@ def to_list(items): return items -class ProjectDataset(Dataset): +class ProjectDataset(UnionBaseDataset): """ The result of applying Project operator to the input Dataset. @@ -3403,37 +3483,6 @@ class TransferDataset(Dataset): self._to_device.release() -class RangeDataset(MappableDataset): - """ - A source dataset that reads and parses datasets stored on disk in a range. - - Args: - start (int): Starting index. - stop (int): Ending index. - step (int): Step size in the range specified by start and stop. - """ - - def __init__(self, start, stop, step): - super().__init__() - self.start = start - self.stop = stop - self.step = step - - def parse(self, children=None): - raise NotImplementedError("Dataset has to implement parse method.") - - def is_shuffled(self): - return False - - def is_sharded(self): - return False - - def get_dataset_size(self): - if self.dataset_size is None: - self.dataset_size = math.ceil((self.stop - self.start) / self.step) - return self.dataset_size - - class Schema: """ Class to represent a schema of a dataset. diff --git a/mindspore/python/mindspore/dataset/engine/datasets_audio.py b/mindspore/python/mindspore/dataset/engine/datasets_audio.py index c87a266569a..8389f7b6f1b 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets_audio.py +++ b/mindspore/python/mindspore/dataset/engine/datasets_audio.py @@ -13,21 +13,26 @@ # limitations under the License. # ============================================================================== """ -This dataset module supports various formats of datasets, including ImageNet, TFData, -MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with -high performance and parses data precisely. Some of the operations that are -provided to users to preprocess data include shuffle, batch, repeat, map, and zip. +This file contains specific audio dataset loading classes. You can easily use +these classes to load the prepared dataset. For example: + LJSpeechDataset: which is lj speech dataset. + YesNoDataset: which is yes or no dataset. + SpeechCommandsDataset: which is speech commands dataset. + TedliumDataset: which is tedlium dataset. + ... +After declaring the dataset object, you can further apply dataset operations +(e.g. filter, skip, concat, map, batch) on it. """ import mindspore._c_dataengine as cde -from .datasets import MappableDataset +from .datasets import AudioBaseDataset, MappableDataset from .validators import check_lj_speech_dataset, check_yes_no_dataset, check_speech_commands_dataset, \ check_tedlium_dataset from ..core.validator_helpers import replace_none -class LJSpeechDataset(MappableDataset): +class LJSpeechDataset(MappableDataset, AudioBaseDataset): """ A source dataset for reading and parsing LJSpeech dataset. @@ -163,7 +168,7 @@ class LJSpeechDataset(MappableDataset): return cde.LJSpeechNode(self.dataset_dir, self.sampler) -class SpeechCommandsDataset(MappableDataset): +class SpeechCommandsDataset(MappableDataset, AudioBaseDataset): """ A source dataset for reading and parsing the SpeechCommands dataset. @@ -287,7 +292,7 @@ class SpeechCommandsDataset(MappableDataset): return cde.SpeechCommandsNode(self.dataset_dir, self.usage, self.sampler) -class TedliumDataset(MappableDataset): +class TedliumDataset(MappableDataset, AudioBaseDataset): """ A source dataset for reading and parsing Tedlium dataset. The columns of generated dataset depend on the source SPH files and the corresponding STM files. @@ -499,7 +504,7 @@ class TedliumDataset(MappableDataset): return cde.TedliumNode(self.dataset_dir, self.release, self.usage, self.extensions, self.sampler) -class YesNoDataset(MappableDataset): +class YesNoDataset(MappableDataset, AudioBaseDataset): """ A source dataset for reading and parsing the YesNo dataset. diff --git a/mindspore/python/mindspore/dataset/engine/datasets_standard_format.py b/mindspore/python/mindspore/dataset/engine/datasets_standard_format.py index 13d0f21e878..b8fb1ac9b4c 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets_standard_format.py +++ b/mindspore/python/mindspore/dataset/engine/datasets_standard_format.py @@ -13,25 +13,90 @@ # limitations under the License. # ============================================================================== """ -This dataset module supports various formats of datasets, including ImageNet, TFData, -MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with -high performance and parses data precisely. Some of the operations that are -provided to users to preprocess data include shuffle, batch, repeat, map, and zip. +This file contains standard format dataset loading classes. +You can convert a dataset to a standard format using the following steps: + 1. Use mindspore.mindrecord.FileWriter / tf.io.TFRecordWriter api to + convert dataset to MindRecord / TFRecord. + 2. Use MindDataset / TFRecordDataset to load MindRecord / TFRecrod files. +After declaring the dataset object, you can further apply dataset operations +(e.g. filter, skip, concat, map, batch) on it. """ import numpy as np import mindspore._c_dataengine as cde from mindspore import log as logger -from .datasets import MappableDataset, SourceDataset, TextBaseDataset, Shuffle, Schema, \ +from .datasets import UnionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema, \ shuffle_to_shuffle_mode, shuffle_to_bool -from .validators import check_minddataset, check_tfrecorddataset +from .validators import check_minddataset, check_tfrecorddataset, check_csvdataset from ..core.validator_helpers import replace_none from . import samplers -class MindDataset(MappableDataset, TextBaseDataset): +class CSVDataset(SourceDataset, UnionBaseDataset): + """ + A source dataset that reads and parses comma-separated values (CSV) datasets. + The columns of generated dataset depend on the source CSV files. + + Args: + dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search + for a pattern of files. The list will be sorted in a lexicographical order. + field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=','). + column_defaults (list, optional): List of default values for the CSV field (default=None). Each item + in the list is either a valid type (float, int, or string). If this is not provided, treats all + columns as string type. + column_names (list[str], optional): List of column names of the dataset (default=None). If this + is not provided, infers the column_names from the first row of CSV file. + num_samples (int, optional): The number of samples to be included in the dataset + (default=None, will include all images). + num_parallel_workers (int, optional): Number of workers to read the data + (default=None, number set in the config). + shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch + (default=Shuffle.GLOBAL). + If shuffle is False, no shuffling will be performed; + If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL + Otherwise, there are two levels of shuffling: + + - Shuffle.GLOBAL: Shuffle both the files and samples. + + - Shuffle.FILES: Shuffle files only. + + num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). + When this argument is specified, `num_samples` reflects the maximum sample number of per shard. + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument can only be specified when num_shards is also specified. + cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. + (default=None, which means no cache is used). + + Raises: + RuntimeError: If dataset_files are not valid or do not exist. + RuntimeError: If num_parallel_workers exceeds the max thread numbers. + RuntimeError: If num_shards is specified but shard_id is None. + RuntimeError: If shard_id is specified but num_shards is None. + + Examples: + >>> csv_dataset_dir = ["/path/to/csv_dataset_file"] # contains 1 or multiple csv files + >>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4']) + """ + + @check_csvdataset + def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, + num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, + num_shards=num_shards, shard_id=shard_id, cache=cache) + self.dataset_files = self._find_files(dataset_files) + self.dataset_files.sort() + self.field_delim = replace_none(field_delim, ',') + self.column_defaults = replace_none(column_defaults, []) + self.column_names = replace_none(column_names, []) + + def parse(self, children=None): + return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names, + self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id) + + +class MindDataset(MappableDataset, UnionBaseDataset): """ A source dataset for reading and parsing MindRecord dataset. @@ -160,7 +225,7 @@ class MindDataset(MappableDataset, TextBaseDataset): self.new_padded_sample[k] = v -class TFRecordDataset(SourceDataset, TextBaseDataset): +class TFRecordDataset(SourceDataset, UnionBaseDataset): """ A source dataset for reading and parsing datasets stored on disk in TFData format. diff --git a/mindspore/python/mindspore/dataset/engine/datasets_text.py b/mindspore/python/mindspore/dataset/engine/datasets_text.py index 447d1a04314..836d95703d7 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets_text.py +++ b/mindspore/python/mindspore/dataset/engine/datasets_text.py @@ -13,17 +13,22 @@ # limitations under the License. # ============================================================================== """ -This dataset module supports various formats of datasets, including ImageNet, TFData, -MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with -high performance and parses data precisely. Some of the operations that are -provided to users to preprocess data include shuffle, batch, repeat, map, and zip. +This file contains specific text dataset loading classes. You can easily use +these classes to load the prepared dataset. For example: + IMDBDataset: which is IMDB dataset. + WikiTextDataset: which is Wiki text dataset. + CLUEDataset: which is CLUE dataset. + YelpReviewDataset: which is yelp review dataset. + ... +After declaring the dataset object, you can further apply dataset operations +(e.g. filter, skip, concat, map, batch) on it. """ import mindspore._c_dataengine as cde -from .datasets import MappableDataset, SourceDataset, TextBaseDataset, Shuffle +from .datasets import TextBaseDataset, SourceDataset, MappableDataset, Shuffle from .validators import check_imdb_dataset, check_iwslt2016_dataset, check_iwslt2017_dataset, \ check_penn_treebank_dataset, check_ag_news_dataset, check_amazon_review_dataset, check_udpos_dataset, \ - check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, check_csvdataset, \ + check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, \ check_sogou_news_dataset, check_textfiledataset, check_dbpedia_dataset, check_yelp_review_dataset, \ check_en_wik9_dataset, check_yahoo_answers_dataset @@ -117,7 +122,7 @@ class AGNewsDataset(SourceDataset, TextBaseDataset): self.shard_id) -class AmazonReviewDataset(SourceDataset): +class AmazonReviewDataset(SourceDataset, TextBaseDataset): """ A source dataset that reads and parses Amazon Review Polarity and Amazon Review Full datasets. @@ -356,7 +361,7 @@ class CLUEDataset(SourceDataset, TextBaseDataset): self.num_shards, self.shard_id) -class CoNLL2000Dataset(SourceDataset): +class CoNLL2000Dataset(SourceDataset, TextBaseDataset): """ A source dataset that reads and parses CoNLL2000 dataset. @@ -414,68 +419,6 @@ class CoNLL2000Dataset(SourceDataset): self.shard_id) -class CSVDataset(SourceDataset, TextBaseDataset): - """ - A source dataset that reads and parses comma-separated values (CSV) datasets. - The columns of generated dataset depend on the source CSV files. - - Args: - dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search - for a pattern of files. The list will be sorted in a lexicographical order. - field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=','). - column_defaults (list, optional): List of default values for the CSV field (default=None). Each item - in the list is either a valid type (float, int, or string). If this is not provided, treats all - columns as string type. - column_names (list[str], optional): List of column names of the dataset (default=None). If this - is not provided, infers the column_names from the first row of CSV file. - num_samples (int, optional): The number of samples to be included in the dataset - (default=None, will include all images). - num_parallel_workers (int, optional): Number of workers to read the data - (default=None, number set in the config). - shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch - (default=Shuffle.GLOBAL). - If shuffle is False, no shuffling will be performed; - If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL - Otherwise, there are two levels of shuffling: - - - Shuffle.GLOBAL: Shuffle both the files and samples. - - - Shuffle.FILES: Shuffle files only. - - num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). - When this argument is specified, `num_samples` reflects the maximum sample number of per shard. - shard_id (int, optional): The shard ID within num_shards (default=None). This - argument can only be specified when num_shards is also specified. - cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. - (default=None, which means no cache is used). - - Raises: - RuntimeError: If dataset_files are not valid or do not exist. - RuntimeError: If num_parallel_workers exceeds the max thread numbers. - RuntimeError: If num_shards is specified but shard_id is None. - RuntimeError: If shard_id is specified but num_shards is None. - - Examples: - >>> csv_dataset_dir = ["/path/to/csv_dataset_file"] # contains 1 or multiple csv files - >>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4']) - """ - - @check_csvdataset - def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, - num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, - num_shards=num_shards, shard_id=shard_id, cache=cache) - self.dataset_files = self._find_files(dataset_files) - self.dataset_files.sort() - self.field_delim = replace_none(field_delim, ',') - self.column_defaults = replace_none(column_defaults, []) - self.column_names = replace_none(column_names, []) - - def parse(self, children=None): - return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names, - self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id) - - class DBpediaDataset(SourceDataset, TextBaseDataset): """ A source dataset that reads and parses the DBpedia dataset. @@ -572,7 +515,7 @@ class DBpediaDataset(SourceDataset, TextBaseDataset): self.shard_id) -class EnWik9Dataset(SourceDataset): +class EnWik9Dataset(SourceDataset, TextBaseDataset): """ A source dataset that reads and parses EnWik9 dataset. @@ -647,7 +590,7 @@ class EnWik9Dataset(SourceDataset): return cde.EnWik9Node(self.dataset_dir, self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id) -class IMDBDataset(MappableDataset): +class IMDBDataset(MappableDataset, TextBaseDataset): """ A source dataset for reading and parsing Internet Movie Database (IMDb). @@ -726,7 +669,7 @@ class IMDBDataset(MappableDataset): About IMDBDataset: - The IMDB dataset contains 50, 000 highly polarized reviews from the Internet Movie Database (IMDB). The data set + The IMDB dataset contains 50, 000 highly polarized reviews from the Internet Movie Database (IMDB). The dataset was divided into 25 000 comments for training and 25 000 comments for testing, with both the training set and test set containing 50% positive and 50% negative comments. Train labels and test labels are all lists of 0 and 1, where 0 stands for negative and 1 for positive. @@ -841,13 +784,13 @@ class IWSLT2016Dataset(SourceDataset, TextBaseDataset): About IWSLT2016 dataset: IWSLT is an international oral translation conference, a major annual scientific conference dedicated to all aspects - of oral translation. The MT task of the IWSLT evaluation activity constitutes a data set, which can be publicly - obtained through the WIT3 website wit3.fbk.eu. The IWSLT2016 data set includes translations from English to Arabic, + of oral translation. The MT task of the IWSLT evaluation activity constitutes a dataset, which can be publicly + obtained through the WIT3 website wit3.fbk.eu. The IWSLT2016 dataset includes translations from English to Arabic, Czech, French, and German, and translations from Arabic, Czech, French, and German to English. You can unzip the original IWSLT2016 dataset files into this directory structure and read by MindSpore's API. After - decompression, you also need to decompress the data set to be read in the specified folder. For example, if you want - to read the data set of de-en, you need to unzip the tgz file in the de/en directory, the data set is in the + decompression, you also need to decompress the dataset to be read in the specified folder. For example, if you want + to read the dataset of de-en, you need to unzip the tgz file in the de/en directory, the dataset is in the unzipped folder. .. code-block:: @@ -961,9 +904,9 @@ class IWSLT2017Dataset(SourceDataset, TextBaseDataset): About IWSLT2017 dataset: IWSLT is an international oral translation conference, a major annual scientific conference dedicated to all aspects - of oral translation. The MT task of the IWSLT evaluation activity constitutes a data set, which can be publicly - obtained through the WIT3 website wit3.fbk.eu. The IWSLT2017 data set involves German, English, Italian, Dutch, and - Romanian. The data set includes translations in any two different languages. + of oral translation. The MT task of the IWSLT evaluation activity constitutes a dataset, which can be publicly + obtained through the WIT3 website wit3.fbk.eu. The IWSLT2017 dataset involves German, English, Italian, Dutch, and + Romanian. The dataset includes translations in any two different languages. You can unzip the original IWSLT2017 dataset files into this directory structure and read by MindSpore's API. You need to decompress the dataset package in texts/DeEnItNlRo/DeEnItNlRo directory to get the DeEnItNlRo-DeEnItNlRo @@ -1099,7 +1042,7 @@ class PennTreebankDataset(SourceDataset, TextBaseDataset): self.shard_id) -class SogouNewsDataset(SourceDataset): +class SogouNewsDataset(SourceDataset, TextBaseDataset): """ A source dataset that reads and parses Sogou News dataset. @@ -1239,7 +1182,7 @@ class TextFileDataset(SourceDataset, TextBaseDataset): self.shard_id) -class UDPOSDataset(SourceDataset): +class UDPOSDataset(SourceDataset, TextBaseDataset): """ A source dataset that reads and parses UDPOS dataset. @@ -1297,7 +1240,7 @@ class UDPOSDataset(SourceDataset): self.shard_id) -class WikiTextDataset(SourceDataset): +class WikiTextDataset(SourceDataset, TextBaseDataset): """ A source dataset that reads and parses WikiText2 and WikiText103 datasets. @@ -1375,7 +1318,7 @@ class WikiTextDataset(SourceDataset): self.shard_id) -class YahooAnswersDataset(SourceDataset): +class YahooAnswersDataset(SourceDataset, TextBaseDataset): """ A source dataset that reads and parses the YahooAnswers dataset. diff --git a/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py b/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py index 027d75e32dd..e6e6593fe0e 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py +++ b/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py @@ -13,10 +13,13 @@ # limitations under the License. # ============================================================================== """ -This dataset module supports various formats of datasets, including ImageNet, TFData, -MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with -high performance and parses data precisely. Some of the operations that are -provided to users to preprocess data include shuffle, batch, repeat, map, and zip. +This file contains contains basic classes that help users do flexible dataset loading. +You can define your own dataset loading class, and use GeneratorDataset to help load data. +You can refer to the +`tutorial ` +to help define your dataset loading. +After declaring the dataset object, you can further apply dataset operations +(e.g. filter, skip, concat, map, batch) on it. """ import builtins import math @@ -38,7 +41,7 @@ import mindspore._c_dataengine as cde from mindspore.common import Tensor from mindspore import log as logger -from .datasets import MappableDataset, TextBaseDataset, Schema, to_list, _watch_dog, _check_shm_usage +from .datasets import UnionBaseDataset, MappableDataset, Schema, to_list, _watch_dog, _check_shm_usage from . import samplers from .queue import _SharedQueue from .validators import check_generatordataset, check_numpyslicesdataset, check_paddeddataset @@ -427,7 +430,7 @@ class _GeneratorWorkerMp(multiprocessing.Process): return True -class GeneratorDataset(MappableDataset, TextBaseDataset): +class GeneratorDataset(MappableDataset, UnionBaseDataset): """ A source dataset that generates data from Python by invoking Python data source each epoch. @@ -873,7 +876,7 @@ class _PaddedDataset: class PaddedDataset(GeneratorDataset): """ - Creates a dataset with filler data provided by user. Mainly used to add to the original data set + Creates a dataset with filler data provided by user. Mainly used to add to the original dataset and assign it to the corresponding shard. Args: diff --git a/mindspore/python/mindspore/dataset/engine/datasets_vision.py b/mindspore/python/mindspore/dataset/engine/datasets_vision.py index e82f4c8dd92..d925dfe7a3e 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets_vision.py +++ b/mindspore/python/mindspore/dataset/engine/datasets_vision.py @@ -13,10 +13,15 @@ # limitations under the License. # ============================================================================== """ -This dataset module supports various formats of datasets, including ImageNet, TFData, -MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with -high performance and parses data precisely. Some of the operations that are -provided to users to preprocess data include shuffle, batch, repeat, map, and zip. +This file contains specific vision dataset loading classes. You can easily use +these classes to load the prepared dataset. For example: + ImageFolderDataset: which is about ImageNet dataset. + Cifar10Dataset: which is cifar10 binary version dataset. + Cifar100Dataset: which is cifar100 binary version dataset. + MnistDataset: which is mnist dataset. + ... +After declaring the dataset object, you can further apply dataset operations +(e.g. filter, skip, concat, map, batch) on it. """ import os import numpy as np @@ -25,7 +30,7 @@ from PIL import Image import mindspore._c_dataengine as cde -from .datasets import MappableDataset, SourceDataset, Shuffle, Schema +from .datasets import VisionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema from .datasets_user_defined import GeneratorDataset from .validators import check_imagefolderdataset, \ check_mnist_cifar_dataset, check_manifestdataset, check_vocdataset, check_cocodataset, \ @@ -268,7 +273,7 @@ class Caltech101Dataset(GeneratorDataset): return class_dict -class Caltech256Dataset(MappableDataset): +class Caltech256Dataset(MappableDataset, VisionBaseDataset): """ A source dataset that reads and parses Caltech256 dataset. @@ -394,7 +399,7 @@ class Caltech256Dataset(MappableDataset): return cde.Caltech256Node(self.dataset_dir, self.decode, self.sampler) -class CelebADataset(MappableDataset): +class CelebADataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing CelebA dataset. Only support to read `list_attr_celeba.txt` currently, which is the attribute annotations of the dataset. @@ -558,7 +563,7 @@ class CelebADataset(MappableDataset): -class Cifar10Dataset(MappableDataset): +class Cifar10Dataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing Cifar10 dataset. This api only supports parsing Cifar10 file in binary version now. @@ -689,7 +694,7 @@ class Cifar10Dataset(MappableDataset): return cde.Cifar10Node(self.dataset_dir, self.usage, self.sampler) -class Cifar100Dataset(MappableDataset): +class Cifar100Dataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing Cifar100 dataset. @@ -813,7 +818,7 @@ class Cifar100Dataset(MappableDataset): return cde.Cifar100Node(self.dataset_dir, self.usage, self.sampler) -class CityscapesDataset(MappableDataset): +class CityscapesDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing Cityscapes dataset. @@ -983,7 +988,7 @@ class CityscapesDataset(MappableDataset): return cde.CityscapesNode(self.dataset_dir, self.usage, self.quality_mode, self.task, self.decode, self.sampler) -class CocoDataset(MappableDataset): +class CocoDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing COCO dataset. @@ -1194,7 +1199,7 @@ class CocoDataset(MappableDataset): return self._class_indexing -class DIV2KDataset(MappableDataset): +class DIV2KDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing DIV2KDataset dataset. @@ -1381,7 +1386,7 @@ class DIV2KDataset(MappableDataset): return cde.DIV2KNode(self.dataset_dir, self.usage, self.downgrade, self.scale, self.decode, self.sampler) -class EMnistDataset(MappableDataset): +class EMnistDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing the EMNIST dataset. @@ -1513,7 +1518,7 @@ class EMnistDataset(MappableDataset): return cde.EMnistNode(self.dataset_dir, self.name, self.usage, self.sampler) -class FakeImageDataset(MappableDataset): +class FakeImageDataset(MappableDataset, VisionBaseDataset): """ A source dataset for generating fake images. @@ -1602,7 +1607,7 @@ class FakeImageDataset(MappableDataset): return cde.FakeImageNode(self.num_images, self.image_size, self.num_classes, self.base_seed, self.sampler) -class FashionMnistDataset(MappableDataset): +class FashionMnistDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing the FASHION-MNIST dataset. @@ -1723,7 +1728,7 @@ class FashionMnistDataset(MappableDataset): return cde.FashionMnistNode(self.dataset_dir, self.usage, self.sampler) -class FlickrDataset(MappableDataset): +class FlickrDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing Flickr8k and Flickr30k dataset. @@ -2128,7 +2133,7 @@ class Flowers102Dataset(GeneratorDataset): return class_dict -class ImageFolderDataset(MappableDataset): +class ImageFolderDataset(MappableDataset, VisionBaseDataset): """ A source dataset that reads images from a tree of directories. All images within one folder have the same label. @@ -2257,7 +2262,7 @@ class ImageFolderDataset(MappableDataset): return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing) -class KMnistDataset(MappableDataset): +class KMnistDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing the KMNIST dataset. @@ -2378,7 +2383,7 @@ class KMnistDataset(MappableDataset): return cde.KMnistNode(self.dataset_dir, self.usage, self.sampler) -class ManifestDataset(MappableDataset): +class ManifestDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading images from a Manifest file. @@ -2497,7 +2502,7 @@ class ManifestDataset(MappableDataset): return self.class_indexing -class MnistDataset(MappableDataset): +class MnistDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing the MNIST dataset. @@ -2617,7 +2622,7 @@ class MnistDataset(MappableDataset): return cde.MnistNode(self.dataset_dir, self.usage, self.sampler) -class PhotoTourDataset(MappableDataset): +class PhotoTourDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing the PhotoTour dataset. @@ -2770,7 +2775,7 @@ class PhotoTourDataset(MappableDataset): return cde.PhotoTourNode(self.dataset_dir, self.name, self.usage, self.sampler) -class Places365Dataset(MappableDataset): +class Places365Dataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing the Places365 dataset. @@ -2911,7 +2916,7 @@ class Places365Dataset(MappableDataset): return cde.Places365Node(self.dataset_dir, self.usage, self.small, self.decode, self.sampler) -class QMnistDataset(MappableDataset): +class QMnistDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing the QMNIST dataset. @@ -3036,7 +3041,7 @@ class QMnistDataset(MappableDataset): return cde.QMnistNode(self.dataset_dir, self.usage, self.compat, self.sampler) -class RandomDataset(SourceDataset): +class RandomDataset(SourceDataset, VisionBaseDataset): """ A source dataset that generates random data. @@ -3268,7 +3273,7 @@ class SBDataset(GeneratorDataset): num_shards=num_shards, shard_id=shard_id) -class SBUDataset(MappableDataset): +class SBUDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing the SBU dataset. @@ -3382,7 +3387,7 @@ class SBUDataset(MappableDataset): return cde.SBUNode(self.dataset_dir, self.decode, self.sampler) -class SemeionDataset(MappableDataset): +class SemeionDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing Semeion dataset. @@ -3502,7 +3507,7 @@ class SemeionDataset(MappableDataset): return cde.SemeionNode(self.dataset_dir, self.sampler) -class STL10Dataset(MappableDataset): +class STL10Dataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing STL10 dataset. @@ -3784,7 +3789,7 @@ class SVHNDataset(GeneratorDataset): num_shards=num_shards, shard_id=shard_id) -class USPSDataset(SourceDataset): +class USPSDataset(SourceDataset, VisionBaseDataset): """ A source dataset for reading and parsing the USPS dataset. @@ -3881,7 +3886,7 @@ class USPSDataset(SourceDataset): self.shard_id) -class VOCDataset(MappableDataset): +class VOCDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing VOC dataset. @@ -4083,7 +4088,7 @@ class VOCDataset(MappableDataset): return self.class_indexing -class WIDERFaceDataset(MappableDataset): +class WIDERFaceDataset(MappableDataset, VisionBaseDataset): """ A source dataset for reading and parsing WIDERFace dataset.