modify return description in comments

This commit is contained in:
Xiao Tianci 2021-01-18 16:49:38 +08:00
parent cf87c0304d
commit f7093efe6a
16 changed files with 165 additions and 107 deletions

View File

@ -102,7 +102,7 @@ def get_seed():
Get the seed.
Returns:
Int, seed.
int, seed.
"""
return _config.get_seed()
@ -131,7 +131,7 @@ def get_prefetch_size():
Get the prefetch size in number of rows.
Returns:
Size, total number of rows to be prefetched.
int, total number of rows to be prefetched.
"""
return _config.get_op_connector_size()
@ -162,7 +162,7 @@ def get_num_parallel_workers():
This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature.
Returns:
Int, number of parallel workers to be used as a default for each operation
int, number of parallel workers to be used as a default for each operation.
"""
return _config.get_num_parallel_workers()
@ -193,7 +193,7 @@ def get_numa_enable():
This is the DEFAULT numa enabled value used for the all process.
Returns:
boolean, the default state of numa enabled
bool, the default state of numa enabled.
"""
return _config.get_numa_enable()
@ -222,7 +222,7 @@ def get_monitor_sampling_interval():
Get the default interval of performance monitor sampling.
Returns:
Int, interval (in milliseconds) for performance monitor sampling.
int, interval (in milliseconds) for performance monitor sampling.
"""
return _config.get_monitor_sampling_interval()
@ -280,7 +280,8 @@ def get_auto_num_workers():
Get the setting (turned on or off) automatic number of workers.
Returns:
Bool, whether auto num worker feature is turned on
bool, whether auto num worker feature is turned on.
Examples:
>>> num_workers = ds.config.get_auto_num_workers()
"""
@ -313,7 +314,7 @@ def get_callback_timeout():
In case of a deadlock, the wait function will exit after the timeout period.
Returns:
Int, the duration in seconds
int, the duration in seconds.
"""
return _config.get_callback_timeout()
@ -323,7 +324,7 @@ def __str__():
String representation of the configurations.
Returns:
Str, configurations.
str, configurations.
"""
return str(_config)

View File

