fix mindrecord comments

This commit is contained in:
liyong 2020-10-13 09:53:09 +08:00
parent 39bc43e674
commit d5ed59ee4d
10 changed files with 53 additions and 54 deletions

View File

@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" """
Introduction to mindrecord: Introduction of mindrecord:
Mindrecord is a module to implement reading, writing, search and Mindrecord is a module to implement reading, writing, search and
converting for MindSpore format dataset. Users could load(modify) converting for MindSpore format dataset. Users could load(modify)
mindrecord data through FileReader(FileWriter). Users could also mindrecord data through FileReader(FileWriter). Users could also
convert other format dataset to mindrecord data through convert other format datasets to mindrecord data through
corresponding sub-module. corresponding sub-module.
""" """

View File

@ -29,10 +29,10 @@ class FileReader:
Class to read MindRecord File series. Class to read MindRecord File series.
Args: Args:
file_name (str, list[str]): One of MindRecord File or file list. file_name (str, list[str]): One of MindRecord File or a file list.
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
It should not be smaller than 1 or larger than the number of CPU. It should not be smaller than 1 or larger than the number of CPUs.
columns (list[str], optional): List of fields which corresponding data would be read (default=None). columns (list[str], optional): A list of fields where corresponding data would be read (default=None).
operator(int, optional): Reserved parameter for operators (default=None). operator(int, optional): Reserved parameter for operators (default=None).
Raises: Raises:
@ -72,7 +72,7 @@ class FileReader:
Yield a batch of data according to columns at a time. Yield a batch of data according to columns at a time.
Yields: Yields:
dict: keys is the same as columns. dictionary: keys are the same as columns.
Raises: Raises:
MRMUnsupportedSchemaError: If schema is invalid. MRMUnsupportedSchemaError: If schema is invalid.

View File

@ -39,11 +39,11 @@ class FileWriter:
Args: Args:
file_name (str): File name of MindRecord File. file_name (str): File name of MindRecord File.
shard_num (int, optional): Number of MindRecord File (default=1). shard_num (int, optional): The Number of MindRecord File (default=1).
It should be between [1, 1000]. It should be between [1, 1000].
Raises: Raises:
ParamValueError: If file_name or shard_num is invalid. ParamValueError: If `file_name` or `shard_num` is invalid.
""" """
def __init__(self, file_name, shard_num=1): def __init__(self, file_name, shard_num=1):
check_filename(file_name) check_filename(file_name)
@ -88,7 +88,7 @@ class FileWriter:
Raises: Raises:
ParamValueError: If file_name is invalid. ParamValueError: If file_name is invalid.
FileNameError: If path contains invalid character. FileNameError: If path contains invalid characters.
MRMOpenError: If failed to open MindRecord File. MRMOpenError: If failed to open MindRecord File.
MRMOpenForAppendError: If failed to open file for appending data. MRMOpenForAppendError: If failed to open file for appending data.
""" """
@ -111,14 +111,14 @@ class FileWriter:
def add_schema(self, content, desc=None): def add_schema(self, content, desc=None):
""" """
Returns a schema id if added schema successfully, or raise exception. Return a schema id if schema is added successfully, or raise an exception.
Args: Args:
content (dict): Dict of user defined schema. content (dict): Dictionary of user defined schema.
desc (str, optional): String of schema description (default=None). desc (str, optional): String of schema description (default=None).
Returns: Returns:
int, schema id. An integer, schema id.
Raises: Raises:
MRMInvalidSchemaError: If schema is invalid. MRMInvalidSchemaError: If schema is invalid.
@ -145,7 +145,7 @@ class FileWriter:
ParamTypeError: If index field is invalid. ParamTypeError: If index field is invalid.
MRMDefineIndexError: If index field is not primitive type. MRMDefineIndexError: If index field is not primitive type.
MRMAddIndexError: If failed to add index field. MRMAddIndexError: If failed to add index field.
MRMGetMetaError: If the schema is not set or get meta failed. MRMGetMetaError: If the schema is not set or failed to get meta.
""" """
if not index_fields or not isinstance(index_fields, list): if not index_fields or not isinstance(index_fields, list):
raise ParamTypeError('index_fields', 'list') raise ParamTypeError('index_fields', 'list')
@ -205,7 +205,7 @@ class FileWriter:
def open_and_set_header(self): def open_and_set_header(self):
""" """
Open writer and set header Open writer and set header.
""" """
if not self._writer.is_open: if not self._writer.is_open:
@ -245,7 +245,7 @@ class FileWriter:
""" """
Set the size of header which contains shard information, schema information, \ Set the size of header which contains shard information, schema information, \
page meta information, etc. The larger the header, the more training data \ page meta information, etc. The larger the header, the more training data \
a single mindrecord file can store. a single Mindrecord file can store.
Args: Args:
header_size (int): Size of header, between 16KB and 128MB. header_size (int): Size of header, between 16KB and 128MB.
@ -278,7 +278,7 @@ class FileWriter:
def commit(self): def commit(self):
""" """
Flush data to disk and generate the corresponding db files. Flush data to disk and generate the corresponding database files.
Returns: Returns:
MSRStatus, SUCCESS or FAILED. MSRStatus, SUCCESS or FAILED.

