diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index e54ce1d6372..93749b5013c 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -102,7 +102,7 @@ def get_seed(): Get the seed. Returns: - Int, seed. + int, seed. """ return _config.get_seed() @@ -131,7 +131,7 @@ def get_prefetch_size(): Get the prefetch size in number of rows. Returns: - Size, total number of rows to be prefetched. + int, total number of rows to be prefetched. """ return _config.get_op_connector_size() @@ -162,7 +162,7 @@ def get_num_parallel_workers(): This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature. Returns: - Int, number of parallel workers to be used as a default for each operation + int, number of parallel workers to be used as a default for each operation. """ return _config.get_num_parallel_workers() @@ -193,7 +193,7 @@ def get_numa_enable(): This is the DEFAULT numa enabled value used for the all process. Returns: - boolean, the default state of numa enabled + bool, the default state of numa enabled. """ return _config.get_numa_enable() @@ -222,7 +222,7 @@ def get_monitor_sampling_interval(): Get the default interval of performance monitor sampling. Returns: - Int, interval (in milliseconds) for performance monitor sampling. + int, interval (in milliseconds) for performance monitor sampling. """ return _config.get_monitor_sampling_interval() @@ -280,7 +280,8 @@ def get_auto_num_workers(): Get the setting (turned on or off) automatic number of workers. Returns: - Bool, whether auto num worker feature is turned on + bool, whether auto num worker feature is turned on. + Examples: >>> num_workers = ds.config.get_auto_num_workers() """ @@ -313,7 +314,7 @@ def get_callback_timeout(): In case of a deadlock, the wait function will exit after the timeout period. Returns: - Int, the duration in seconds + int, the duration in seconds. """ return _config.get_callback_timeout() @@ -323,7 +324,7 @@ def __str__(): String representation of the configurations. Returns: - Str, configurations. + str, configurations. """ return str(_config) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 668b88a360b..bb7da167959 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -80,7 +80,7 @@ def zip(datasets): The number of datasets must be more than 1. Returns: - Dataset, ZipDataset. + ZipDataset, dataset zipped. Raises: ValueError: If the number of datasets is 1. @@ -149,8 +149,8 @@ class Dataset: Internal method to create an IR tree. Returns: - ir_tree, The onject of the IR tree. - dataset, the root dataset of the IR tree. + DatasetNode, the root node of the IR tree. + Dataset, the root dataset of the IR tree. """ parent = self.parent self.parent = [] @@ -165,7 +165,7 @@ class Dataset: Internal method to parse the API tree into an IR tree. Returns: - DatasetNode, The root of the IR tree. + DatasetNode, the root node of the IR tree. """ if len(self.parent) > 1: raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") @@ -197,7 +197,7 @@ class Dataset: Args: Returns: - Python dictionary. + dict, attributes related to the current class. """ args = dict() args["num_parallel_workers"] = self.num_parallel_workers @@ -211,7 +211,7 @@ class Dataset: filename (str): filename of json file to be saved as Returns: - Str, JSON string of the pipeline. + str, JSON string of the pipeline. """ return json.loads(self.parse_tree().to_json(filename)) @@ -258,6 +258,9 @@ class Dataset: drop_remainder (bool, optional): If True, will drop the last batch for each bucket if it is not a full batch (default=False). + Returns: + BucketBatchByLengthDataset, dataset bucketed and batched by length. + Examples: >>> import mindspore.dataset as ds >>> @@ -371,6 +374,9 @@ class Dataset: num_batch (int): the number of batches without blocking at the start of each epoch. callback (function): The callback funciton that will be invoked when sync_update is called. + Returns: + SyncWaitDataset, dataset added a blocking condition. + Raises: RuntimeError: If condition name already exists. @@ -434,7 +440,7 @@ class Dataset: return a 'Dataset'. Returns: - Dataset, applied by the function. + Dataset, dataset applied by the function. Examples: >>> import mindspore.dataset as ds @@ -650,7 +656,7 @@ class Dataset: in parallel (default=None). Returns: - FilterDataset, dataset filter. + FilterDataset, dataset filtered. Examples: >>> import mindspore.dataset as ds @@ -748,6 +754,9 @@ class Dataset: """ Internal method called by split to calculate absolute split sizes and to do some error checking after calculating absolute split sizes. + + Returns: + int, absolute split sizes of the dataset. """ # Call get_dataset_size here and check input here because # don't want to call this once in check_split and another time in @@ -1015,7 +1024,7 @@ class Dataset: is specified and special_first is set to default, special_tokens will be prepended Returns: - Vocab node + Vocab, vocab built from the dataset. Example: >>> import mindspore.dataset as ds @@ -1074,7 +1083,7 @@ class Dataset: params(dict): contains more optional parameters of sentencepiece library Returns: - SentencePieceVocab node + SentencePieceVocab, vocab built from the dataset. Example: >>> import mindspore.dataset as ds @@ -1115,7 +1124,7 @@ class Dataset: return a preprogressing 'Dataset'. Returns: - Dataset, applied by the function. + Dataset, dataset applied by the function. Examples: >>> import mindspore.dataset as ds @@ -1159,7 +1168,7 @@ class Dataset: If device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M. - Return: + Returns: TransferDataset, dataset for transferring. """ return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) @@ -1287,7 +1296,7 @@ class Dataset: use this param to select the conversion method, only take False for better performance (default=True). Returns: - Iterator, list of ndarrays. + TupleIterator, tuple iterator over the dataset. Examples: >>> import mindspore.dataset as ds @@ -1322,7 +1331,7 @@ class Dataset: if output_numpy=False, iterator will output MSTensor (default=False). Returns: - Iterator, dictionary of column name-ndarray pair. + DictIterator, dictionary iterator over the dataset. Examples: >>> import mindspore.dataset as ds @@ -1352,6 +1361,9 @@ class Dataset: """ Get Input Index Information + Returns: + tuple, tuple of the input index information. + Examples: >>> import mindspore.dataset as ds >>> @@ -1409,6 +1421,9 @@ class Dataset: def get_col_names(self): """ Get names of the columns in the dataset + + Returns: + list, list of column names in the dataset. """ if self._col_names is None: runtime_getter = self._init_tree_getters() @@ -1419,8 +1434,8 @@ class Dataset: """ Get the shapes of output data. - Return: - List, list of shapes of each column. + Returns: + list, list of shapes of each column. """ if self.saved_output_shapes is None: runtime_getter = self._init_tree_getters() @@ -1432,8 +1447,8 @@ class Dataset: """ Get the types of output data. - Return: - List of data types. + Returns: + list, list of data types. """ if self.saved_output_types is None: runtime_getter = self._init_tree_getters() @@ -1445,8 +1460,8 @@ class Dataset: """ Get the number of batches in an epoch. - Return: - Number, number of batches. + Returns: + int, number of batches. """ if self.dataset_size is None: runtime_getter = self._init_size_getter() @@ -1457,8 +1472,8 @@ class Dataset: """ Get the number of classes in a dataset. - Return: - Number, number of classes. + Returns: + int, number of classes. """ if self._num_classes is None: runtime_getter = self._init_tree_getters() @@ -1511,8 +1526,8 @@ class Dataset: """ Get the size of a batch. - Return: - Number, the number of data in a batch. + Returns: + int, the number of data in a batch. """ if self._batch_size is None: runtime_getter = self._init_tree_getters() @@ -1525,8 +1540,8 @@ class Dataset: """ Get the replication times in RepeatDataset else 1. - Return: - Number, the count of repeat. + Returns: + int, the count of repeat. """ if self._repeat_count is None: runtime_getter = self._init_tree_getters() @@ -1540,8 +1555,8 @@ class Dataset: Get the class index. Returns: - Dict, A str-to-int mapping from label name to index. - Dict, A str-to-list mapping from label name to index for Coco ONLY. The second number + dict, a str-to-int mapping from label name to index. + dict, a str-to-list mapping from label name to index for Coco ONLY. The second number in the list is used to indicate the super category """ if self.children: @@ -1588,7 +1603,7 @@ class SourceDataset(Dataset): patterns (Union[str, list[str]]): String or list of patterns to be searched. Returns: - List, files. + list, list of files. """ if not isinstance(patterns, list): @@ -1646,9 +1661,6 @@ class MappableDataset(SourceDataset): Args: new_sampler (Sampler): The sampler to use for the current dataset. - Returns: - Dataset, that uses new_sampler. - Examples: >>> import mindspore.dataset as ds >>> @@ -1909,8 +1921,9 @@ class BatchDataset(Dataset): Args: dataset (Dataset): Dataset to be checked. - Return: - True or False. + + Returns: + bool, whether repeat is used before batch. """ if isinstance(dataset, RepeatDataset): return True @@ -1995,18 +2008,12 @@ class BatchInfo(cde.CBatchInfo): def get_batch_num(self): """ Return the batch number of the current batch. - - Return: - Number, number of the current batch. """ return def get_epoch_num(self): """ Return the epoch number of the current batch. - - Return: - Number, number of the current epoch. """ return @@ -2055,8 +2062,8 @@ class BlockReleasePair: """ Function for handing blocking condition. - Return: - True + Returns: + bool, True. """ with self.cv: # if disable is true, the always evaluate to true @@ -2145,8 +2152,9 @@ class SyncWaitDataset(Dataset): Args: dataset (Dataset): Dataset to be checked. - Return: - True or False. + + Returns: + bool, whether sync_wait is used before batch. """ if isinstance(dataset, BatchDataset): return True @@ -2932,6 +2940,9 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, n num_shards (int): Number of shard for sharding. shard_id (int): Shard ID. non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). + + Returns: + Sampler, sampler selected based on user input. """ if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): return None @@ -4180,7 +4191,7 @@ class ManifestDataset(MappableDataset): Get the class index. Returns: - Dict, A str-to-int mapping from label name to index. + dict, a str-to-int mapping from label name to index. """ if self.class_indexing is None: if self._class_indexing is None: @@ -4579,7 +4590,7 @@ class Schema: Args: schema_file(str): Path of schema file (default=None). - Return: + Returns: Schema object, schema info about dataset. Raises: @@ -4654,7 +4665,7 @@ class Schema: Get a JSON string of the schema. Returns: - Str, JSON string of the schema. + str, JSON string of the schema. """ return self.cpp_schema.to_json() @@ -4840,7 +4851,7 @@ class VOCDataset(MappableDataset): Get the class index. Returns: - Dict, A str-to-int mapping from label name to index. + dict, a str-to-int mapping from label name to index. """ if self.task != "Detection": raise NotImplementedError("Only 'Detection' support get_class_indexing.") @@ -5032,7 +5043,7 @@ class CocoDataset(MappableDataset): Get the class index. Returns: - Dict, A str-to-list mapping from label name to index + dict, a str-to-list mapping from label name to index """ if self.task not in {"Detection", "Panoptic"}: raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.") diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 1cd2bea067e..58fe40f47d9 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -100,7 +100,7 @@ class GraphData: node_type (int): Specify the type of node. Returns: - numpy.ndarray: Array of nodes. + numpy.ndarray, array of nodes. Examples: >>> import mindspore.dataset as ds @@ -124,7 +124,7 @@ class GraphData: edge_type (int): Specify the type of edge. Returns: - numpy.ndarray: array of edges. + numpy.ndarray, array of edges. Examples: >>> import mindspore.dataset as ds @@ -148,7 +148,7 @@ class GraphData: edge_list (Union[list, numpy.ndarray]): The given list of edges. Returns: - numpy.ndarray: Array of nodes. + numpy.ndarray, array of nodes. Raises: TypeError: If `edge_list` is not list or ndarray. @@ -167,7 +167,7 @@ class GraphData: neighbor_type (int): Specify the type of neighbor. Returns: - numpy.ndarray: Array of nodes. + numpy.ndarray, array of neighbors. Examples: >>> import mindspore.dataset as ds @@ -201,7 +201,7 @@ class GraphData: neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. Returns: - numpy.ndarray: Array of nodes. + numpy.ndarray, array of neighbors. Examples: >>> import mindspore.dataset as ds @@ -231,7 +231,7 @@ class GraphData: neg_neighbor_type (int): Specify the type of negative neighbor. Returns: - numpy.ndarray: Array of nodes. + numpy.ndarray, array of neighbors. Examples: >>> import mindspore.dataset as ds @@ -260,7 +260,7 @@ class GraphData: feature_types (Union[list, numpy.ndarray]): The given list of feature types. Returns: - numpy.ndarray: array of features. + numpy.ndarray, array of features. Examples: >>> import mindspore.dataset as ds @@ -292,7 +292,7 @@ class GraphData: feature_types (Union[list, numpy.ndarray]): The given list of feature types. Returns: - numpy.ndarray: array of features. + numpy.ndarray, array of features. Examples: >>> import mindspore.dataset as ds @@ -320,7 +320,7 @@ class GraphData: the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. Returns: - dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, + dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num, node_feature_type and edge_feature_type. """ if self._working_mode == 'server': @@ -347,7 +347,7 @@ class GraphData: A default value of -1 indicates that no node is given. Returns: - numpy.ndarray: Array of nodes. + numpy.ndarray, array of nodes. Examples: >>> import mindspore.dataset as ds diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 9107e952968..1fa866ae290 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -128,6 +128,7 @@ class BuiltinSampler: User should not extend this class. """ + def __init__(self, num_samples=None): self.child_sampler = None self.num_samples = num_samples @@ -201,7 +202,7 @@ class BuiltinSampler: - None Returns: - int, The number of samples, or None + int, the number of samples, or None """ if self.child_sampler is not None: child_samples = self.child_sampler.get_num_samples() diff --git a/mindspore/dataset/vision/py_transforms.py b/mindspore/dataset/vision/py_transforms.py index cb5d6338ea0..5b9767e166b 100644 --- a/mindspore/dataset/vision/py_transforms.py +++ b/mindspore/dataset/vision/py_transforms.py @@ -1063,8 +1063,9 @@ class LinearTransformation: the dot product with the transformation matrix, and reshapes it back to its original shape. Args: - transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W. - mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W. + transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), where + :math:`D = C \times H \times W`. + mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where :math:`D = C \times H \times W`. Examples: >>> from mindspore.dataset.transforms.py_transforms import Compose diff --git a/mindspore/explainer/_image_classification_runner.py b/mindspore/explainer/_image_classification_runner.py index 0163ae989f1..c49cf0ed9e5 100644 --- a/mindspore/explainer/_image_classification_runner.py +++ b/mindspore/explainer/_image_classification_runner.py @@ -23,7 +23,7 @@ from PIL import Image import mindspore as ms import mindspore.dataset as ds from mindspore import log -from mindspore.dataset.engine.datasets import Dataset +from mindspore.dataset import Dataset from mindspore.nn import Cell, SequentialCell from mindspore.ops.operations import ExpandDims from mindspore.train._utils import check_value_type diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index cfc9b3ec1e6..e249f6fde09 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -30,6 +30,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchema __all__ = ['FileWriter'] + class FileWriter: """ Class to write user defined raw data into MindRecord File series. @@ -45,6 +46,7 @@ class FileWriter: Raises: ParamValueError: If `file_name` or `shard_num` is invalid. """ + def __init__(self, file_name, shard_num=1): check_filename(file_name) self._file_name = file_name @@ -84,7 +86,7 @@ class FileWriter: file_name (str): String of MindRecord file name. Returns: - Instance of FileWriter. + FileWriter, file writer for the opened MindRecord file. Raises: ParamValueError: If file_name is invalid. @@ -118,7 +120,7 @@ class FileWriter: desc (str, optional): String of schema description (default=None). Returns: - An integer, schema id. + int, schema id. Raises: MRMInvalidSchemaError: If schema is invalid. @@ -175,17 +177,17 @@ class FileWriter: if field not in v: error_data_dic[i] = "for schema, {} th data is wrong, " \ - "there is not '{}' object in the raw data.".format(i, field) + "there is not '{}' object in the raw data.".format(i, field) continue field_type = type(v[field]).__name__ if field_type not in VALUE_TYPE_MAP: error_data_dic[i] = "for schema, {} th data is wrong, " \ - "data type for '{}' is not matched.".format(i, field) + "data type for '{}' is not matched.".format(i, field) continue if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: error_data_dic[i] = "for schema, {} th data is wrong, " \ - "data type for '{}' is not matched.".format(i, field) + "data type for '{}' is not matched.".format(i, field) continue if field_type == 'ndarray': @@ -206,7 +208,6 @@ class FileWriter: def open_and_set_header(self): """ Open writer and set header. - """ if not self._writer.is_open: self._writer.open(self._paths) @@ -222,6 +223,9 @@ class FileWriter: raw_data (list[dict]): List of raw data. parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). + Returns: + MSRStatus, SUCCESS or FAILED. + Raises: ParamTypeError: If index field is invalid. MRMOpenError: If failed to open MindRecord File. @@ -330,7 +334,7 @@ class FileWriter: v (dict): Sub dict in schema Returns: - bool, True or False. + bool, whether the array item is valid. str, error message. """ if v['type'] not in VALID_ARRAY_ATTRIBUTES: @@ -355,7 +359,7 @@ class FileWriter: content (dict): Dict of raw schema. Returns: - bool, True or False. + bool, whether the schema is valid. str, error message. """ error = '' diff --git a/mindspore/mindrecord/mindpage.py b/mindspore/mindrecord/mindpage.py index 43811728471..768e5df3d20 100644 --- a/mindspore/mindrecord/mindpage.py +++ b/mindspore/mindrecord/mindpage.py @@ -23,6 +23,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategor __all__ = ['MindPage'] + class MindPage: """ Class to read MindRecord File series in pagination. @@ -36,6 +37,7 @@ class MindPage: ParamValueError: If `file_name`, `num_consumer` or columns is invalid. MRMInitSegmentError: If failed to initialize ShardSegment. """ + def __init__(self, file_name, num_consumer=4): if isinstance(file_name, list): for f in file_name: @@ -69,7 +71,12 @@ class MindPage: return self._candidate_fields def get_category_fields(self): - """Return candidate category fields.""" + """ + Return candidate category fields. + + Returns: + list[str], by which data could be grouped. + """ logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated." " Please use candidate_fields") return self.candidate_fields @@ -97,12 +104,22 @@ class MindPage: @property def category_field(self): - """Getter function for category fields.""" + """ + Getter function for category fields. + + Returns: + list[str], by which data could be grouped. + """ return self._category_field @category_field.setter def category_field(self, category_field): - """Setter function for category field""" + """ + Setter function for category field. + + Returns: + MSRStatus, SUCCESS or FAILED. + """ if not category_field or not isinstance(category_field, str): raise ParamTypeError('category_fields', 'str') if category_field not in self._candidate_fields: @@ -132,7 +149,7 @@ class MindPage: num_row (int): Number of rows in a page. Returns: - List, list[dict]. + list[dict], data queried by category id. Raises: ParamValueError: If any parameter is invalid. @@ -158,7 +175,7 @@ class MindPage: num_row (int): Number of row in a page. Returns: - str, read at page. + list[dict], data queried by category name. """ if not isinstance(category_name, str): raise ParamValueError("Category name should be str.") diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 37e453a6a7b..1c61674159c 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -23,6 +23,7 @@ from .common.exceptions import MRMOpenError, MRMOpenForAppendError, MRMInvalidHe __all__ = ['ShardWriter'] + class ShardWriter: """ Wrapper class which is represent shardWrite class in c++ module. @@ -192,9 +193,11 @@ class ShardWriter: if len(blob_data) == 1: values = [v for v in blob_data.values()] return bytes(values[0]) + # convert int to bytes def int_to_bytes(x: int) -> bytes: return x.to_bytes(8, 'big') + merged = bytes() for field, v in blob_data.items(): # convert ndarray to bytes @@ -209,7 +212,7 @@ class ShardWriter: Flush data to disk. Returns: - Class MSRStatus, SUCCESS or FAILED. + MSRStatus, SUCCESS or FAILED. Raises: MRMCommitError: If failed to flush data to disk. diff --git a/mindspore/mindrecord/tools/cifar100_to_mr.py b/mindspore/mindrecord/tools/cifar100_to_mr.py index 1c5f7458d87..137171ce5b7 100644 --- a/mindspore/mindrecord/tools/cifar100_to_mr.py +++ b/mindspore/mindrecord/tools/cifar100_to_mr.py @@ -33,6 +33,7 @@ except ModuleNotFoundError: __all__ = ['Cifar100ToMR'] + class Cifar100ToMR: """ A class to transform from cifar100 to MindRecord. @@ -44,6 +45,7 @@ class Cifar100ToMR: Raises: ValueError: If source or destination is invalid. """ + def __init__(self, source, destination): check_filename(source) self.source = source @@ -74,7 +76,7 @@ class Cifar100ToMR: fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"]. Returns: - SUCCESS or FAILED, whether cifar100 is successfully transformed to MindRecord. + MSRStatus, whether cifar100 is successfully transformed to MindRecord. """ if fields and not isinstance(fields, list): raise ValueError("The parameter fields should be None or list") @@ -114,6 +116,7 @@ class Cifar100ToMR: raise t.exception return t.res + def _construct_raw_data(images, fine_labels, coarse_labels): """ Construct raw data from cifar100 data. @@ -124,7 +127,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels): coarse_labels (list): coarse label list from cifar100. Returns: - SUCCESS/FAILED, whether successfully written into MindRecord. + list[dict], data dictionary constructed from cifar100. """ if not cv2: raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") @@ -141,6 +144,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels): raw_data.append(row_data) return raw_data + def _generate_mindrecord(file_name, raw_data, fields, schema_desc): """ Generate MindRecord file from raw data. @@ -153,7 +157,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc): schema_desc (str): String of schema description. Returns: - SUCCESS/FAILED, whether successfully written into MindRecord. + MSRStatus, whether successfully written into MindRecord. """ schema = {"id": {"type": "int64"}, "fine_label": {"type": "int64"}, "coarse_label": {"type": "int64"}, "data": {"type": "bytes"}} diff --git a/mindspore/mindrecord/tools/cifar10_to_mr.py b/mindspore/mindrecord/tools/cifar10_to_mr.py index 118d56ca991..c78293a6135 100644 --- a/mindspore/mindrecord/tools/cifar10_to_mr.py +++ b/mindspore/mindrecord/tools/cifar10_to_mr.py @@ -25,6 +25,7 @@ from .cifar10 import Cifar10 from ..common.exceptions import PathNotExistsError from ..filewriter import FileWriter from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED + try: cv2 = import_module("cv2") except ModuleNotFoundError: @@ -32,6 +33,7 @@ except ModuleNotFoundError: __all__ = ['Cifar10ToMR'] + class Cifar10ToMR: """ A class to transform from cifar10 to MindRecord. @@ -43,6 +45,7 @@ class Cifar10ToMR: Raises: ValueError: If source or destination is invalid. """ + def __init__(self, source, destination): check_filename(source) self.source = source @@ -73,7 +76,7 @@ class Cifar10ToMR: fields (list[str], optional): A list of index fields, e.g.["label"] (default=None). Returns: - SUCCESS or FAILED, whether cifar10 is successfully transformed to MindRecord. + MSRStatus, whether cifar10 is successfully transformed to MindRecord. """ if fields and not isinstance(fields, list): raise ValueError("The parameter fields should be None or list") @@ -109,6 +112,7 @@ class Cifar10ToMR: raise t.exception return t.res + def _construct_raw_data(images, labels): """ Construct raw data from cifar10 data. @@ -118,7 +122,7 @@ def _construct_raw_data(images, labels): labels (list): label list from cifar10. Returns: - SUCCESS/FAILED, whether successfully written into MindRecord. + list[dict], data dictionary constructed from cifar10. """ if not cv2: raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") @@ -133,6 +137,7 @@ def _construct_raw_data(images, labels): raw_data.append(row_data) return raw_data + def _generate_mindrecord(file_name, raw_data, fields, schema_desc): """ Generate MindRecord file from raw data. @@ -145,7 +150,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc): schema_desc (str): String of schema description. Returns: - SUCCESS/FAILED, whether successfully written into MindRecord. + MSRStatus, whether successfully written into MindRecord. """ schema = {"id": {"type": "int64"}, "label": {"type": "int64"}, "data": {"type": "bytes"}} diff --git a/mindspore/mindrecord/tools/csv_to_mr.py b/mindspore/mindrecord/tools/csv_to_mr.py index 5b2fa5ec9dd..f7cf7d8b1c9 100644 --- a/mindspore/mindrecord/tools/csv_to_mr.py +++ b/mindspore/mindrecord/tools/csv_to_mr.py @@ -29,6 +29,7 @@ except ModuleNotFoundError: __all__ = ['CsvToMR'] + class CsvToMR: """ A class to transform from csv to MindRecord. @@ -121,7 +122,7 @@ class CsvToMR: Executes transformation from csv to MindRecord. Returns: - SUCCESS or FAILED, whether csv is successfully transformed to MindRecord. + MSRStatus, whether csv is successfully transformed to MindRecord. """ if not os.path.exists(self.source): raise IOError("Csv file {} do not exist.".format(self.source)) diff --git a/mindspore/mindrecord/tools/imagenet_to_mr.py b/mindspore/mindrecord/tools/imagenet_to_mr.py index b7f8d145463..5158377fe6b 100644 --- a/mindspore/mindrecord/tools/imagenet_to_mr.py +++ b/mindspore/mindrecord/tools/imagenet_to_mr.py @@ -47,6 +47,7 @@ class ImageNetToMR: Raises: ValueError: If `map_file`, `image_dir` or `destination` is invalid. """ + def __init__(self, map_file, image_dir, destination, partition_number=1): check_filename(map_file) self.map_file = map_file @@ -122,7 +123,7 @@ class ImageNetToMR: Executes transformation from imagenet to MindRecord. Returns: - SUCCESS or FAILED, whether imagenet is successfully transformed to MindRecord. + MSRStatus, whether imagenet is successfully transformed to MindRecord. """ t0_total = time.time() @@ -133,10 +134,10 @@ class ImageNetToMR: logger.info("transformed MindRecord schema is: {}".format(imagenet_schema_json)) # set the header size - self.writer.set_header_size(1<<24) + self.writer.set_header_size(1 << 24) # set the page size - self.writer.set_page_size(1<<26) + self.writer.set_page_size(1 << 26) # create the schema self.writer.add_schema(imagenet_schema_json, "imagenet_schema") diff --git a/mindspore/mindrecord/tools/mnist_to_mr.py b/mindspore/mindrecord/tools/mnist_to_mr.py index 5c3d5783fe4..0419a0df319 100644 --- a/mindspore/mindrecord/tools/mnist_to_mr.py +++ b/mindspore/mindrecord/tools/mnist_to_mr.py @@ -32,6 +32,7 @@ except ModuleNotFoundError: __all__ = ['MnistToMR'] + class MnistToMR: """ A class to transform from Mnist to MindRecord. @@ -125,7 +126,7 @@ class MnistToMR: Executes transformation from Mnist train part to MindRecord. Returns: - SUCCESS/FAILED, whether successfully written into MindRecord. + MSRStatus, whether successfully written into MindRecord. """ t0_total = time.time() @@ -173,7 +174,7 @@ class MnistToMR: Executes transformation from Mnist test part to MindRecord. Returns: - SUCCESS or FAILED, whether Mnist is successfully transformed to MindRecord. + MSRStatus, whether Mnist is successfully transformed to MindRecord. """ t0_total = time.time() @@ -222,7 +223,7 @@ class MnistToMR: Executes transformation from Mnist to MindRecord. Returns: - SUCCESS/FAILED, whether successfully written into MindRecord. + MSRStatus, whether successfully written into MindRecord. """ if not cv2: raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") diff --git a/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/mindrecord/tools/tfrecord_to_mr.py index 34da5c24fb0..2c4b3d85e14 100644 --- a/mindspore/mindrecord/tools/tfrecord_to_mr.py +++ b/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -23,7 +23,6 @@ from mindspore import log as logger from ..filewriter import FileWriter from ..shardutils import check_filename, ExceptionThread - __all__ = ['TFRecordToMR'] SupportedTensorFlowVersion = '1.13.0-rc1' @@ -86,9 +85,10 @@ class TFRecordToMR: ValueError: If parameter is invalid. Exception: when tensorflow module is not found or version is not correct. """ + def __init__(self, source, destination, feature_dict, bytes_fields=None): try: - self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord + self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord except ModuleNotFoundError: self.tf = None if not self.tf: @@ -265,7 +265,7 @@ class TFRecordToMR: Execute transformation from TFRecord to MindRecord. Returns: - SUCCESS or FAILED, whether TFRecord is successfuly transformed to MindRecord. + MSRStatus, whether TFRecord is successfuly transformed to MindRecord. """ writer = FileWriter(self.destination) logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" diff --git a/model_zoo/official/recommend/ncf/src/dataset.py b/model_zoo/official/recommend/ncf/src/dataset.py index 039b5721153..589ff742bfb 100644 --- a/model_zoo/official/recommend/ncf/src/dataset.py +++ b/model_zoo/official/recommend/ncf/src/dataset.py @@ -22,13 +22,12 @@ import pickle import numpy as np import pandas as pd -from mindspore.dataset.engine import GeneratorDataset +from mindspore.dataset import GeneratorDataset import src.constants as rconst import src.movielens as movielens import src.stat_utils as stat_utils - DATASET_TO_NUM_USERS_AND_ITEMS = { "ml-1m": (6040, 3706), "ml-20m": (138493, 26744) @@ -205,6 +204,7 @@ class NCFDataset: """ A dataset for NCF network. """ + def __init__(self, pos_users, pos_items, @@ -407,6 +407,7 @@ class RandomSampler: """ A random sampler for dataset. """ + def __init__(self, pos_count, num_train_negatives, batch_size): self.pos_count = pos_count self._num_samples = (1 + num_train_negatives) * self.pos_count @@ -433,6 +434,7 @@ class DistributedSamplerOfTrain: """ A distributed sampler for dataset. """ + def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size): """ Distributed sampler of training dataset. @@ -443,15 +445,16 @@ class DistributedSamplerOfTrain: self._batch_size = batch_size self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size)) - self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size)) + self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size)) self._total_num_samples = self._samples_per_rank * self._rank_size + def __iter__(self): """ Returns the data after each sampling. """ indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32())) indices = indices.tolist() - indices.extend(indices[:self._total_num_samples-len(indices)]) + indices.extend(indices[:self._total_num_samples - len(indices)]) indices = indices[self._rank_id:self._total_num_samples:self._rank_size] batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)] @@ -463,10 +466,12 @@ class DistributedSamplerOfTrain: """ return self._batchs_per_rank + class SequenceSampler: """ A sequence sampler for dataset. """ + def __init__(self, eval_batch_size, num_users): self._eval_users_per_batch = int( eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) @@ -491,10 +496,12 @@ class SequenceSampler: """ return self._eval_batches_per_epoch + class DistributedSamplerOfEval: """ A distributed sampler for eval dataset. """ + def __init__(self, eval_batch_size, num_users, rank_id, rank_size): self._eval_users_per_batch = int( eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) @@ -507,8 +514,8 @@ class DistributedSamplerOfEval: self._eval_batch_size = eval_batch_size self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size)) - #self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size)) - #self._total_num_samples = self._samples_per_rank * self._rank_size + # self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size)) + # self._total_num_samples = self._samples_per_rank * self._rank_size def __iter__(self): indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch) @@ -525,6 +532,7 @@ class DistributedSamplerOfEval: def __len__(self): return self._batchs_per_rank + def parse_eval_batch_size(eval_batch_size): """ Parse eval batch size.