@ -80,7 +80,7 @@ def zip(datasets):
The number of datasets must be more than 1.
Returns:
Dataset, ZipDataset.
ZipDataset, dataset zipped.
Raises:
ValueError: If the number of datasets is 1.
@ -149,8 +149,8 @@ class Dataset:
Internal method to create an IR tree.
Returns:
ir_tree, The onject of the IR tree.
dataset, the root dataset of the IR tree.
DatasetNode, the root node of the IR tree.
Dataset, the root dataset of the IR tree.
"""
parent = self.parent
self.parent = []
@ -165,7 +165,7 @@ class Dataset:
Internal method to parse the API tree into an IR tree.
Returns:
DatasetNode, The root of the IR tree.
DatasetNode, the root node of the IR tree.
"""
if len(self.parent) > 1:
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
@ -197,7 +197,7 @@ class Dataset:
Args:
Returns:
Python dictionary.
dict, attributes related to the current class.
"""
args = dict()
args["num_parallel_workers"] = self.num_parallel_workers
@ -211,7 +211,7 @@ class Dataset:
filename (str): filename of json file to be saved as
Returns:
Str, JSON string of the pipeline.
str, JSON string of the pipeline.
"""
return json.loads(self.parse_tree().to_json(filename))
@ -258,6 +258,9 @@ class Dataset:
drop_remainder (bool, optional): If True, will drop the last batch for each
bucket if it is not a full batch (default=False).
Returns:
BucketBatchByLengthDataset, dataset bucketed and batched by length.
Examples:
>>> import mindspore.dataset as ds
>>>
@ -371,6 +374,9 @@ class Dataset:
num_batch (int): the number of batches without blocking at the start of each epoch.
callback (function): The callback funciton that will be invoked when sync_update is called.
Returns:
SyncWaitDataset, dataset added a blocking condition.
Raises:
RuntimeError: If condition name already exists.
@ -434,7 +440,7 @@ class Dataset:
return a 'Dataset'.
Returns:
Dataset, applied by the function.
Dataset, dataset applied by the function.
Examples:
>>> import mindspore.dataset as ds
@ -650,7 +656,7 @@ class Dataset:
in parallel (default=None).
Returns:
FilterDataset, dataset filter.
FilterDataset, dataset filtered.
Examples:
>>> import mindspore.dataset as ds
@ -748,6 +754,9 @@ class Dataset:
"""
Internal method called by split to calculate absolute split sizes and to
do some error checking after calculating absolute split sizes.
Returns:
int, absolute split sizes of the dataset.
"""
# Call get_dataset_size here and check input here because
# don't want to call this once in check_split and another time in
@ -1015,7 +1024,7 @@ class Dataset:
is specified and special_first is set to default, special_tokens will be prepended
Returns:
Vocab node
Vocab, vocab built from the dataset.
Example:
>>> import mindspore.dataset as ds
@ -1074,7 +1083,7 @@ class Dataset:
params(dict): contains more optional parameters of sentencepiece library
Returns:
SentencePieceVocab node
SentencePieceVocab, vocab built from the dataset.
Example:
>>> import mindspore.dataset as ds
@ -1115,7 +1124,7 @@ class Dataset:
return a preprogressing 'Dataset'.
Returns:
Dataset, applied by the function.
Dataset, dataset applied by the function.
Examples:
>>> import mindspore.dataset as ds
@ -1159,7 +1168,7 @@ class Dataset:
If device is Ascend, features of data will be transferred one by one. The limitation
of data transmission per time is 256M.
Return:
Returns:
TransferDataset, dataset for transferring.
"""
return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
@ -1287,7 +1296,7 @@ class Dataset:
use this param to select the conversion method, only take False for better performance (default=True).
Returns:
Iterator, list of ndarrays.
TupleIterator, tuple iterator over the dataset.
Examples:
>>> import mindspore.dataset as ds
@ -1322,7 +1331,7 @@ class Dataset:
if output_numpy=False, iterator will output MSTensor (default=False).
Returns:
Iterator, dictionary of column name-ndarray pair.
DictIterator, dictionary iterator over the dataset.
Examples:
>>> import mindspore.dataset as ds
@ -1352,6 +1361,9 @@ class Dataset:
"""
Get Input Index Information
Returns:
tuple, tuple of the input index information.
Examples:
>>> import mindspore.dataset as ds
>>>
@ -1409,6 +1421,9 @@ class Dataset:
def get_col_names(self):
"""
Get names of the columns in the dataset
Returns:
list, list of column names in the dataset.
"""
if self._col_names is None:
runtime_getter = self._init_tree_getters()
@ -1419,8 +1434,8 @@ class Dataset:
"""
Get the shapes of output data.
Return:
List, list of shapes of each column.
Returns:
list, list of shapes of each column.
"""
if self.saved_output_shapes is None:
runtime_getter = self._init_tree_getters()
@ -1432,8 +1447,8 @@ class Dataset:
"""
Get the types of output data.
Return:
List of data types.
Returns:
list, list of data types.
"""
if self.saved_output_types is None:
runtime_getter = self._init_tree_getters()
@ -1445,8 +1460,8 @@ class Dataset:
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
Returns:
int, number of batches.
"""
if self.dataset_size is None:
runtime_getter = self._init_size_getter()
@ -1457,8 +1472,8 @@ class Dataset:
"""
Get the number of classes in a dataset.
Return:
Number, number of classes.
Returns:
int, number of classes.
"""
if self._num_classes is None:
runtime_getter = self._init_tree_getters()
@ -1511,8 +1526,8 @@ class Dataset:
"""
Get the size of a batch.
Return:
Number, the number of data in a batch.
Returns:
int, the number of data in a batch.
"""
if self._batch_size is None:
runtime_getter = self._init_tree_getters()
@ -1525,8 +1540,8 @@ class Dataset:
"""
Get the replication times in RepeatDataset else 1.
Return:
Number, the count of repeat.
Returns:
int, the count of repeat.
"""
if self._repeat_count is None:
runtime_getter = self._init_tree_getters()
@ -1540,8 +1555,8 @@ class Dataset:
Get the class index.
Returns:
Dict, A str-to-int mapping from label name to index.
Dict, A str-to-list<int> mapping from label name to index for Coco ONLY. The second number
dict, a str-to-int mapping from label name to index.
dict, a str-to-list<int> mapping from label name to index for Coco ONLY. The second number
in the list is used to indicate the super category
"""
if self.children:
@ -1588,7 +1603,7 @@ class SourceDataset(Dataset):
patterns (Union[str, list[str]]): String or list of patterns to be searched.
Returns:
List, files.
list, list of files.
"""
if not isinstance(patterns, list):
@ -1646,9 +1661,6 @@ class MappableDataset(SourceDataset):
Args:
new_sampler (Sampler): The sampler to use for the current dataset.
Returns:
Dataset, that uses new_sampler.
Examples:
>>> import mindspore.dataset as ds
>>>
@ -1909,8 +1921,9 @@ class BatchDataset(Dataset):
Args:
dataset (Dataset): Dataset to be checked.
Return:
True or False.
Returns:
bool, whether repeat is used before batch.
"""
if isinstance(dataset, RepeatDataset):
return True
@ -1995,18 +2008,12 @@ class BatchInfo(cde.CBatchInfo):
def get_batch_num(self):
"""
Return the batch number of the current batch.
Return:
Number, number of the current batch.
"""
return
def get_epoch_num(self):
"""
Return the epoch number of the current batch.
Return:
Number, number of the current epoch.
"""
return
@ -2055,8 +2062,8 @@ class BlockReleasePair:
"""
Function for handing blocking condition.
Return:
True
Returns:
bool, True.
"""
with self.cv:
# if disable is true, the always evaluate to true
@ -2145,8 +2152,9 @@ class SyncWaitDataset(Dataset):
Args:
dataset (Dataset): Dataset to be checked.
Return:
True or False.
Returns:
bool, whether sync_wait is used before batch.
"""
if isinstance(dataset, BatchDataset):
return True
@ -2932,6 +2940,9 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, n
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).
Returns:
Sampler, sampler selected based on user input.
"""
if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
return None
@ -4180,7 +4191,7 @@ class ManifestDataset(MappableDataset):
Get the class index.
Returns:
Dict, A str-to-int mapping from label name to index.
dict, a str-to-int mapping from label name to index.
"""
if self.class_indexing is None:
if self._class_indexing is None:
@ -4579,7 +4590,7 @@ class Schema:
Args:
schema_file(str): Path of schema file (default=None).
Return:
Returns:
Schema object, schema info about dataset.
Raises:
@ -4654,7 +4665,7 @@ class Schema:
Get a JSON string of the schema.
Returns:
Str, JSON string of the schema.
str, JSON string of the schema.
"""
return self.cpp_schema.to_json()
@ -4840,7 +4851,7 @@ class VOCDataset(MappableDataset):
Get the class index.
Returns:
Dict, A str-to-int mapping from label name to index.
dict, a str-to-int mapping from label name to index.
"""
if self.task != "Detection":
raise NotImplementedError("Only 'Detection' support get_class_indexing.")
@ -5032,7 +5043,7 @@ class CocoDataset(MappableDataset):
Get the class index.
Returns:
Dict, A str-to-list<int> mapping from label name to index
dict, a str-to-list<int> mapping from label name to index
"""
if self.task not in {"Detection", "Panoptic"}:
raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.")