View File

@ -28,12 +28,12 @@ class MindPage:
Class to read MindRecord File series in pagination. Class to read MindRecord File series in pagination.
Args: Args:
file_name (str): One of MindRecord File or file list. file_name (str): One of MindRecord File or a file list.
num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). num_consumer(int, optional): The number of consumer threads which load data to memory (default=4).
It should not be smaller than 1 or larger than the number of CPU. It should not be smaller than 1 or larger than the number of CPUs.
Raises: Raises:
ParamValueError: If file_name, num_consumer or columns is invalid. ParamValueError: If `file_name`, `num_consumer` or columns is invalid.
MRMInitSegmentError: If failed to initialize ShardSegment. MRMInitSegmentError: If failed to initialize ShardSegment.
""" """
def __init__(self, file_name, num_consumer=4): def __init__(self, file_name, num_consumer=4):
@ -97,7 +97,7 @@ class MindPage:
@property @property
def category_field(self): def category_field(self):
"""Getter function for category field""" """Getter function for category fields."""
return self._category_field return self._category_field
@category_field.setter @category_field.setter
@ -127,7 +127,7 @@ class MindPage:
Query by category id in pagination. Query by category id in pagination.
Args: Args:
category_id (int): Category id, referred to the return of read_category_info. category_id (int): Category id, referred to the return of `read_category_info`.
page (int): Index of page. page (int): Index of page.
num_row (int): Number of rows in a page. num_row (int): Number of rows in a page.
@ -153,7 +153,7 @@ class MindPage:
Args: Args:
category_name (str): String of category field's value, category_name (str): String of category field's value,
referred to the return of read_category_info. referred to the return of `read_category_info`.
page (int): Index of page. page (int): Index of page.
num_row (int): Number of row in a page. num_row (int): Number of row in a page.

View File

@ -35,7 +35,7 @@ __all__ = ['Cifar100ToMR']
class Cifar100ToMR: class Cifar100ToMR:
""" """
Class is for transformation from cifar100 to MindRecord. A class to transform from cifar100 to MindRecord.
Args: Args:
source (str): the cifar100 directory to be transformed. source (str): the cifar100 directory to be transformed.
@ -71,10 +71,10 @@ class Cifar100ToMR:
Executes transformation from cifar100 to MindRecord. Executes transformation from cifar100 to MindRecord.
Args: Args:
fields (list[str]): list of index field, ie. ["fine_label", "coarse_label"]. fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"].
Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord. SUCCESS or FAILED, whether cifar100 is successfully transformed to MindRecord.
""" """
if fields and not isinstance(fields, list): if fields and not isinstance(fields, list):
raise ValueError("The parameter fields should be None or list") raise ValueError("The parameter fields should be None or list")

