!28864 enhance dataset step2 for br:r1.6

Merge pull request !28864 from guozhijian/enhance_dataset_r1.6
This commit is contained in:
i-robot 2022-01-13 08:32:50 +00:00 committed by Gitee
commit c73cac6a72
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 321 additions and 251 deletions

View File

@ -36,69 +36,69 @@ from .iterators import *
from .samplers import *
from .serializer_deserializer import compare, deserialize, serialize, show
__all__ = ["Caltech101Dataset", # vision dataset
"Caltech256Dataset", # vision dataset
"CelebADataset", # vision dataset
"Cifar10Dataset", # vision dataset
"Cifar100Dataset", # vision dataset
"CityscapesDataset", # vision dataset
"CocoDataset", # vision dataset
"DIV2KDataset", # vision dataset
"EMnistDataset", # vision dataset
"FakeImageDataset", # vision dataset
"FashionMnistDataset", # vision dataset
"FlickrDataset", # vision dataset
"Flowers102Dataset", # vision dataset
"ImageFolderDataset", # vision dataset
"KMnistDataset", # vision dataset
"ManifestDataset", # vision dataset
"MnistDataset", # vision dataset
"PhotoTourDataset", # vision dataset
"Places365Dataset", # vision dataset
"QMnistDataset", # vision dataset
"RandomDataset", # vision dataset
"SBDataset", # vision dataset
"SBUDataset", # vision dataset
"SemeionDataset", # vision dataset
"STL10Dataset", # vision dataset
"SVHNDataset", # vision dataset
"USPSDataset", # vision dataset
"VOCDataset", # vision dataset
"WIDERFaceDataset", # vision dataset
"AGNewsDataset", # text dataset
"AmazonReviewDataset", # text dataset
"CLUEDataset", # text dataset
"CoNLL2000Dataset", # text dataset
"CSVDataset", # text dataset
"DBpediaDataset", # text dataset
"EnWik9Dataset", # text dataset
"IMDBDataset", # text dataset
"IWSLT2016Dataset", # text dataset
"IWSLT2017Dataset", # text dataset
"PennTreebankDataset", # text dataset
"SogouNewsDataset", # text dataset
"TextFileDataset", # text dataset
"UDPOSDataset", # text dataset
"WikiTextDataset", # text dataset
"YahooAnswersDataset", # text dataset
"YelpReviewDataset", # text dataset
"LJSpeechDataset", # audio dataset
"SpeechCommandsDataset", # audio dataset
"TedliumDataset", # audio dataset
"YesNoDataset", # audio dataset
"MindDataset", # standard format dataset
"TFRecordDataset", # standard format dataset
"GeneratorDataset", # user defined dataset
"NumpySlicesDataset", # user defined dataset
"PaddedDataset", # user defined dataset
"GraphData", # graph data
"DistributedSampler", # sampler
"RandomSampler", # sampler
"SequentialSampler", # sampler
"SubsetRandomSampler", # sampler
"SubsetSampler", # sampler
"PKSampler", # sampler
"WeightedRandomSampler", # sampler
__all__ = ["Caltech101Dataset", # Vision
"Caltech256Dataset", # Vision
"CelebADataset", # Vision
"Cifar10Dataset", # Vision
"Cifar100Dataset", # Vision
"CityscapesDataset", # Vision
"CocoDataset", # Vision
"DIV2KDataset", # Vision
"EMnistDataset", # Vision
"FakeImageDataset", # Vision
"FashionMnistDataset", # Vision
"FlickrDataset", # Vision
"Flowers102Dataset", # Vision
"ImageFolderDataset", # Vision
"KMnistDataset", # Vision
"ManifestDataset", # Vision
"MnistDataset", # Vision
"PhotoTourDataset", # Vision
"Places365Dataset", # Vision
"QMnistDataset", # Vision
"RandomDataset", # Vision
"SBDataset", # Vision
"SBUDataset", # Vision
"SemeionDataset", # Vision
"STL10Dataset", # Vision
"SVHNDataset", # Vision
"USPSDataset", # Vision
"VOCDataset", # Vision
"WIDERFaceDataset", # Vision
"AGNewsDataset", # Text
"AmazonReviewDataset", # Text
"CLUEDataset", # Text
"CoNLL2000Dataset", # Text
"DBpediaDataset", # Text
"EnWik9Dataset", # Text
"IMDBDataset", # Text
"IWSLT2016Dataset", # Text
"IWSLT2017Dataset", # Text
"PennTreebankDataset", # Text
"SogouNewsDataset", # Text
"TextFileDataset", # Text
"UDPOSDataset", # Text
"WikiTextDataset", # Text
"YahooAnswersDataset", # Text
"YelpReviewDataset", # Text
"LJSpeechDataset", # Audio
"SpeechCommandsDataset", # Audio
"TedliumDataset", # Audio
"YesNoDataset", # Audio
"CSVDataset", # Standard Format
"MindDataset", # Standard Format
"TFRecordDataset", # Standard Format
"GeneratorDataset", # User Defined
"NumpySlicesDataset", # User Defined
"PaddedDataset", # User Defined
"GraphData", # Graph Data
"DistributedSampler", # Sampler
"RandomSampler", # Sampler
"SequentialSampler", # Sampler
"SubsetRandomSampler", # Sampler
"SubsetSampler", # Sampler
"PKSampler", # Sampler
"WeightedRandomSampler", # Sampler
"DatasetCache",
"DSCallback",
"WaitedDSCallback",

View File

@ -13,15 +13,21 @@
# limitations under the License.
# ==============================================================================
"""
This dataset module supports various formats of datasets, including ImageNet, TFData,
MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
high performance and parses data precisely. Some of the operations that are
provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
1. This file is an abstraction of the dataset loading class. It contains
some basic dataset operations(skip, filter, map, batch, ...).
2. Specific dataset loading classes can be found in datasets_vision.py, datasets_text.py,
datasets_audio.py, datasets_standard_format.py and dataets_user_defined.py files.
datasets_vision.py: contains vision dataset loading classes.
datasets_text.py: contains text dataset loading classes.
datasets_audio.py: contains audio dataset loading classes.
datasets_standard_format.py: contains standard format loading classes which
any other kinds of datasets can be converted to.
dataets_user_defined.py: contains basic classes that help users to define
flexible ways to load dataset.
"""
import atexit
import glob
import json
import math
import os
import signal
import stat
@ -215,6 +221,44 @@ class Dataset:
This class is the base class of SourceDataset and Dataset, and represents
a node in the data flow graph.
Dataset
-----------------------------------------------------------
| | | |
VisionBaseDataset TextBaseDataset AudioBaseDataset |
- - - |
| | | |
---------------------------------------- |
UnionBaseDataset |
|
SourceDataset
-
|
MappableDataset
DatasetOperator: MapDataset(UnionBaseDataset)
BatchDataset(UnionBaseDataset)
BucketBatchByLengthDataset(UnionBaseDataset)
ShuffleDataset(UnionBaseDataset)
FilterDataset(UnionBaseDataset)
RepeatDataset(UnionBaseDataset)
SkipDataset(UnionBaseDataset)
TakeDataset(UnionBaseDataset)
ZipDataset(UnionBaseDataset)
ConcatDataset(UnionBaseDataset)
RenameDataset(UnionBaseDataset)
ProjectDataset(UnionBaseDataset)
SyncWaitDataset(UnionBaseDataset)
Impl Dataset - vision: ImageFolderDataset(MappableDataset, VisionBaseDataset)
USPSDataset(SourceDataset, VisionBaseDataset)
Impl Dataset - text: TextFileDataset(SourceDataset, TextBaseDataset)
YahooAnswersDataset(SourceDataset, TextBaseDataset)
Impl Dataset - audio: LJSpeechDataset(MappableDataset, AudioBaseDataset)
TedliumDataset(MappableDataset, AudioBaseDataset)
Impl Dataset - standard: MindDataset(MappableDataset, UnionBaseDataset)
TFRecordDataset(SourceDataset, UnionBaseDataset)
Impl Dataset - user defined: GeneratorDataset(MappableDataset, UnionBaseDataset)
NumpySlicesDataset(GeneratorDataset)
Args:
num_parallel_workers (int, optional): Number of workers to process the dataset in parallel
@ -1796,6 +1840,18 @@ class Dataset:
return ir_node
class VisionBaseDataset(Dataset):
"""
Abstract class to represent a vision source dataset which produces content to the data pipeline.
"""
def __init__(self, children=None, num_parallel_workers=None, cache=None):
super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
def parse(self, children=None):
raise NotImplementedError("Dataset has to implement parse method.")
class TextBaseDataset(Dataset):
"""
Abstract class to represent a text source dataset which produces content to the data pipeline.
@ -1928,6 +1984,30 @@ class TextBaseDataset(Dataset):
return vocab
class AudioBaseDataset(Dataset):
"""
Abstract class to represent a audio source dataset which produces content to the data pipeline.
"""
def __init__(self, children=None, num_parallel_workers=None, cache=None):
super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
def parse(self, children=None):
raise NotImplementedError("Dataset has to implement parse method.")
class UnionBaseDataset(VisionBaseDataset, TextBaseDataset, AudioBaseDataset):
"""
Abstract class to represent a union source dataset which produces content to the data pipeline.
"""
def __init__(self, children=None, num_parallel_workers=None, cache=None):
super().__init__(children=children, num_parallel_workers=num_parallel_workers, cache=cache)
def parse(self, children=None):
raise NotImplementedError("Dataset has to implement parse method.")
class SourceDataset(Dataset):
"""
Abstract class to represent a source dataset which produces content to the data pipeline.
@ -2159,7 +2239,7 @@ class MappableDataset(SourceDataset):
return tuple(splits)
class BucketBatchByLengthDataset(Dataset):
class BucketBatchByLengthDataset(UnionBaseDataset):
"""
The result of applying BucketBatchByLength operator to the input dataset.
"""
@ -2211,7 +2291,7 @@ def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1):
raise RuntimeError("Expected /dev/shm to exist.")
class BatchDataset(Dataset):
class BatchDataset(UnionBaseDataset):
"""
The result of applying Batch operator to the input dataset.
@ -2482,7 +2562,7 @@ class BlockReleasePair:
self.cv.notify_all()
class SyncWaitDataset(Dataset):
class SyncWaitDataset(UnionBaseDataset):
"""
The result of adding a blocking condition to the input Dataset.
@ -2551,7 +2631,7 @@ class SyncWaitDataset(Dataset):
self._pair.reset()
class ShuffleDataset(Dataset):
class ShuffleDataset(UnionBaseDataset):
"""
The result of applying Shuffle operator to the input Dataset.
@ -2830,7 +2910,7 @@ class _ExceptHookHandler:
_mp_pool_exit_preprocess()
class MapDataset(TextBaseDataset, Dataset):
class MapDataset(UnionBaseDataset):
"""
The result of applying the Map operator to the input Dataset.
@ -3003,7 +3083,7 @@ class MapDataset(TextBaseDataset, Dataset):
self._abort_watchdog()
class FilterDataset(Dataset):
class FilterDataset(UnionBaseDataset):
"""
The result of applying filter predicate to the input Dataset.
@ -3025,7 +3105,7 @@ class FilterDataset(Dataset):
return cde.FilterNode(children[0], self.predicate, self.input_columns)
class RepeatDataset(Dataset):
class RepeatDataset(UnionBaseDataset):
"""
The result of applying Repeat operator to the input Dataset.
@ -3042,7 +3122,7 @@ class RepeatDataset(Dataset):
return cde.RepeatNode(children[0], self.count)
class SkipDataset(Dataset):
class SkipDataset(UnionBaseDataset):
"""
The result of applying Skip operator to the input Dataset.
@ -3059,7 +3139,7 @@ class SkipDataset(Dataset):
return cde.SkipNode(children[0], self.count)
class TakeDataset(Dataset):
class TakeDataset(UnionBaseDataset):
"""
The result of applying Take operator to the input Dataset.
@ -3076,7 +3156,7 @@ class TakeDataset(Dataset):
return cde.TakeNode(children[0], self.count)
class ZipDataset(Dataset):
class ZipDataset(UnionBaseDataset):
"""
The result of applying Zip operator to the input Dataset.
@ -3097,7 +3177,7 @@ class ZipDataset(Dataset):
return any([c.is_sync() for c in self.children])
class ConcatDataset(Dataset):
class ConcatDataset(UnionBaseDataset):
"""
The result of applying concat dataset operator to the input Dataset.
@ -3126,7 +3206,7 @@ class ConcatDataset(Dataset):
child_index += 1
# _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
# whether the data set is mappable. The second element of pair is length of the dataset
# whether the dataset is mappable. The second element of pair is length of the dataset
self._children_flag_and_nums = []
# _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
@ -3208,7 +3288,7 @@ class ConcatDataset(Dataset):
cumulative_samples_nums %= sampler.num_shards
class RenameDataset(Dataset):
class RenameDataset(UnionBaseDataset):
"""
The result of applying Rename operator to the input Dataset.
@ -3237,7 +3317,7 @@ def to_list(items):
return items
class ProjectDataset(Dataset):
class ProjectDataset(UnionBaseDataset):
"""
The result of applying Project operator to the input Dataset.
@ -3403,37 +3483,6 @@ class TransferDataset(Dataset):
self._to_device.release()
class RangeDataset(MappableDataset):
"""
A source dataset that reads and parses datasets stored on disk in a range.
Args:
start (int): Starting index.
stop (int): Ending index.
step (int): Step size in the range specified by start and stop.
"""
def __init__(self, start, stop, step):
super().__init__()
self.start = start
self.stop = stop
self.step = step
def parse(self, children=None):
raise NotImplementedError("Dataset has to implement parse method.")
def is_shuffled(self):
return False
def is_sharded(self):
return False
def get_dataset_size(self):
if self.dataset_size is None:
self.dataset_size = math.ceil((self.stop - self.start) / self.step)
return self.dataset_size
class Schema:
"""
Class to represent a schema of a dataset.

View File

@ -13,21 +13,26 @@
# limitations under the License.
# ==============================================================================
"""
This dataset module supports various formats of datasets, including ImageNet, TFData,
MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
high performance and parses data precisely. Some of the operations that are
provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
This file contains specific audio dataset loading classes. You can easily use
these classes to load the prepared dataset. For example:
LJSpeechDataset: which is lj speech dataset.
YesNoDataset: which is yes or no dataset.
SpeechCommandsDataset: which is speech commands dataset.
TedliumDataset: which is tedlium dataset.
...
After declaring the dataset object, you can further apply dataset operations
(e.g. filter, skip, concat, map, batch) on it.
"""
import mindspore._c_dataengine as cde
from .datasets import MappableDataset
from .datasets import AudioBaseDataset, MappableDataset
from .validators import check_lj_speech_dataset, check_yes_no_dataset, check_speech_commands_dataset, \
check_tedlium_dataset
from ..core.validator_helpers import replace_none
class LJSpeechDataset(MappableDataset):
class LJSpeechDataset(MappableDataset, AudioBaseDataset):
"""
A source dataset for reading and parsing LJSpeech dataset.
@ -163,7 +168,7 @@ class LJSpeechDataset(MappableDataset):
return cde.LJSpeechNode(self.dataset_dir, self.sampler)
class SpeechCommandsDataset(MappableDataset):
class SpeechCommandsDataset(MappableDataset, AudioBaseDataset):
"""
A source dataset for reading and parsing the SpeechCommands dataset.
@ -287,7 +292,7 @@ class SpeechCommandsDataset(MappableDataset):
return cde.SpeechCommandsNode(self.dataset_dir, self.usage, self.sampler)
class TedliumDataset(MappableDataset):
class TedliumDataset(MappableDataset, AudioBaseDataset):
"""
A source dataset for reading and parsing Tedlium dataset.
The columns of generated dataset depend on the source SPH files and the corresponding STM files.
@ -499,7 +504,7 @@ class TedliumDataset(MappableDataset):
return cde.TedliumNode(self.dataset_dir, self.release, self.usage, self.extensions, self.sampler)
class YesNoDataset(MappableDataset):
class YesNoDataset(MappableDataset, AudioBaseDataset):
"""
A source dataset for reading and parsing the YesNo dataset.

View File

@ -13,25 +13,90 @@
# limitations under the License.
# ==============================================================================
"""
This dataset module supports various formats of datasets, including ImageNet, TFData,
MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
high performance and parses data precisely. Some of the operations that are
provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
This file contains standard format dataset loading classes.
You can convert a dataset to a standard format using the following steps:
1. Use mindspore.mindrecord.FileWriter / tf.io.TFRecordWriter api to
convert dataset to MindRecord / TFRecord.
2. Use MindDataset / TFRecordDataset to load MindRecord / TFRecrod files.
After declaring the dataset object, you can further apply dataset operations
(e.g. filter, skip, concat, map, batch) on it.
"""
import numpy as np
import mindspore._c_dataengine as cde
from mindspore import log as logger
from .datasets import MappableDataset, SourceDataset, TextBaseDataset, Shuffle, Schema, \
from .datasets import UnionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema, \
shuffle_to_shuffle_mode, shuffle_to_bool
from .validators import check_minddataset, check_tfrecorddataset
from .validators import check_minddataset, check_tfrecorddataset, check_csvdataset
from ..core.validator_helpers import replace_none
from . import samplers
class MindDataset(MappableDataset, TextBaseDataset):
class CSVDataset(SourceDataset, UnionBaseDataset):
"""
A source dataset that reads and parses comma-separated values (CSV) datasets.
The columns of generated dataset depend on the source CSV files.
Args:
dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search
for a pattern of files. The list will be sorted in a lexicographical order.
field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=',').
column_defaults (list, optional): List of default values for the CSV field (default=None). Each item
in the list is either a valid type (float, int, or string). If this is not provided, treats all
columns as string type.
column_names (list[str], optional): List of column names of the dataset (default=None). If this
is not provided, infers the column_names from the first row of CSV file.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, will include all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_files are not valid or do not exist.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
Examples:
>>> csv_dataset_dir = ["/path/to/csv_dataset_file"] # contains 1 or multiple csv files
>>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4'])
"""
@check_csvdataset
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
self.field_delim = replace_none(field_delim, ',')
self.column_defaults = replace_none(column_defaults, [])
self.column_names = replace_none(column_names, [])
def parse(self, children=None):
return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names,
self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id)
class MindDataset(MappableDataset, UnionBaseDataset):
"""
A source dataset for reading and parsing MindRecord dataset.
@ -160,7 +225,7 @@ class MindDataset(MappableDataset, TextBaseDataset):
self.new_padded_sample[k] = v
class TFRecordDataset(SourceDataset, TextBaseDataset):
class TFRecordDataset(SourceDataset, UnionBaseDataset):
"""
A source dataset for reading and parsing datasets stored on disk in TFData format.

View File

@ -13,17 +13,22 @@
# limitations under the License.
# ==============================================================================
"""
This dataset module supports various formats of datasets, including ImageNet, TFData,
MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
high performance and parses data precisely. Some of the operations that are
provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
This file contains specific text dataset loading classes. You can easily use
these classes to load the prepared dataset. For example:
IMDBDataset: which is IMDB dataset.
WikiTextDataset: which is Wiki text dataset.
CLUEDataset: which is CLUE dataset.
YelpReviewDataset: which is yelp review dataset.
...
After declaring the dataset object, you can further apply dataset operations
(e.g. filter, skip, concat, map, batch) on it.
"""
import mindspore._c_dataengine as cde
from .datasets import MappableDataset, SourceDataset, TextBaseDataset, Shuffle
from .datasets import TextBaseDataset, SourceDataset, MappableDataset, Shuffle
from .validators import check_imdb_dataset, check_iwslt2016_dataset, check_iwslt2017_dataset, \
check_penn_treebank_dataset, check_ag_news_dataset, check_amazon_review_dataset, check_udpos_dataset, \
check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, check_csvdataset, \
check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, \
check_sogou_news_dataset, check_textfiledataset, check_dbpedia_dataset, check_yelp_review_dataset, \
check_en_wik9_dataset, check_yahoo_answers_dataset
@ -117,7 +122,7 @@ class AGNewsDataset(SourceDataset, TextBaseDataset):
self.shard_id)
class AmazonReviewDataset(SourceDataset):
class AmazonReviewDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses Amazon Review Polarity and Amazon Review Full datasets.
@ -356,7 +361,7 @@ class CLUEDataset(SourceDataset, TextBaseDataset):
self.num_shards, self.shard_id)
class CoNLL2000Dataset(SourceDataset):
class CoNLL2000Dataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses CoNLL2000 dataset.
@ -414,68 +419,6 @@ class CoNLL2000Dataset(SourceDataset):
self.shard_id)
class CSVDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses comma-separated values (CSV) datasets.
The columns of generated dataset depend on the source CSV files.
Args:
dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search
for a pattern of files. The list will be sorted in a lexicographical order.
field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=',').
column_defaults (list, optional): List of default values for the CSV field (default=None). Each item
in the list is either a valid type (float, int, or string). If this is not provided, treats all
columns as string type.
column_names (list[str], optional): List of column names of the dataset (default=None). If this
is not provided, infers the column_names from the first row of CSV file.
num_samples (int, optional): The number of samples to be included in the dataset
(default=None, will include all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_files are not valid or do not exist.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
Examples:
>>> csv_dataset_dir = ["/path/to/csv_dataset_file"] # contains 1 or multiple csv files
>>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4'])
"""
@check_csvdataset
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
self.field_delim = replace_none(field_delim, ',')
self.column_defaults = replace_none(column_defaults, [])
self.column_names = replace_none(column_names, [])
def parse(self, children=None):
return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names,
self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id)
class DBpediaDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses the DBpedia dataset.
@ -572,7 +515,7 @@ class DBpediaDataset(SourceDataset, TextBaseDataset):
self.shard_id)
class EnWik9Dataset(SourceDataset):
class EnWik9Dataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses EnWik9 dataset.
@ -647,7 +590,7 @@ class EnWik9Dataset(SourceDataset):
return cde.EnWik9Node(self.dataset_dir, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class IMDBDataset(MappableDataset):
class IMDBDataset(MappableDataset, TextBaseDataset):
"""
A source dataset for reading and parsing Internet Movie Database (IMDb).
@ -726,7 +669,7 @@ class IMDBDataset(MappableDataset):
About IMDBDataset:
The IMDB dataset contains 50, 000 highly polarized reviews from the Internet Movie Database (IMDB). The data set
The IMDB dataset contains 50, 000 highly polarized reviews from the Internet Movie Database (IMDB). The dataset
was divided into 25 000 comments for training and 25 000 comments for testing, with both the training set and test
set containing 50% positive and 50% negative comments. Train labels and test labels are all lists of 0 and 1, where
0 stands for negative and 1 for positive.
@ -841,13 +784,13 @@ class IWSLT2016Dataset(SourceDataset, TextBaseDataset):
About IWSLT2016 dataset:
IWSLT is an international oral translation conference, a major annual scientific conference dedicated to all aspects
of oral translation. The MT task of the IWSLT evaluation activity constitutes a data set, which can be publicly
obtained through the WIT3 website wit3.fbk.eu. The IWSLT2016 data set includes translations from English to Arabic,
of oral translation. The MT task of the IWSLT evaluation activity constitutes a dataset, which can be publicly
obtained through the WIT3 website wit3.fbk.eu. The IWSLT2016 dataset includes translations from English to Arabic,
Czech, French, and German, and translations from Arabic, Czech, French, and German to English.
You can unzip the original IWSLT2016 dataset files into this directory structure and read by MindSpore's API. After
decompression, you also need to decompress the data set to be read in the specified folder. For example, if you want
to read the data set of de-en, you need to unzip the tgz file in the de/en directory, the data set is in the
decompression, you also need to decompress the dataset to be read in the specified folder. For example, if you want
to read the dataset of de-en, you need to unzip the tgz file in the de/en directory, the dataset is in the
unzipped folder.
.. code-block::
@ -961,9 +904,9 @@ class IWSLT2017Dataset(SourceDataset, TextBaseDataset):
About IWSLT2017 dataset:
IWSLT is an international oral translation conference, a major annual scientific conference dedicated to all aspects
of oral translation. The MT task of the IWSLT evaluation activity constitutes a data set, which can be publicly
obtained through the WIT3 website wit3.fbk.eu. The IWSLT2017 data set involves German, English, Italian, Dutch, and
Romanian. The data set includes translations in any two different languages.
of oral translation. The MT task of the IWSLT evaluation activity constitutes a dataset, which can be publicly
obtained through the WIT3 website wit3.fbk.eu. The IWSLT2017 dataset involves German, English, Italian, Dutch, and
Romanian. The dataset includes translations in any two different languages.
You can unzip the original IWSLT2017 dataset files into this directory structure and read by MindSpore's API. You
need to decompress the dataset package in texts/DeEnItNlRo/DeEnItNlRo directory to get the DeEnItNlRo-DeEnItNlRo
@ -1099,7 +1042,7 @@ class PennTreebankDataset(SourceDataset, TextBaseDataset):
self.shard_id)
class SogouNewsDataset(SourceDataset):
class SogouNewsDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses Sogou News dataset.
@ -1239,7 +1182,7 @@ class TextFileDataset(SourceDataset, TextBaseDataset):
self.shard_id)
class UDPOSDataset(SourceDataset):
class UDPOSDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses UDPOS dataset.
@ -1297,7 +1240,7 @@ class UDPOSDataset(SourceDataset):
self.shard_id)
class WikiTextDataset(SourceDataset):
class WikiTextDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses WikiText2 and WikiText103 datasets.
@ -1375,7 +1318,7 @@ class WikiTextDataset(SourceDataset):
self.shard_id)
class YahooAnswersDataset(SourceDataset):
class YahooAnswersDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses the YahooAnswers dataset.

View File

@ -13,10 +13,13 @@
# limitations under the License.
# ==============================================================================
"""
This dataset module supports various formats of datasets, including ImageNet, TFData,
MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
high performance and parses data precisely. Some of the operations that are
provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
This file contains contains basic classes that help users do flexible dataset loading.
You can define your own dataset loading class, and use GeneratorDataset to help load data.
You can refer to the
`tutorial <https://www.mindspore.cn/docs/programming_guide/en/master/dataset_loading.html#loading-user-defined-dataset>`
to help define your dataset loading.
After declaring the dataset object, you can further apply dataset operations
(e.g. filter, skip, concat, map, batch) on it.
"""
import builtins
import math
@ -38,7 +41,7 @@ import mindspore._c_dataengine as cde
from mindspore.common import Tensor
from mindspore import log as logger
from .datasets import MappableDataset, TextBaseDataset, Schema, to_list, _watch_dog, _check_shm_usage
from .datasets import UnionBaseDataset, MappableDataset, Schema, to_list, _watch_dog, _check_shm_usage
from . import samplers
from .queue import _SharedQueue
from .validators import check_generatordataset, check_numpyslicesdataset, check_paddeddataset
@ -427,7 +430,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
return True
class GeneratorDataset(MappableDataset, TextBaseDataset):
class GeneratorDataset(MappableDataset, UnionBaseDataset):
"""
A source dataset that generates data from Python by invoking Python data source each epoch.
@ -873,7 +876,7 @@ class _PaddedDataset:
class PaddedDataset(GeneratorDataset):
"""
Creates a dataset with filler data provided by user. Mainly used to add to the original data set
Creates a dataset with filler data provided by user. Mainly used to add to the original dataset
and assign it to the corresponding shard.
Args:

View File

@ -13,10 +13,15 @@
# limitations under the License.
# ==============================================================================
"""
This dataset module supports various formats of datasets, including ImageNet, TFData,
MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
high performance and parses data precisely. Some of the operations that are
provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
This file contains specific vision dataset loading classes. You can easily use
these classes to load the prepared dataset. For example:
ImageFolderDataset: which is about ImageNet dataset.
Cifar10Dataset: which is cifar10 binary version dataset.
Cifar100Dataset: which is cifar100 binary version dataset.
MnistDataset: which is mnist dataset.
...
After declaring the dataset object, you can further apply dataset operations
(e.g. filter, skip, concat, map, batch) on it.
"""
import os
import numpy as np
@ -25,7 +30,7 @@ from PIL import Image
import mindspore._c_dataengine as cde
from .datasets import MappableDataset, SourceDataset, Shuffle, Schema
from .datasets import VisionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema
from .datasets_user_defined import GeneratorDataset
from .validators import check_imagefolderdataset, \
check_mnist_cifar_dataset, check_manifestdataset, check_vocdataset, check_cocodataset, \
@ -268,7 +273,7 @@ class Caltech101Dataset(GeneratorDataset):
return class_dict
class Caltech256Dataset(MappableDataset):
class Caltech256Dataset(MappableDataset, VisionBaseDataset):
"""
A source dataset that reads and parses Caltech256 dataset.
@ -394,7 +399,7 @@ class Caltech256Dataset(MappableDataset):
return cde.Caltech256Node(self.dataset_dir, self.decode, self.sampler)
class CelebADataset(MappableDataset):
class CelebADataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing CelebA dataset.
Only support to read `list_attr_celeba.txt` currently, which is the attribute annotations of the dataset.
@ -558,7 +563,7 @@ class CelebADataset(MappableDataset):
class Cifar10Dataset(MappableDataset):
class Cifar10Dataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing Cifar10 dataset.
This api only supports parsing Cifar10 file in binary version now.
@ -689,7 +694,7 @@ class Cifar10Dataset(MappableDataset):
return cde.Cifar10Node(self.dataset_dir, self.usage, self.sampler)
class Cifar100Dataset(MappableDataset):
class Cifar100Dataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing Cifar100 dataset.
@ -813,7 +818,7 @@ class Cifar100Dataset(MappableDataset):
return cde.Cifar100Node(self.dataset_dir, self.usage, self.sampler)
class CityscapesDataset(MappableDataset):
class CityscapesDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing Cityscapes dataset.
@ -983,7 +988,7 @@ class CityscapesDataset(MappableDataset):
return cde.CityscapesNode(self.dataset_dir, self.usage, self.quality_mode, self.task, self.decode, self.sampler)
class CocoDataset(MappableDataset):
class CocoDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing COCO dataset.
@ -1194,7 +1199,7 @@ class CocoDataset(MappableDataset):
return self._class_indexing
class DIV2KDataset(MappableDataset):
class DIV2KDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing DIV2KDataset dataset.
@ -1381,7 +1386,7 @@ class DIV2KDataset(MappableDataset):
return cde.DIV2KNode(self.dataset_dir, self.usage, self.downgrade, self.scale, self.decode, self.sampler)
class EMnistDataset(MappableDataset):
class EMnistDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the EMNIST dataset.
@ -1513,7 +1518,7 @@ class EMnistDataset(MappableDataset):
return cde.EMnistNode(self.dataset_dir, self.name, self.usage, self.sampler)
class FakeImageDataset(MappableDataset):
class FakeImageDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for generating fake images.
@ -1602,7 +1607,7 @@ class FakeImageDataset(MappableDataset):
return cde.FakeImageNode(self.num_images, self.image_size, self.num_classes, self.base_seed, self.sampler)
class FashionMnistDataset(MappableDataset):
class FashionMnistDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the FASHION-MNIST dataset.
@ -1723,7 +1728,7 @@ class FashionMnistDataset(MappableDataset):
return cde.FashionMnistNode(self.dataset_dir, self.usage, self.sampler)
class FlickrDataset(MappableDataset):
class FlickrDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing Flickr8k and Flickr30k dataset.
@ -2128,7 +2133,7 @@ class Flowers102Dataset(GeneratorDataset):
return class_dict
class ImageFolderDataset(MappableDataset):
class ImageFolderDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset that reads images from a tree of directories.
All images within one folder have the same label.
@ -2257,7 +2262,7 @@ class ImageFolderDataset(MappableDataset):
return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing)
class KMnistDataset(MappableDataset):
class KMnistDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the KMNIST dataset.
@ -2378,7 +2383,7 @@ class KMnistDataset(MappableDataset):
return cde.KMnistNode(self.dataset_dir, self.usage, self.sampler)
class ManifestDataset(MappableDataset):
class ManifestDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading images from a Manifest file.
@ -2497,7 +2502,7 @@ class ManifestDataset(MappableDataset):
return self.class_indexing
class MnistDataset(MappableDataset):
class MnistDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the MNIST dataset.
@ -2617,7 +2622,7 @@ class MnistDataset(MappableDataset):
return cde.MnistNode(self.dataset_dir, self.usage, self.sampler)
class PhotoTourDataset(MappableDataset):
class PhotoTourDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the PhotoTour dataset.
@ -2770,7 +2775,7 @@ class PhotoTourDataset(MappableDataset):
return cde.PhotoTourNode(self.dataset_dir, self.name, self.usage, self.sampler)
class Places365Dataset(MappableDataset):
class Places365Dataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the Places365 dataset.
@ -2911,7 +2916,7 @@ class Places365Dataset(MappableDataset):
return cde.Places365Node(self.dataset_dir, self.usage, self.small, self.decode, self.sampler)
class QMnistDataset(MappableDataset):
class QMnistDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the QMNIST dataset.
@ -3036,7 +3041,7 @@ class QMnistDataset(MappableDataset):
return cde.QMnistNode(self.dataset_dir, self.usage, self.compat, self.sampler)
class RandomDataset(SourceDataset):
class RandomDataset(SourceDataset, VisionBaseDataset):
"""
A source dataset that generates random data.
@ -3268,7 +3273,7 @@ class SBDataset(GeneratorDataset):
num_shards=num_shards, shard_id=shard_id)
class SBUDataset(MappableDataset):
class SBUDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the SBU dataset.
@ -3382,7 +3387,7 @@ class SBUDataset(MappableDataset):
return cde.SBUNode(self.dataset_dir, self.decode, self.sampler)
class SemeionDataset(MappableDataset):
class SemeionDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing Semeion dataset.
@ -3502,7 +3507,7 @@ class SemeionDataset(MappableDataset):
return cde.SemeionNode(self.dataset_dir, self.sampler)
class STL10Dataset(MappableDataset):
class STL10Dataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing STL10 dataset.
@ -3784,7 +3789,7 @@ class SVHNDataset(GeneratorDataset):
num_shards=num_shards, shard_id=shard_id)
class USPSDataset(SourceDataset):
class USPSDataset(SourceDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing the USPS dataset.
@ -3881,7 +3886,7 @@ class USPSDataset(SourceDataset):
self.shard_id)
class VOCDataset(MappableDataset):
class VOCDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing VOC dataset.
@ -4083,7 +4088,7 @@ class VOCDataset(MappableDataset):
return self.class_indexing
class WIDERFaceDataset(MappableDataset):
class WIDERFaceDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset for reading and parsing WIDERFace dataset.