View File

@ -100,7 +100,7 @@ class GraphData:
node_type (int): Specify the type of node.
Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of nodes.
Examples:
>>> import mindspore.dataset as ds
@ -124,7 +124,7 @@ class GraphData:
edge_type (int): Specify the type of edge.
Returns:
numpy.ndarray: array of edges.
numpy.ndarray, array of edges.
Examples:
>>> import mindspore.dataset as ds
@ -148,7 +148,7 @@ class GraphData:
edge_list (Union[list, numpy.ndarray]): The given list of edges.
Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of nodes.
Raises:
TypeError: If `edge_list` is not list or ndarray.
@ -167,7 +167,7 @@ class GraphData:
neighbor_type (int): Specify the type of neighbor.
Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of neighbors.
Examples:
>>> import mindspore.dataset as ds
@ -201,7 +201,7 @@ class GraphData:
neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop.
Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of neighbors.
Examples:
>>> import mindspore.dataset as ds
@ -231,7 +231,7 @@ class GraphData:
neg_neighbor_type (int): Specify the type of negative neighbor.
Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of neighbors.
Examples:
>>> import mindspore.dataset as ds
@ -260,7 +260,7 @@ class GraphData:
feature_types (Union[list, numpy.ndarray]): The given list of feature types.
Returns:
numpy.ndarray: array of features.
numpy.ndarray, array of features.
Examples:
>>> import mindspore.dataset as ds
@ -292,7 +292,7 @@ class GraphData:
feature_types (Union[list, numpy.ndarray]): The given list of feature types.
Returns:
numpy.ndarray: array of features.
numpy.ndarray, array of features.
Examples:
>>> import mindspore.dataset as ds
@ -320,7 +320,7 @@ class GraphData:
the feature information of nodes, the number of edges, the type of edges, and the feature information of edges.
Returns:
dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
node_feature_type and edge_feature_type.
"""
if self._working_mode == 'server':
@ -347,7 +347,7 @@ class GraphData:
A default value of -1 indicates that no node is given.
Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of nodes.
Examples:
>>> import mindspore.dataset as ds

View File

@ -128,6 +128,7 @@ class BuiltinSampler:
User should not extend this class.
"""
def __init__(self, num_samples=None):
self.child_sampler = None
self.num_samples = num_samples
@ -201,7 +202,7 @@ class BuiltinSampler:
- None
Returns:
int, The number of samples, or None
int, the number of samples, or None
"""
if self.child_sampler is not None:
child_samples = self.child_sampler.get_num_samples()

View File

@ -1063,8 +1063,9 @@ class LinearTransformation:
the dot product with the transformation matrix, and reshapes it back to its original shape.
Args:
transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), where
:math:`D = C \times H \times W`.
mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where :math:`D = C \times H \times W`.
Examples:
>>> from mindspore.dataset.transforms.py_transforms import Compose

View File

@ -23,7 +23,7 @@ from PIL import Image
import mindspore as ms
import mindspore.dataset as ds
from mindspore import log
from mindspore.dataset.engine.datasets import Dataset
from mindspore.dataset import Dataset
from mindspore.nn import Cell, SequentialCell
from mindspore.ops.operations import ExpandDims
from mindspore.train._utils import check_value_type

View File

@ -30,6 +30,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchema
__all__ = ['FileWriter']
class FileWriter:
"""
Class to write user defined raw data into MindRecord File series.
@ -45,6 +46,7 @@ class FileWriter:
Raises:
ParamValueError: If `file_name` or `shard_num` is invalid.
"""
def __init__(self, file_name, shard_num=1):
check_filename(file_name)
self._file_name = file_name
@ -84,7 +86,7 @@ class FileWriter:
file_name (str): String of MindRecord file name.
Returns:
Instance of FileWriter.
FileWriter, file writer for the opened MindRecord file.
Raises:
ParamValueError: If file_name is invalid.
@ -118,7 +120,7 @@ class FileWriter:
desc (str, optional): String of schema description (default=None).
Returns:
An integer, schema id.
int, schema id.
Raises:
MRMInvalidSchemaError: If schema is invalid.
@ -175,17 +177,17 @@ class FileWriter:
if field not in v:
error_data_dic[i] = "for schema, {} th data is wrong, " \
"there is not '{}' object in the raw data.".format(i, field)
"there is not '{}' object in the raw data.".format(i, field)
continue
field_type = type(v[field]).__name__
if field_type not in VALUE_TYPE_MAP:
error_data_dic[i] = "for schema, {} th data is wrong, " \
"data type for '{}' is not matched.".format(i, field)
"data type for '{}' is not matched.".format(i, field)
continue
if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]:
error_data_dic[i] = "for schema, {} th data is wrong, " \
"data type for '{}' is not matched.".format(i, field)
"data type for '{}' is not matched.".format(i, field)
continue
if field_type == 'ndarray':
@ -206,7 +208,6 @@ class FileWriter:
def open_and_set_header(self):
"""
Open writer and set header.
"""
if not self._writer.is_open:
self._writer.open(self._paths)
@ -222,6 +223,9 @@ class FileWriter:
raw_data (list[dict]): List of raw data.
parallel_writer (bool, optional): Load data parallel if it equals to True (default=False).
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
ParamTypeError: If index field is invalid.
MRMOpenError: If failed to open MindRecord File.
@ -330,7 +334,7 @@ class FileWriter:
v (dict): Sub dict in schema
Returns:
bool, True or False.
bool, whether the array item is valid.
str, error message.
"""
if v['type'] not in VALID_ARRAY_ATTRIBUTES:
@ -355,7 +359,7 @@ class FileWriter:
content (dict): Dict of raw schema.
Returns:
bool, True or False.
bool, whether the schema is valid.
str, error message.
"""
error = ''

View File

@ -23,6 +23,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategor
__all__ = ['MindPage']
class MindPage:
"""
Class to read MindRecord File series in pagination.
@ -36,6 +37,7 @@ class MindPage:
ParamValueError: If `file_name`, `num_consumer` or columns is invalid.
MRMInitSegmentError: If failed to initialize ShardSegment.
"""
def __init__(self, file_name, num_consumer=4):
if isinstance(file_name, list):
for f in file_name:
@ -69,7 +71,12 @@ class MindPage:
return self._candidate_fields
def get_category_fields(self):
"""Return candidate category fields."""
"""
Return candidate category fields.
Returns:
list[str], by which data could be grouped.
"""
logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated."
" Please use candidate_fields")
return self.candidate_fields
@ -97,12 +104,22 @@ class MindPage:
@property
def category_field(self):
"""Getter function for category fields."""
"""
Getter function for category fields.
Returns:
list[str], by which data could be grouped.
"""
return self._category_field
@category_field.setter
def category_field(self, category_field):
"""Setter function for category field"""
"""
Setter function for category field.
Returns:
MSRStatus, SUCCESS or FAILED.
"""
if not category_field or not isinstance(category_field, str):
raise ParamTypeError('category_fields', 'str')
if category_field not in self._candidate_fields:
@ -132,7 +149,7 @@ class MindPage:
num_row (int): Number of rows in a page.
Returns:
List, list[dict].
list[dict], data queried by category id.
Raises:
ParamValueError: If any parameter is invalid.
@ -158,7 +175,7 @@ class MindPage:
num_row (int): Number of row in a page.
Returns:
str, read at page.
list[dict], data queried by category name.
"""
if not isinstance(category_name, str):
raise ParamValueError("Category name should be str.")