View File

@ -34,7 +34,7 @@ __all__ = ['Cifar10ToMR']
class Cifar10ToMR: class Cifar10ToMR:
""" """
Class is for transformation from cifar10 to MindRecord. A class to transform from cifar10 to MindRecord.
Args: Args:
source (str): the cifar10 directory to be transformed. source (str): the cifar10 directory to be transformed.
@ -70,10 +70,10 @@ class Cifar10ToMR:
Executes transformation from cifar10 to MindRecord. Executes transformation from cifar10 to MindRecord.
Args: Args:
fields (list[str], optional): list of index fields, ie. ["label"] (default=None). fields (list[str], optional): A list of index fields, e.g.["label"] (default=None).
Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord. SUCCESS or FAILED, whether cifar10 is successfully transformed to MindRecord.
""" """
if fields and not isinstance(fields, list): if fields and not isinstance(fields, list):
raise ValueError("The parameter fields should be None or list") raise ValueError("The parameter fields should be None or list")

View File

@ -31,17 +31,17 @@ __all__ = ['CsvToMR']
class CsvToMR: class CsvToMR:
""" """
Class is for transformation from csv to MindRecord. A class to transform from csv to MindRecord.
Args: Args:
source (str): the file path of csv. source (str): the file path of csv.
destination (str): the MindRecord file path to transform into. destination (str): the MindRecord file path to transform into.
columns_list(list[str], optional): List of columns to be read(default=None). columns_list(list[str], optional): A list of columns to be read(default=None).
partition_number (int, optional): partition size (default=1). partition_number (int, optional): partition size (default=1).
Raises: Raises:
ValueError: If source, destination, partition_number is invalid. ValueError: If `source`, `destination`, `partition_number` is invalid.
RuntimeError: If columns_list is invalid. RuntimeError: If `columns_list` is invalid.
""" """
def __init__(self, source, destination, columns_list=None, partition_number=1): def __init__(self, source, destination, columns_list=None, partition_number=1):
@ -121,7 +121,7 @@ class CsvToMR:
Executes transformation from csv to MindRecord. Executes transformation from csv to MindRecord.
Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord. SUCCESS or FAILED, whether csv is successfully transformed to MindRecord.
""" """
if not os.path.exists(self.source): if not os.path.exists(self.source):
raise IOError("Csv file {} do not exist.".format(self.source)) raise IOError("Csv file {} do not exist.".format(self.source))

View File

@ -28,11 +28,10 @@ __all__ = ['ImageNetToMR']
class ImageNetToMR: class ImageNetToMR:
""" """
Class is for transformation from imagenet to MindRecord. A class to transform from imagenet to MindRecord.
Args: Args:
map_file (str): the map file which indicates label. map_file (str): the map file that indicates label. The map file content should be like this:
the map file content should like this:
.. code-block:: .. code-block::
@ -41,12 +40,12 @@ class ImageNetToMR:
n02110185 2 n02110185 2
n02096294 3 n02096294 3
image_dir (str): image directory contains n02119789, n02100735, n02110185, n02096294 dir. image_dir (str): image directory contains n02119789, n02100735, n02110185 and n02096294 directory.
destination (str): the MindRecord file path to transform into. destination (str): the MindRecord file path to transform into.
partition_number (int, optional): partition size (default=1). partition_number (int, optional): partition size (default=1).
Raises: Raises:
ValueError: If map_file, image_dir or destination is invalid. ValueError: If `map_file`, `image_dir` or `destination` is invalid.
""" """
def __init__(self, map_file, image_dir, destination, partition_number=1): def __init__(self, map_file, image_dir, destination, partition_number=1):
check_filename(map_file) check_filename(map_file)
@ -123,7 +122,7 @@ class ImageNetToMR:
Executes transformation from imagenet to MindRecord. Executes transformation from imagenet to MindRecord.
Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord. SUCCESS or FAILED, whether imagenet is successfully transformed to MindRecord.
""" """
t0_total = time.time() t0_total = time.time()

