forked from mindspore-Ecosystem/mindspore
fix mindrecord comments
This commit is contained in:
parent
39bc43e674
commit
d5ed59ee4d
|
@ -13,12 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""
|
"""
|
||||||
Introduction to mindrecord:
|
Introduction of mindrecord:
|
||||||
|
|
||||||
Mindrecord is a module to implement reading, writing, search and
|
Mindrecord is a module to implement reading, writing, search and
|
||||||
converting for MindSpore format dataset. Users could load(modify)
|
converting for MindSpore format dataset. Users could load(modify)
|
||||||
mindrecord data through FileReader(FileWriter). Users could also
|
mindrecord data through FileReader(FileWriter). Users could also
|
||||||
convert other format dataset to mindrecord data through
|
convert other format datasets to mindrecord data through
|
||||||
corresponding sub-module.
|
corresponding sub-module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -29,10 +29,10 @@ class FileReader:
|
||||||
Class to read MindRecord File series.
|
Class to read MindRecord File series.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_name (str, list[str]): One of MindRecord File or file list.
|
file_name (str, list[str]): One of MindRecord File or a file list.
|
||||||
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
||||||
It should not be smaller than 1 or larger than the number of CPU.
|
It should not be smaller than 1 or larger than the number of CPUs.
|
||||||
columns (list[str], optional): List of fields which corresponding data would be read (default=None).
|
columns (list[str], optional): A list of fields where corresponding data would be read (default=None).
|
||||||
operator(int, optional): Reserved parameter for operators (default=None).
|
operator(int, optional): Reserved parameter for operators (default=None).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -72,7 +72,7 @@ class FileReader:
|
||||||
Yield a batch of data according to columns at a time.
|
Yield a batch of data according to columns at a time.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
dict: keys is the same as columns.
|
dictionary: keys are the same as columns.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MRMUnsupportedSchemaError: If schema is invalid.
|
MRMUnsupportedSchemaError: If schema is invalid.
|
||||||
|
|
|
@ -39,11 +39,11 @@ class FileWriter:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_name (str): File name of MindRecord File.
|
file_name (str): File name of MindRecord File.
|
||||||
shard_num (int, optional): Number of MindRecord File (default=1).
|
shard_num (int, optional): The Number of MindRecord File (default=1).
|
||||||
It should be between [1, 1000].
|
It should be between [1, 1000].
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ParamValueError: If file_name or shard_num is invalid.
|
ParamValueError: If `file_name` or `shard_num` is invalid.
|
||||||
"""
|
"""
|
||||||
def __init__(self, file_name, shard_num=1):
|
def __init__(self, file_name, shard_num=1):
|
||||||
check_filename(file_name)
|
check_filename(file_name)
|
||||||
|
@ -88,7 +88,7 @@ class FileWriter:
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ParamValueError: If file_name is invalid.
|
ParamValueError: If file_name is invalid.
|
||||||
FileNameError: If path contains invalid character.
|
FileNameError: If path contains invalid characters.
|
||||||
MRMOpenError: If failed to open MindRecord File.
|
MRMOpenError: If failed to open MindRecord File.
|
||||||
MRMOpenForAppendError: If failed to open file for appending data.
|
MRMOpenForAppendError: If failed to open file for appending data.
|
||||||
"""
|
"""
|
||||||
|
@ -111,14 +111,14 @@ class FileWriter:
|
||||||
|
|
||||||
def add_schema(self, content, desc=None):
|
def add_schema(self, content, desc=None):
|
||||||
"""
|
"""
|
||||||
Returns a schema id if added schema successfully, or raise exception.
|
Return a schema id if schema is added successfully, or raise an exception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content (dict): Dict of user defined schema.
|
content (dict): Dictionary of user defined schema.
|
||||||
desc (str, optional): String of schema description (default=None).
|
desc (str, optional): String of schema description (default=None).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int, schema id.
|
An integer, schema id.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MRMInvalidSchemaError: If schema is invalid.
|
MRMInvalidSchemaError: If schema is invalid.
|
||||||
|
@ -145,7 +145,7 @@ class FileWriter:
|
||||||
ParamTypeError: If index field is invalid.
|
ParamTypeError: If index field is invalid.
|
||||||
MRMDefineIndexError: If index field is not primitive type.
|
MRMDefineIndexError: If index field is not primitive type.
|
||||||
MRMAddIndexError: If failed to add index field.
|
MRMAddIndexError: If failed to add index field.
|
||||||
MRMGetMetaError: If the schema is not set or get meta failed.
|
MRMGetMetaError: If the schema is not set or failed to get meta.
|
||||||
"""
|
"""
|
||||||
if not index_fields or not isinstance(index_fields, list):
|
if not index_fields or not isinstance(index_fields, list):
|
||||||
raise ParamTypeError('index_fields', 'list')
|
raise ParamTypeError('index_fields', 'list')
|
||||||
|
@ -205,7 +205,7 @@ class FileWriter:
|
||||||
|
|
||||||
def open_and_set_header(self):
|
def open_and_set_header(self):
|
||||||
"""
|
"""
|
||||||
Open writer and set header
|
Open writer and set header.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not self._writer.is_open:
|
if not self._writer.is_open:
|
||||||
|
@ -245,7 +245,7 @@ class FileWriter:
|
||||||
"""
|
"""
|
||||||
Set the size of header which contains shard information, schema information, \
|
Set the size of header which contains shard information, schema information, \
|
||||||
page meta information, etc. The larger the header, the more training data \
|
page meta information, etc. The larger the header, the more training data \
|
||||||
a single mindrecord file can store.
|
a single Mindrecord file can store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
header_size (int): Size of header, between 16KB and 128MB.
|
header_size (int): Size of header, between 16KB and 128MB.
|
||||||
|
@ -278,7 +278,7 @@ class FileWriter:
|
||||||
|
|
||||||
def commit(self):
|
def commit(self):
|
||||||
"""
|
"""
|
||||||
Flush data to disk and generate the corresponding db files.
|
Flush data to disk and generate the corresponding database files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MSRStatus, SUCCESS or FAILED.
|
MSRStatus, SUCCESS or FAILED.
|
||||||
|
|
|
@ -28,12 +28,12 @@ class MindPage:
|
||||||
Class to read MindRecord File series in pagination.
|
Class to read MindRecord File series in pagination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_name (str): One of MindRecord File or file list.
|
file_name (str): One of MindRecord File or a file list.
|
||||||
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
|
num_consumer(int, optional): The number of consumer threads which load data to memory (default=4).
|
||||||
It should not be smaller than 1 or larger than the number of CPU.
|
It should not be smaller than 1 or larger than the number of CPUs.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ParamValueError: If file_name, num_consumer or columns is invalid.
|
ParamValueError: If `file_name`, `num_consumer` or columns is invalid.
|
||||||
MRMInitSegmentError: If failed to initialize ShardSegment.
|
MRMInitSegmentError: If failed to initialize ShardSegment.
|
||||||
"""
|
"""
|
||||||
def __init__(self, file_name, num_consumer=4):
|
def __init__(self, file_name, num_consumer=4):
|
||||||
|
@ -97,7 +97,7 @@ class MindPage:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def category_field(self):
|
def category_field(self):
|
||||||
"""Getter function for category field"""
|
"""Getter function for category fields."""
|
||||||
return self._category_field
|
return self._category_field
|
||||||
|
|
||||||
@category_field.setter
|
@category_field.setter
|
||||||
|
@ -127,7 +127,7 @@ class MindPage:
|
||||||
Query by category id in pagination.
|
Query by category id in pagination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
category_id (int): Category id, referred to the return of read_category_info.
|
category_id (int): Category id, referred to the return of `read_category_info`.
|
||||||
page (int): Index of page.
|
page (int): Index of page.
|
||||||
num_row (int): Number of rows in a page.
|
num_row (int): Number of rows in a page.
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ class MindPage:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
category_name (str): String of category field's value,
|
category_name (str): String of category field's value,
|
||||||
referred to the return of read_category_info.
|
referred to the return of `read_category_info`.
|
||||||
page (int): Index of page.
|
page (int): Index of page.
|
||||||
num_row (int): Number of row in a page.
|
num_row (int): Number of row in a page.
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ __all__ = ['Cifar100ToMR']
|
||||||
|
|
||||||
class Cifar100ToMR:
|
class Cifar100ToMR:
|
||||||
"""
|
"""
|
||||||
Class is for transformation from cifar100 to MindRecord.
|
A class to transform from cifar100 to MindRecord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source (str): the cifar100 directory to be transformed.
|
source (str): the cifar100 directory to be transformed.
|
||||||
|
@ -71,10 +71,10 @@ class Cifar100ToMR:
|
||||||
Executes transformation from cifar100 to MindRecord.
|
Executes transformation from cifar100 to MindRecord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fields (list[str]): list of index field, ie. ["fine_label", "coarse_label"].
|
fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
SUCCESS or FAILED, whether cifar100 is successfully transformed to MindRecord.
|
||||||
"""
|
"""
|
||||||
if fields and not isinstance(fields, list):
|
if fields and not isinstance(fields, list):
|
||||||
raise ValueError("The parameter fields should be None or list")
|
raise ValueError("The parameter fields should be None or list")
|
||||||
|
|
|
@ -34,7 +34,7 @@ __all__ = ['Cifar10ToMR']
|
||||||
|
|
||||||
class Cifar10ToMR:
|
class Cifar10ToMR:
|
||||||
"""
|
"""
|
||||||
Class is for transformation from cifar10 to MindRecord.
|
A class to transform from cifar10 to MindRecord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source (str): the cifar10 directory to be transformed.
|
source (str): the cifar10 directory to be transformed.
|
||||||
|
@ -70,10 +70,10 @@ class Cifar10ToMR:
|
||||||
Executes transformation from cifar10 to MindRecord.
|
Executes transformation from cifar10 to MindRecord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fields (list[str], optional): list of index fields, ie. ["label"] (default=None).
|
fields (list[str], optional): A list of index fields, e.g.["label"] (default=None).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
SUCCESS or FAILED, whether cifar10 is successfully transformed to MindRecord.
|
||||||
"""
|
"""
|
||||||
if fields and not isinstance(fields, list):
|
if fields and not isinstance(fields, list):
|
||||||
raise ValueError("The parameter fields should be None or list")
|
raise ValueError("The parameter fields should be None or list")
|
||||||
|
|
|
@ -31,17 +31,17 @@ __all__ = ['CsvToMR']
|
||||||
|
|
||||||
class CsvToMR:
|
class CsvToMR:
|
||||||
"""
|
"""
|
||||||
Class is for transformation from csv to MindRecord.
|
A class to transform from csv to MindRecord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source (str): the file path of csv.
|
source (str): the file path of csv.
|
||||||
destination (str): the MindRecord file path to transform into.
|
destination (str): the MindRecord file path to transform into.
|
||||||
columns_list(list[str], optional): List of columns to be read(default=None).
|
columns_list(list[str], optional): A list of columns to be read(default=None).
|
||||||
partition_number (int, optional): partition size (default=1).
|
partition_number (int, optional): partition size (default=1).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If source, destination, partition_number is invalid.
|
ValueError: If `source`, `destination`, `partition_number` is invalid.
|
||||||
RuntimeError: If columns_list is invalid.
|
RuntimeError: If `columns_list` is invalid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, source, destination, columns_list=None, partition_number=1):
|
def __init__(self, source, destination, columns_list=None, partition_number=1):
|
||||||
|
@ -121,7 +121,7 @@ class CsvToMR:
|
||||||
Executes transformation from csv to MindRecord.
|
Executes transformation from csv to MindRecord.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
SUCCESS or FAILED, whether csv is successfully transformed to MindRecord.
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(self.source):
|
if not os.path.exists(self.source):
|
||||||
raise IOError("Csv file {} do not exist.".format(self.source))
|
raise IOError("Csv file {} do not exist.".format(self.source))
|
||||||
|
|
|
@ -28,11 +28,10 @@ __all__ = ['ImageNetToMR']
|
||||||
|
|
||||||
class ImageNetToMR:
|
class ImageNetToMR:
|
||||||
"""
|
"""
|
||||||
Class is for transformation from imagenet to MindRecord.
|
A class to transform from imagenet to MindRecord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
map_file (str): the map file which indicates label.
|
map_file (str): the map file that indicates label. The map file content should be like this:
|
||||||
the map file content should like this:
|
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
|
@ -41,12 +40,12 @@ class ImageNetToMR:
|
||||||
n02110185 2
|
n02110185 2
|
||||||
n02096294 3
|
n02096294 3
|
||||||
|
|
||||||
image_dir (str): image directory contains n02119789, n02100735, n02110185, n02096294 dir.
|
image_dir (str): image directory contains n02119789, n02100735, n02110185 and n02096294 directory.
|
||||||
destination (str): the MindRecord file path to transform into.
|
destination (str): the MindRecord file path to transform into.
|
||||||
partition_number (int, optional): partition size (default=1).
|
partition_number (int, optional): partition size (default=1).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If map_file, image_dir or destination is invalid.
|
ValueError: If `map_file`, `image_dir` or `destination` is invalid.
|
||||||
"""
|
"""
|
||||||
def __init__(self, map_file, image_dir, destination, partition_number=1):
|
def __init__(self, map_file, image_dir, destination, partition_number=1):
|
||||||
check_filename(map_file)
|
check_filename(map_file)
|
||||||
|
@ -123,7 +122,7 @@ class ImageNetToMR:
|
||||||
Executes transformation from imagenet to MindRecord.
|
Executes transformation from imagenet to MindRecord.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
SUCCESS or FAILED, whether imagenet is successfully transformed to MindRecord.
|
||||||
"""
|
"""
|
||||||
t0_total = time.time()
|
t0_total = time.time()
|
||||||
|
|
||||||
|
|
|
@ -34,17 +34,17 @@ __all__ = ['MnistToMR']
|
||||||
|
|
||||||
class MnistToMR:
|
class MnistToMR:
|
||||||
"""
|
"""
|
||||||
Class is for transformation from Mnist to MindRecord.
|
A class to transform from Mnist to MindRecord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source (str): directory which contains t10k-images-idx3-ubyte.gz,
|
source (str): directory that contains t10k-images-idx3-ubyte.gz,
|
||||||
train-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz,
|
train-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz
|
||||||
train-labels-idx1-ubyte.gz.
|
and train-labels-idx1-ubyte.gz.
|
||||||
destination (str): the MindRecord file directory to transform into.
|
destination (str): the MindRecord file directory to transform into.
|
||||||
partition_number (int, optional): partition size (default=1).
|
partition_number (int, optional): partition size (default=1).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If source/destination/partition_number is invalid.
|
ValueError: If `source`, `destination`, `partition_number` is invalid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, source, destination, partition_number=1):
|
def __init__(self, source, destination, partition_number=1):
|
||||||
|
@ -173,7 +173,7 @@ class MnistToMR:
|
||||||
Executes transformation from Mnist test part to MindRecord.
|
Executes transformation from Mnist test part to MindRecord.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
SUCCESS or FAILED, whether Mnist is successfully transformed to MindRecord.
|
||||||
"""
|
"""
|
||||||
t0_total = time.time()
|
t0_total = time.time()
|
||||||
|
|
||||||
|
|
|
@ -99,25 +99,25 @@ def _cast_name(key):
|
||||||
|
|
||||||
class TFRecordToMR:
|
class TFRecordToMR:
|
||||||
"""
|
"""
|
||||||
Class is for tranformation from TFRecord to MindRecord.
|
A class to transform from TFRecord to MindRecord.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source (str): the TFRecord file to be transformed.
|
source (str): the TFRecord file to be transformed.
|
||||||
destination (str): the MindRecord file path to tranform into.
|
destination (str): the MindRecord file path to tranform into.
|
||||||
feature_dict (dict): a dictionary that states the feature type, i.e.
|
feature_dict (dict): a dictionary that states the feature type, e.g.
|
||||||
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \
|
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \
|
||||||
"yyyy": tf.io.FixedLenFeature([], tf.int64)}
|
"yyyy": tf.io.FixedLenFeature([], tf.int64)}
|
||||||
|
|
||||||
**Follow case which uses VarLenFeature not support**
|
**Follow case which uses VarLenFeature is not supported.**
|
||||||
|
|
||||||
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), \
|
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), \
|
||||||
"yyyy": tf.io.VarLenFeature(tf.int64)}, \
|
"yyyy": tf.io.VarLenFeature(tf.int64)}, \
|
||||||
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
|
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
|
||||||
bytes_fields (list, optional): the bytes fields which are in feature_dict and can be images bytes.
|
bytes_fields (list, optional): the bytes fields which are in `feature_dict` and can be images bytes.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If parameter is invalid.
|
ValueError: If parameter is invalid.
|
||||||
Exception: when tensorflow module not found or version is not correct.
|
Exception: when tensorflow module is not found or version is not correct.
|
||||||
"""
|
"""
|
||||||
def __init__(self, source, destination, feature_dict, bytes_fields=None):
|
def __init__(self, source, destination, feature_dict, bytes_fields=None):
|
||||||
if not tf:
|
if not tf:
|
||||||
|
@ -211,7 +211,7 @@ class TFRecordToMR:
|
||||||
ms_dict[cast_key] = float(val.numpy())
|
ms_dict[cast_key] = float(val.numpy())
|
||||||
|
|
||||||
def tfrecord_iterator(self):
|
def tfrecord_iterator(self):
|
||||||
"""Yield a dict with key to be fields in schema, and value to be data."""
|
"""Yield a dictionary whose keys are fields in schema."""
|
||||||
dataset = tf.data.TFRecordDataset(self.source)
|
dataset = tf.data.TFRecordDataset(self.source)
|
||||||
dataset = dataset.map(self._parse_record)
|
dataset = dataset.map(self._parse_record)
|
||||||
iterator = dataset.__iter__()
|
iterator = dataset.__iter__()
|
||||||
|
@ -237,10 +237,10 @@ class TFRecordToMR:
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
Executes transform from TFRecord to MindRecord.
|
Execute transformation from TFRecord to MindRecord.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SUCCESS/FAILED, whether successfuly written into MindRecord.
|
SUCCESS or FAILED, whether TFRecord is successfuly transformed to MindRecord.
|
||||||
"""
|
"""
|
||||||
writer = FileWriter(self.destination)
|
writer = FileWriter(self.destination)
|
||||||
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"
|
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"
|
||||||
|
|
Loading…
Reference in New Issue