View File

@ -23,6 +23,7 @@ from .common.exceptions import MRMOpenError, MRMOpenForAppendError, MRMInvalidHe
__all__ = ['ShardWriter']
class ShardWriter:
"""
Wrapper class which is represent shardWrite class in c++ module.
@ -192,9 +193,11 @@ class ShardWriter:
if len(blob_data) == 1:
values = [v for v in blob_data.values()]
return bytes(values[0])
# convert int to bytes
def int_to_bytes(x: int) -> bytes:
return x.to_bytes(8, 'big')
merged = bytes()
for field, v in blob_data.items():
# convert ndarray to bytes
@ -209,7 +212,7 @@ class ShardWriter:
Flush data to disk.
Returns:
Class MSRStatus, SUCCESS or FAILED.
MSRStatus, SUCCESS or FAILED.
Raises:
MRMCommitError: If failed to flush data to disk.

View File

@ -33,6 +33,7 @@ except ModuleNotFoundError:
__all__ = ['Cifar100ToMR']
class Cifar100ToMR:
"""
A class to transform from cifar100 to MindRecord.
@ -44,6 +45,7 @@ class Cifar100ToMR:
Raises:
ValueError: If source or destination is invalid.
"""
def __init__(self, source, destination):
check_filename(source)
self.source = source
@ -74,7 +76,7 @@ class Cifar100ToMR:
fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"].
Returns:
SUCCESS or FAILED, whether cifar100 is successfully transformed to MindRecord.
MSRStatus, whether cifar100 is successfully transformed to MindRecord.
"""
if fields and not isinstance(fields, list):
raise ValueError("The parameter fields should be None or list")
@ -114,6 +116,7 @@ class Cifar100ToMR:
raise t.exception
return t.res
def _construct_raw_data(images, fine_labels, coarse_labels):
"""
Construct raw data from cifar100 data.
@ -124,7 +127,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels):
coarse_labels (list): coarse label list from cifar100.
Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
list[dict], data dictionary constructed from cifar100.
"""
if not cv2:
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")
@ -141,6 +144,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels):
raw_data.append(row_data)
return raw_data
def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
"""
Generate MindRecord file from raw data.
@ -153,7 +157,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
schema_desc (str): String of schema description.
Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
MSRStatus, whether successfully written into MindRecord.
"""
schema = {"id": {"type": "int64"}, "fine_label": {"type": "int64"},
"coarse_label": {"type": "int64"}, "data": {"type": "bytes"}}