View File

@ -34,17 +34,17 @@ __all__ = ['MnistToMR']
class MnistToMR: class MnistToMR:
""" """
Class is for transformation from Mnist to MindRecord. A class to transform from Mnist to MindRecord.
Args: Args:
source (str): directory which contains t10k-images-idx3-ubyte.gz, source (str): directory that contains t10k-images-idx3-ubyte.gz,
train-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz, train-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz
train-labels-idx1-ubyte.gz. and train-labels-idx1-ubyte.gz.
destination (str): the MindRecord file directory to transform into. destination (str): the MindRecord file directory to transform into.
partition_number (int, optional): partition size (default=1). partition_number (int, optional): partition size (default=1).
Raises: Raises:
ValueError: If source/destination/partition_number is invalid. ValueError: If `source`, `destination`, `partition_number` is invalid.
""" """
def __init__(self, source, destination, partition_number=1): def __init__(self, source, destination, partition_number=1):
@ -173,7 +173,7 @@ class MnistToMR:
Executes transformation from Mnist test part to MindRecord. Executes transformation from Mnist test part to MindRecord.
Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord. SUCCESS or FAILED, whether Mnist is successfully transformed to MindRecord.
""" """
t0_total = time.time() t0_total = time.time()

View File

@ -99,25 +99,25 @@ def _cast_name(key):
class TFRecordToMR: class TFRecordToMR:
""" """
Class is for tranformation from TFRecord to MindRecord. A class to transform from TFRecord to MindRecord.
Args: Args:
source (str): the TFRecord file to be transformed. source (str): the TFRecord file to be transformed.
destination (str): the MindRecord file path to tranform into. destination (str): the MindRecord file path to tranform into.
feature_dict (dict): a dictionary that states the feature type, i.e. feature_dict (dict): a dictionary that states the feature type, e.g.
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \
"yyyy": tf.io.FixedLenFeature([], tf.int64)} "yyyy": tf.io.FixedLenFeature([], tf.int64)}
**Follow case which uses VarLenFeature not support** **Follow case which uses VarLenFeature is not supported.**
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), \ feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), \
"yyyy": tf.io.VarLenFeature(tf.int64)}, \ "yyyy": tf.io.VarLenFeature(tf.int64)}, \
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}} "sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}}
bytes_fields (list, optional): the bytes fields which are in feature_dict and can be images bytes. bytes_fields (list, optional): the bytes fields which are in `feature_dict` and can be images bytes.
Raises: Raises:
ValueError: If parameter is invalid. ValueError: If parameter is invalid.
Exception: when tensorflow module not found or version is not correct. Exception: when tensorflow module is not found or version is not correct.
""" """
def __init__(self, source, destination, feature_dict, bytes_fields=None): def __init__(self, source, destination, feature_dict, bytes_fields=None):
if not tf: if not tf:
@ -211,7 +211,7 @@ class TFRecordToMR:
ms_dict[cast_key] = float(val.numpy()) ms_dict[cast_key] = float(val.numpy())
def tfrecord_iterator(self): def tfrecord_iterator(self):
"""Yield a dict with key to be fields in schema, and value to be data.""" """Yield a dictionary whose keys are fields in schema."""
dataset = tf.data.TFRecordDataset(self.source) dataset = tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record) dataset = dataset.map(self._parse_record)
iterator = dataset.__iter__() iterator = dataset.__iter__()
@ -237,10 +237,10 @@ class TFRecordToMR:
def run(self): def run(self):
""" """
Executes transform from TFRecord to MindRecord. Execute transformation from TFRecord to MindRecord.
Returns: Returns:
SUCCESS/FAILED, whether successfuly written into MindRecord. SUCCESS or FAILED, whether TFRecord is successfuly transformed to MindRecord.
""" """
writer = FileWriter(self.destination) writer = FileWriter(self.destination)
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"