View File

@ -25,6 +25,7 @@ from .cifar10 import Cifar10
from ..common.exceptions import PathNotExistsError
from ..filewriter import FileWriter
from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED
try:
cv2 = import_module("cv2")
except ModuleNotFoundError:
@ -32,6 +33,7 @@ except ModuleNotFoundError:
__all__ = ['Cifar10ToMR']
class Cifar10ToMR:
"""
A class to transform from cifar10 to MindRecord.
@ -43,6 +45,7 @@ class Cifar10ToMR:
Raises:
ValueError: If source or destination is invalid.
"""
def __init__(self, source, destination):
check_filename(source)
self.source = source
@ -73,7 +76,7 @@ class Cifar10ToMR:
fields (list[str], optional): A list of index fields, e.g.["label"] (default=None).
Returns:
SUCCESS or FAILED, whether cifar10 is successfully transformed to MindRecord.
MSRStatus, whether cifar10 is successfully transformed to MindRecord.
"""
if fields and not isinstance(fields, list):
raise ValueError("The parameter fields should be None or list")
@ -109,6 +112,7 @@ class Cifar10ToMR:
raise t.exception
return t.res
def _construct_raw_data(images, labels):
"""
Construct raw data from cifar10 data.
@ -118,7 +122,7 @@ def _construct_raw_data(images, labels):
labels (list): label list from cifar10.
Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
list[dict], data dictionary constructed from cifar10.
"""
if not cv2:
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")
@ -133,6 +137,7 @@ def _construct_raw_data(images, labels):
raw_data.append(row_data)
return raw_data
def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
"""
Generate MindRecord file from raw data.
@ -145,7 +150,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
schema_desc (str): String of schema description.
Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
MSRStatus, whether successfully written into MindRecord.
"""
schema = {"id": {"type": "int64"}, "label": {"type": "int64"},
"data": {"type": "bytes"}}

View File

@ -29,6 +29,7 @@ except ModuleNotFoundError:
__all__ = ['CsvToMR']
class CsvToMR:
"""
A class to transform from csv to MindRecord.
@ -121,7 +122,7 @@ class CsvToMR:
Executes transformation from csv to MindRecord.
Returns:
SUCCESS or FAILED, whether csv is successfully transformed to MindRecord.
MSRStatus, whether csv is successfully transformed to MindRecord.
"""
if not os.path.exists(self.source):
raise IOError("Csv file {} do not exist.".format(self.source))

View File

@ -47,6 +47,7 @@ class ImageNetToMR:
Raises:
ValueError: If `map_file`, `image_dir` or `destination` is invalid.
"""
def __init__(self, map_file, image_dir, destination, partition_number=1):
check_filename(map_file)
self.map_file = map_file
@ -122,7 +123,7 @@ class ImageNetToMR:
Executes transformation from imagenet to MindRecord.
Returns:
SUCCESS or FAILED, whether imagenet is successfully transformed to MindRecord.
MSRStatus, whether imagenet is successfully transformed to MindRecord.
"""
t0_total = time.time()
@ -133,10 +134,10 @@ class ImageNetToMR:
logger.info("transformed MindRecord schema is: {}".format(imagenet_schema_json))
# set the header size
self.writer.set_header_size(1<<24)
self.writer.set_header_size(1 << 24)
# set the page size
self.writer.set_page_size(1<<26)
self.writer.set_page_size(1 << 26)
# create the schema
self.writer.add_schema(imagenet_schema_json, "imagenet_schema")

View File

@ -32,6 +32,7 @@ except ModuleNotFoundError:
__all__ = ['MnistToMR']
class MnistToMR:
"""
A class to transform from Mnist to MindRecord.
@ -125,7 +126,7 @@ class MnistToMR:
Executes transformation from Mnist train part to MindRecord.
Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
MSRStatus, whether successfully written into MindRecord.
"""
t0_total = time.time()
@ -173,7 +174,7 @@ class MnistToMR:
Executes transformation from Mnist test part to MindRecord.
Returns:
SUCCESS or FAILED, whether Mnist is successfully transformed to MindRecord.
MSRStatus, whether Mnist is successfully transformed to MindRecord.
"""
t0_total = time.time()
@ -222,7 +223,7 @@ class MnistToMR:
Executes transformation from Mnist to MindRecord.
Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
MSRStatus, whether successfully written into MindRecord.
"""
if not cv2:
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")

View File

@ -23,7 +23,6 @@ from mindspore import log as logger
from ..filewriter import FileWriter
from ..shardutils import check_filename, ExceptionThread
__all__ = ['TFRecordToMR']
SupportedTensorFlowVersion = '1.13.0-rc1'
@ -86,9 +85,10 @@ class TFRecordToMR:
ValueError: If parameter is invalid.
Exception: when tensorflow module is not found or version is not correct.
"""
def __init__(self, source, destination, feature_dict, bytes_fields=None):
try:
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
except ModuleNotFoundError:
self.tf = None
if not self.tf:
@ -265,7 +265,7 @@ class TFRecordToMR:
Execute transformation from TFRecord to MindRecord.
Returns:
SUCCESS or FAILED, whether TFRecord is successfuly transformed to MindRecord.
MSRStatus, whether TFRecord is successfuly transformed to MindRecord.
"""
writer = FileWriter(self.destination)
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"

View File

@ -22,13 +22,12 @@ import pickle
import numpy as np
import pandas as pd
from mindspore.dataset.engine import GeneratorDataset
from mindspore.dataset import GeneratorDataset
import src.constants as rconst
import src.movielens as movielens
import src.stat_utils as stat_utils
DATASET_TO_NUM_USERS_AND_ITEMS = {
"ml-1m": (6040, 3706),
"ml-20m": (138493, 26744)
@ -205,6 +204,7 @@ class NCFDataset:
"""
A dataset for NCF network.
"""
def __init__(self,
pos_users,
pos_items,
@ -407,6 +407,7 @@ class RandomSampler:
"""
A random sampler for dataset.
"""
def __init__(self, pos_count, num_train_negatives, batch_size):
self.pos_count = pos_count
self._num_samples = (1 + num_train_negatives) * self.pos_count
@ -433,6 +434,7 @@ class DistributedSamplerOfTrain:
"""
A distributed sampler for dataset.
"""
def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size):
"""
Distributed sampler of training dataset.
@ -443,15 +445,16 @@ class DistributedSamplerOfTrain:
self._batch_size = batch_size
self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size))
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
self._total_num_samples = self._samples_per_rank * self._rank_size
def __iter__(self):
"""
Returns the data after each sampling.
"""
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
indices = indices.tolist()
indices.extend(indices[:self._total_num_samples-len(indices)])
indices.extend(indices[:self._total_num_samples - len(indices)])
indices = indices[self._rank_id:self._total_num_samples:self._rank_size]
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)]
@ -463,10 +466,12 @@ class DistributedSamplerOfTrain:
"""
return self._batchs_per_rank
class SequenceSampler:
"""
A sequence sampler for dataset.
"""
def __init__(self, eval_batch_size, num_users):
self._eval_users_per_batch = int(
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
@ -491,10 +496,12 @@ class SequenceSampler:
"""
return self._eval_batches_per_epoch
class DistributedSamplerOfEval:
"""
A distributed sampler for eval dataset.
"""
def __init__(self, eval_batch_size, num_users, rank_id, rank_size):
self._eval_users_per_batch = int(
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
@ -507,8 +514,8 @@ class DistributedSamplerOfEval:
self._eval_batch_size = eval_batch_size
self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size))
#self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
#self._total_num_samples = self._samples_per_rank * self._rank_size
# self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
# self._total_num_samples = self._samples_per_rank * self._rank_size
def __iter__(self):
indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch)
@ -525,6 +532,7 @@ class DistributedSamplerOfEval:
def __len__(self):
return self._batchs_per_rank
def parse_eval_batch_size(eval_batch_size):
"""
Parse eval batch size.