forked from mindspore-Ecosystem/mindspore
!11379 fix missing return description in comment
From: @tiancixiao Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunwei
This commit is contained in:
commit
344567e5d7
|
@ -102,7 +102,7 @@ def get_seed():
|
|||
Get the seed.
|
||||
|
||||
Returns:
|
||||
Int, seed.
|
||||
int, seed.
|
||||
"""
|
||||
return _config.get_seed()
|
||||
|
||||
|
@ -131,7 +131,7 @@ def get_prefetch_size():
|
|||
Get the prefetch size in number of rows.
|
||||
|
||||
Returns:
|
||||
Size, total number of rows to be prefetched.
|
||||
int, total number of rows to be prefetched.
|
||||
"""
|
||||
return _config.get_op_connector_size()
|
||||
|
||||
|
@ -162,7 +162,7 @@ def get_num_parallel_workers():
|
|||
This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature.
|
||||
|
||||
Returns:
|
||||
Int, number of parallel workers to be used as a default for each operation
|
||||
int, number of parallel workers to be used as a default for each operation.
|
||||
"""
|
||||
return _config.get_num_parallel_workers()
|
||||
|
||||
|
@ -193,7 +193,7 @@ def get_numa_enable():
|
|||
This is the DEFAULT numa enabled value used for the all process.
|
||||
|
||||
Returns:
|
||||
boolean, the default state of numa enabled
|
||||
bool, the default state of numa enabled.
|
||||
"""
|
||||
return _config.get_numa_enable()
|
||||
|
||||
|
@ -222,7 +222,7 @@ def get_monitor_sampling_interval():
|
|||
Get the default interval of performance monitor sampling.
|
||||
|
||||
Returns:
|
||||
Int, interval (in milliseconds) for performance monitor sampling.
|
||||
int, interval (in milliseconds) for performance monitor sampling.
|
||||
"""
|
||||
return _config.get_monitor_sampling_interval()
|
||||
|
||||
|
@ -280,7 +280,8 @@ def get_auto_num_workers():
|
|||
Get the setting (turned on or off) automatic number of workers.
|
||||
|
||||
Returns:
|
||||
Bool, whether auto num worker feature is turned on
|
||||
bool, whether auto num worker feature is turned on.
|
||||
|
||||
Examples:
|
||||
>>> num_workers = ds.config.get_auto_num_workers()
|
||||
"""
|
||||
|
@ -313,7 +314,7 @@ def get_callback_timeout():
|
|||
In case of a deadlock, the wait function will exit after the timeout period.
|
||||
|
||||
Returns:
|
||||
Int, the duration in seconds
|
||||
int, the duration in seconds.
|
||||
"""
|
||||
return _config.get_callback_timeout()
|
||||
|
||||
|
@ -323,7 +324,7 @@ def __str__():
|
|||
String representation of the configurations.
|
||||
|
||||
Returns:
|
||||
Str, configurations.
|
||||
str, configurations.
|
||||
"""
|
||||
return str(_config)
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ def zip(datasets):
|
|||
The number of datasets must be more than 1.
|
||||
|
||||
Returns:
|
||||
Dataset, ZipDataset.
|
||||
ZipDataset, dataset zipped.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of datasets is 1.
|
||||
|
@ -149,8 +149,8 @@ class Dataset:
|
|||
Internal method to create an IR tree.
|
||||
|
||||
Returns:
|
||||
ir_tree, The onject of the IR tree.
|
||||
dataset, the root dataset of the IR tree.
|
||||
DatasetNode, the root node of the IR tree.
|
||||
Dataset, the root dataset of the IR tree.
|
||||
"""
|
||||
parent = self.parent
|
||||
self.parent = []
|
||||
|
@ -165,7 +165,7 @@ class Dataset:
|
|||
Internal method to parse the API tree into an IR tree.
|
||||
|
||||
Returns:
|
||||
DatasetNode, The root of the IR tree.
|
||||
DatasetNode, the root node of the IR tree.
|
||||
"""
|
||||
if len(self.parent) > 1:
|
||||
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
|
||||
|
@ -197,7 +197,7 @@ class Dataset:
|
|||
Args:
|
||||
|
||||
Returns:
|
||||
Python dictionary.
|
||||
dict, attributes related to the current class.
|
||||
"""
|
||||
args = dict()
|
||||
args["num_parallel_workers"] = self.num_parallel_workers
|
||||
|
@ -211,7 +211,7 @@ class Dataset:
|
|||
filename (str): filename of json file to be saved as
|
||||
|
||||
Returns:
|
||||
Str, JSON string of the pipeline.
|
||||
str, JSON string of the pipeline.
|
||||
"""
|
||||
return json.loads(self.parse_tree().to_json(filename))
|
||||
|
||||
|
@ -258,6 +258,9 @@ class Dataset:
|
|||
drop_remainder (bool, optional): If True, will drop the last batch for each
|
||||
bucket if it is not a full batch (default=False).
|
||||
|
||||
Returns:
|
||||
BucketBatchByLengthDataset, dataset bucketed and batched by length.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
|
@ -371,6 +374,9 @@ class Dataset:
|
|||
num_batch (int): the number of batches without blocking at the start of each epoch.
|
||||
callback (function): The callback funciton that will be invoked when sync_update is called.
|
||||
|
||||
Returns:
|
||||
SyncWaitDataset, dataset added a blocking condition.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If condition name already exists.
|
||||
|
||||
|
@ -434,7 +440,7 @@ class Dataset:
|
|||
return a 'Dataset'.
|
||||
|
||||
Returns:
|
||||
Dataset, applied by the function.
|
||||
Dataset, dataset applied by the function.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -650,7 +656,7 @@ class Dataset:
|
|||
in parallel (default=None).
|
||||
|
||||
Returns:
|
||||
FilterDataset, dataset filter.
|
||||
FilterDataset, dataset filtered.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -748,6 +754,9 @@ class Dataset:
|
|||
"""
|
||||
Internal method called by split to calculate absolute split sizes and to
|
||||
do some error checking after calculating absolute split sizes.
|
||||
|
||||
Returns:
|
||||
int, absolute split sizes of the dataset.
|
||||
"""
|
||||
# Call get_dataset_size here and check input here because
|
||||
# don't want to call this once in check_split and another time in
|
||||
|
@ -1015,7 +1024,7 @@ class Dataset:
|
|||
is specified and special_first is set to default, special_tokens will be prepended
|
||||
|
||||
Returns:
|
||||
Vocab node
|
||||
Vocab, vocab built from the dataset.
|
||||
|
||||
Example:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -1074,7 +1083,7 @@ class Dataset:
|
|||
params(dict): contains more optional parameters of sentencepiece library
|
||||
|
||||
Returns:
|
||||
SentencePieceVocab node
|
||||
SentencePieceVocab, vocab built from the dataset.
|
||||
|
||||
Example:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -1115,7 +1124,7 @@ class Dataset:
|
|||
return a preprogressing 'Dataset'.
|
||||
|
||||
Returns:
|
||||
Dataset, applied by the function.
|
||||
Dataset, dataset applied by the function.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -1159,7 +1168,7 @@ class Dataset:
|
|||
If device is Ascend, features of data will be transferred one by one. The limitation
|
||||
of data transmission per time is 256M.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
TransferDataset, dataset for transferring.
|
||||
"""
|
||||
return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
|
||||
|
@ -1287,7 +1296,7 @@ class Dataset:
|
|||
use this param to select the conversion method, only take False for better performance (default=True).
|
||||
|
||||
Returns:
|
||||
Iterator, list of ndarrays.
|
||||
TupleIterator, tuple iterator over the dataset.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -1322,7 +1331,7 @@ class Dataset:
|
|||
if output_numpy=False, iterator will output MSTensor (default=False).
|
||||
|
||||
Returns:
|
||||
Iterator, dictionary of column name-ndarray pair.
|
||||
DictIterator, dictionary iterator over the dataset.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -1352,6 +1361,9 @@ class Dataset:
|
|||
"""
|
||||
Get Input Index Information
|
||||
|
||||
Returns:
|
||||
tuple, tuple of the input index information.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
|
@ -1409,6 +1421,9 @@ class Dataset:
|
|||
def get_col_names(self):
|
||||
"""
|
||||
Get names of the columns in the dataset
|
||||
|
||||
Returns:
|
||||
list, list of column names in the dataset.
|
||||
"""
|
||||
if self._col_names is None:
|
||||
runtime_getter = self._init_tree_getters()
|
||||
|
@ -1419,8 +1434,8 @@ class Dataset:
|
|||
"""
|
||||
Get the shapes of output data.
|
||||
|
||||
Return:
|
||||
List, list of shapes of each column.
|
||||
Returns:
|
||||
list, list of shapes of each column.
|
||||
"""
|
||||
if self.saved_output_shapes is None:
|
||||
runtime_getter = self._init_tree_getters()
|
||||
|
@ -1432,8 +1447,8 @@ class Dataset:
|
|||
"""
|
||||
Get the types of output data.
|
||||
|
||||
Return:
|
||||
List of data types.
|
||||
Returns:
|
||||
list, list of data types.
|
||||
"""
|
||||
if self.saved_output_types is None:
|
||||
runtime_getter = self._init_tree_getters()
|
||||
|
@ -1445,8 +1460,8 @@ class Dataset:
|
|||
"""
|
||||
Get the number of batches in an epoch.
|
||||
|
||||
Return:
|
||||
Number, number of batches.
|
||||
Returns:
|
||||
int, number of batches.
|
||||
"""
|
||||
if self.dataset_size is None:
|
||||
runtime_getter = self._init_size_getter()
|
||||
|
@ -1457,8 +1472,8 @@ class Dataset:
|
|||
"""
|
||||
Get the number of classes in a dataset.
|
||||
|
||||
Return:
|
||||
Number, number of classes.
|
||||
Returns:
|
||||
int, number of classes.
|
||||
"""
|
||||
if self._num_classes is None:
|
||||
runtime_getter = self._init_tree_getters()
|
||||
|
@ -1511,8 +1526,8 @@ class Dataset:
|
|||
"""
|
||||
Get the size of a batch.
|
||||
|
||||
Return:
|
||||
Number, the number of data in a batch.
|
||||
Returns:
|
||||
int, the number of data in a batch.
|
||||
"""
|
||||
if self._batch_size is None:
|
||||
runtime_getter = self._init_tree_getters()
|
||||
|
@ -1525,8 +1540,8 @@ class Dataset:
|
|||
"""
|
||||
Get the replication times in RepeatDataset else 1.
|
||||
|
||||
Return:
|
||||
Number, the count of repeat.
|
||||
Returns:
|
||||
int, the count of repeat.
|
||||
"""
|
||||
if self._repeat_count is None:
|
||||
runtime_getter = self._init_tree_getters()
|
||||
|
@ -1540,8 +1555,8 @@ class Dataset:
|
|||
Get the class index.
|
||||
|
||||
Returns:
|
||||
Dict, A str-to-int mapping from label name to index.
|
||||
Dict, A str-to-list<int> mapping from label name to index for Coco ONLY. The second number
|
||||
dict, a str-to-int mapping from label name to index.
|
||||
dict, a str-to-list<int> mapping from label name to index for Coco ONLY. The second number
|
||||
in the list is used to indicate the super category
|
||||
"""
|
||||
if self.children:
|
||||
|
@ -1588,7 +1603,7 @@ class SourceDataset(Dataset):
|
|||
patterns (Union[str, list[str]]): String or list of patterns to be searched.
|
||||
|
||||
Returns:
|
||||
List, files.
|
||||
list, list of files.
|
||||
"""
|
||||
|
||||
if not isinstance(patterns, list):
|
||||
|
@ -1646,9 +1661,6 @@ class MappableDataset(SourceDataset):
|
|||
Args:
|
||||
new_sampler (Sampler): The sampler to use for the current dataset.
|
||||
|
||||
Returns:
|
||||
Dataset, that uses new_sampler.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>>
|
||||
|
@ -1909,8 +1921,9 @@ class BatchDataset(Dataset):
|
|||
|
||||
Args:
|
||||
dataset (Dataset): Dataset to be checked.
|
||||
Return:
|
||||
True or False.
|
||||
|
||||
Returns:
|
||||
bool, whether repeat is used before batch.
|
||||
"""
|
||||
if isinstance(dataset, RepeatDataset):
|
||||
return True
|
||||
|
@ -1995,18 +2008,12 @@ class BatchInfo(cde.CBatchInfo):
|
|||
def get_batch_num(self):
|
||||
"""
|
||||
Return the batch number of the current batch.
|
||||
|
||||
Return:
|
||||
Number, number of the current batch.
|
||||
"""
|
||||
return
|
||||
|
||||
def get_epoch_num(self):
|
||||
"""
|
||||
Return the epoch number of the current batch.
|
||||
|
||||
Return:
|
||||
Number, number of the current epoch.
|
||||
"""
|
||||
return
|
||||
|
||||
|
@ -2055,8 +2062,8 @@ class BlockReleasePair:
|
|||
"""
|
||||
Function for handing blocking condition.
|
||||
|
||||
Return:
|
||||
True
|
||||
Returns:
|
||||
bool, True.
|
||||
"""
|
||||
with self.cv:
|
||||
# if disable is true, the always evaluate to true
|
||||
|
@ -2145,8 +2152,9 @@ class SyncWaitDataset(Dataset):
|
|||
|
||||
Args:
|
||||
dataset (Dataset): Dataset to be checked.
|
||||
Return:
|
||||
True or False.
|
||||
|
||||
Returns:
|
||||
bool, whether sync_wait is used before batch.
|
||||
"""
|
||||
if isinstance(dataset, BatchDataset):
|
||||
return True
|
||||
|
@ -2932,6 +2940,9 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, n
|
|||
num_shards (int): Number of shard for sharding.
|
||||
shard_id (int): Shard ID.
|
||||
non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).
|
||||
|
||||
Returns:
|
||||
Sampler, sampler selected based on user input.
|
||||
"""
|
||||
if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
|
||||
return None
|
||||
|
@ -4180,7 +4191,7 @@ class ManifestDataset(MappableDataset):
|
|||
Get the class index.
|
||||
|
||||
Returns:
|
||||
Dict, A str-to-int mapping from label name to index.
|
||||
dict, a str-to-int mapping from label name to index.
|
||||
"""
|
||||
if self.class_indexing is None:
|
||||
if self._class_indexing is None:
|
||||
|
@ -4579,7 +4590,7 @@ class Schema:
|
|||
Args:
|
||||
schema_file(str): Path of schema file (default=None).
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
Schema object, schema info about dataset.
|
||||
|
||||
Raises:
|
||||
|
@ -4654,7 +4665,7 @@ class Schema:
|
|||
Get a JSON string of the schema.
|
||||
|
||||
Returns:
|
||||
Str, JSON string of the schema.
|
||||
str, JSON string of the schema.
|
||||
"""
|
||||
return self.cpp_schema.to_json()
|
||||
|
||||
|
@ -4840,7 +4851,7 @@ class VOCDataset(MappableDataset):
|
|||
Get the class index.
|
||||
|
||||
Returns:
|
||||
Dict, A str-to-int mapping from label name to index.
|
||||
dict, a str-to-int mapping from label name to index.
|
||||
"""
|
||||
if self.task != "Detection":
|
||||
raise NotImplementedError("Only 'Detection' support get_class_indexing.")
|
||||
|
@ -5032,7 +5043,7 @@ class CocoDataset(MappableDataset):
|
|||
Get the class index.
|
||||
|
||||
Returns:
|
||||
Dict, A str-to-list<int> mapping from label name to index
|
||||
dict, a str-to-list<int> mapping from label name to index
|
||||
"""
|
||||
if self.task not in {"Detection", "Panoptic"}:
|
||||
raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.")
|
||||
|
|
|
@ -100,7 +100,7 @@ class GraphData:
|
|||
node_type (int): Specify the type of node.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of nodes.
|
||||
numpy.ndarray, array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -124,7 +124,7 @@ class GraphData:
|
|||
edge_type (int): Specify the type of edge.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of edges.
|
||||
numpy.ndarray, array of edges.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -148,7 +148,7 @@ class GraphData:
|
|||
edge_list (Union[list, numpy.ndarray]): The given list of edges.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of nodes.
|
||||
numpy.ndarray, array of nodes.
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_list` is not list or ndarray.
|
||||
|
@ -167,7 +167,7 @@ class GraphData:
|
|||
neighbor_type (int): Specify the type of neighbor.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of nodes.
|
||||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -201,7 +201,7 @@ class GraphData:
|
|||
neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of nodes.
|
||||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -231,7 +231,7 @@ class GraphData:
|
|||
neg_neighbor_type (int): Specify the type of negative neighbor.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of nodes.
|
||||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -260,7 +260,7 @@ class GraphData:
|
|||
feature_types (Union[list, numpy.ndarray]): The given list of feature types.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of features.
|
||||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -292,7 +292,7 @@ class GraphData:
|
|||
feature_types (Union[list, numpy.ndarray]): The given list of feature types.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of features.
|
||||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -320,7 +320,7 @@ class GraphData:
|
|||
the feature information of nodes, the number of edges, the type of edges, and the feature information of edges.
|
||||
|
||||
Returns:
|
||||
dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
|
||||
dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
|
||||
node_feature_type and edge_feature_type.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
|
@ -347,7 +347,7 @@ class GraphData:
|
|||
A default value of -1 indicates that no node is given.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of nodes.
|
||||
numpy.ndarray, array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
|
|
@ -128,6 +128,7 @@ class BuiltinSampler:
|
|||
|
||||
User should not extend this class.
|
||||
"""
|
||||
|
||||
def __init__(self, num_samples=None):
|
||||
self.child_sampler = None
|
||||
self.num_samples = num_samples
|
||||
|
@ -201,7 +202,7 @@ class BuiltinSampler:
|
|||
- None
|
||||
|
||||
Returns:
|
||||
int, The number of samples, or None
|
||||
int, the number of samples, or None
|
||||
"""
|
||||
if self.child_sampler is not None:
|
||||
child_samples = self.child_sampler.get_num_samples()
|
||||
|
|
|
@ -1063,8 +1063,9 @@ class LinearTransformation:
|
|||
the dot product with the transformation matrix, and reshapes it back to its original shape.
|
||||
|
||||
Args:
|
||||
transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
|
||||
mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
|
||||
transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), where
|
||||
:math:`D = C \times H \times W`.
|
||||
mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where :math:`D = C \times H \times W`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.transforms.py_transforms import Compose
|
||||
|
|
|
@ -23,7 +23,7 @@ from PIL import Image
|
|||
import mindspore as ms
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log
|
||||
from mindspore.dataset.engine.datasets import Dataset
|
||||
from mindspore.dataset import Dataset
|
||||
from mindspore.nn import Cell, SequentialCell
|
||||
from mindspore.ops.operations import ExpandDims
|
||||
from mindspore.train._utils import check_value_type
|
||||
|
|
|
@ -30,6 +30,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchema
|
|||
|
||||
__all__ = ['FileWriter']
|
||||
|
||||
|
||||
class FileWriter:
|
||||
"""
|
||||
Class to write user defined raw data into MindRecord File series.
|
||||
|
@ -45,6 +46,7 @@ class FileWriter:
|
|||
Raises:
|
||||
ParamValueError: If `file_name` or `shard_num` is invalid.
|
||||
"""
|
||||
|
||||
def __init__(self, file_name, shard_num=1):
|
||||
check_filename(file_name)
|
||||
self._file_name = file_name
|
||||
|
@ -84,7 +86,7 @@ class FileWriter:
|
|||
file_name (str): String of MindRecord file name.
|
||||
|
||||
Returns:
|
||||
Instance of FileWriter.
|
||||
FileWriter, file writer for the opened MindRecord file.
|
||||
|
||||
Raises:
|
||||
ParamValueError: If file_name is invalid.
|
||||
|
@ -118,7 +120,7 @@ class FileWriter:
|
|||
desc (str, optional): String of schema description (default=None).
|
||||
|
||||
Returns:
|
||||
An integer, schema id.
|
||||
int, schema id.
|
||||
|
||||
Raises:
|
||||
MRMInvalidSchemaError: If schema is invalid.
|
||||
|
@ -175,17 +177,17 @@ class FileWriter:
|
|||
|
||||
if field not in v:
|
||||
error_data_dic[i] = "for schema, {} th data is wrong, " \
|
||||
"there is not '{}' object in the raw data.".format(i, field)
|
||||
"there is not '{}' object in the raw data.".format(i, field)
|
||||
continue
|
||||
field_type = type(v[field]).__name__
|
||||
if field_type not in VALUE_TYPE_MAP:
|
||||
error_data_dic[i] = "for schema, {} th data is wrong, " \
|
||||
"data type for '{}' is not matched.".format(i, field)
|
||||
"data type for '{}' is not matched.".format(i, field)
|
||||
continue
|
||||
|
||||
if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]:
|
||||
error_data_dic[i] = "for schema, {} th data is wrong, " \
|
||||
"data type for '{}' is not matched.".format(i, field)
|
||||
"data type for '{}' is not matched.".format(i, field)
|
||||
continue
|
||||
|
||||
if field_type == 'ndarray':
|
||||
|
@ -206,7 +208,6 @@ class FileWriter:
|
|||
def open_and_set_header(self):
|
||||
"""
|
||||
Open writer and set header.
|
||||
|
||||
"""
|
||||
if not self._writer.is_open:
|
||||
self._writer.open(self._paths)
|
||||
|
@ -222,6 +223,9 @@ class FileWriter:
|
|||
raw_data (list[dict]): List of raw data.
|
||||
parallel_writer (bool, optional): Load data parallel if it equals to True (default=False).
|
||||
|
||||
Returns:
|
||||
MSRStatus, SUCCESS or FAILED.
|
||||
|
||||
Raises:
|
||||
ParamTypeError: If index field is invalid.
|
||||
MRMOpenError: If failed to open MindRecord File.
|
||||
|
@ -330,7 +334,7 @@ class FileWriter:
|
|||
v (dict): Sub dict in schema
|
||||
|
||||
Returns:
|
||||
bool, True or False.
|
||||
bool, whether the array item is valid.
|
||||
str, error message.
|
||||
"""
|
||||
if v['type'] not in VALID_ARRAY_ATTRIBUTES:
|
||||
|
@ -355,7 +359,7 @@ class FileWriter:
|
|||
content (dict): Dict of raw schema.
|
||||
|
||||
Returns:
|
||||
bool, True or False.
|
||||
bool, whether the schema is valid.
|
||||
str, error message.
|
||||
"""
|
||||
error = ''
|
||||
|
|
|
@ -23,6 +23,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategor
|
|||
|
||||
__all__ = ['MindPage']
|
||||
|
||||
|
||||
class MindPage:
|
||||
"""
|
||||
Class to read MindRecord File series in pagination.
|
||||
|
@ -36,6 +37,7 @@ class MindPage:
|
|||
ParamValueError: If `file_name`, `num_consumer` or columns is invalid.
|
||||
MRMInitSegmentError: If failed to initialize ShardSegment.
|
||||
"""
|
||||
|
||||
def __init__(self, file_name, num_consumer=4):
|
||||
if isinstance(file_name, list):
|
||||
for f in file_name:
|
||||
|
@ -69,7 +71,12 @@ class MindPage:
|
|||
return self._candidate_fields
|
||||
|
||||
def get_category_fields(self):
|
||||
"""Return candidate category fields."""
|
||||
"""
|
||||
Return candidate category fields.
|
||||
|
||||
Returns:
|
||||
list[str], by which data could be grouped.
|
||||
"""
|
||||
logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated."
|
||||
" Please use candidate_fields")
|
||||
return self.candidate_fields
|
||||
|
@ -97,12 +104,22 @@ class MindPage:
|
|||
|
||||
@property
|
||||
def category_field(self):
|
||||
"""Getter function for category fields."""
|
||||
"""
|
||||
Getter function for category fields.
|
||||
|
||||
Returns:
|
||||
list[str], by which data could be grouped.
|
||||
"""
|
||||
return self._category_field
|
||||
|
||||
@category_field.setter
|
||||
def category_field(self, category_field):
|
||||
"""Setter function for category field"""
|
||||
"""
|
||||
Setter function for category field.
|
||||
|
||||
Returns:
|
||||
MSRStatus, SUCCESS or FAILED.
|
||||
"""
|
||||
if not category_field or not isinstance(category_field, str):
|
||||
raise ParamTypeError('category_fields', 'str')
|
||||
if category_field not in self._candidate_fields:
|
||||
|
@ -132,7 +149,7 @@ class MindPage:
|
|||
num_row (int): Number of rows in a page.
|
||||
|
||||
Returns:
|
||||
List, list[dict].
|
||||
list[dict], data queried by category id.
|
||||
|
||||
Raises:
|
||||
ParamValueError: If any parameter is invalid.
|
||||
|
@ -158,7 +175,7 @@ class MindPage:
|
|||
num_row (int): Number of row in a page.
|
||||
|
||||
Returns:
|
||||
str, read at page.
|
||||
list[dict], data queried by category name.
|
||||
"""
|
||||
if not isinstance(category_name, str):
|
||||
raise ParamValueError("Category name should be str.")
|
||||
|
|
|
@ -23,6 +23,7 @@ from .common.exceptions import MRMOpenError, MRMOpenForAppendError, MRMInvalidHe
|
|||
|
||||
__all__ = ['ShardWriter']
|
||||
|
||||
|
||||
class ShardWriter:
|
||||
"""
|
||||
Wrapper class which is represent shardWrite class in c++ module.
|
||||
|
@ -192,9 +193,11 @@ class ShardWriter:
|
|||
if len(blob_data) == 1:
|
||||
values = [v for v in blob_data.values()]
|
||||
return bytes(values[0])
|
||||
|
||||
# convert int to bytes
|
||||
def int_to_bytes(x: int) -> bytes:
|
||||
return x.to_bytes(8, 'big')
|
||||
|
||||
merged = bytes()
|
||||
for field, v in blob_data.items():
|
||||
# convert ndarray to bytes
|
||||
|
@ -209,7 +212,7 @@ class ShardWriter:
|
|||
Flush data to disk.
|
||||
|
||||
Returns:
|
||||
Class MSRStatus, SUCCESS or FAILED.
|
||||
MSRStatus, SUCCESS or FAILED.
|
||||
|
||||
Raises:
|
||||
MRMCommitError: If failed to flush data to disk.
|
||||
|
|
|
@ -33,6 +33,7 @@ except ModuleNotFoundError:
|
|||
|
||||
__all__ = ['Cifar100ToMR']
|
||||
|
||||
|
||||
class Cifar100ToMR:
|
||||
"""
|
||||
A class to transform from cifar100 to MindRecord.
|
||||
|
@ -44,6 +45,7 @@ class Cifar100ToMR:
|
|||
Raises:
|
||||
ValueError: If source or destination is invalid.
|
||||
"""
|
||||
|
||||
def __init__(self, source, destination):
|
||||
check_filename(source)
|
||||
self.source = source
|
||||
|
@ -74,7 +76,7 @@ class Cifar100ToMR:
|
|||
fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"].
|
||||
|
||||
Returns:
|
||||
SUCCESS or FAILED, whether cifar100 is successfully transformed to MindRecord.
|
||||
MSRStatus, whether cifar100 is successfully transformed to MindRecord.
|
||||
"""
|
||||
if fields and not isinstance(fields, list):
|
||||
raise ValueError("The parameter fields should be None or list")
|
||||
|
@ -114,6 +116,7 @@ class Cifar100ToMR:
|
|||
raise t.exception
|
||||
return t.res
|
||||
|
||||
|
||||
def _construct_raw_data(images, fine_labels, coarse_labels):
|
||||
"""
|
||||
Construct raw data from cifar100 data.
|
||||
|
@ -124,7 +127,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels):
|
|||
coarse_labels (list): coarse label list from cifar100.
|
||||
|
||||
Returns:
|
||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
||||
list[dict], data dictionary constructed from cifar100.
|
||||
"""
|
||||
if not cv2:
|
||||
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")
|
||||
|
@ -141,6 +144,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels):
|
|||
raw_data.append(row_data)
|
||||
return raw_data
|
||||
|
||||
|
||||
def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
|
||||
"""
|
||||
Generate MindRecord file from raw data.
|
||||
|
@ -153,7 +157,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
|
|||
schema_desc (str): String of schema description.
|
||||
|
||||
Returns:
|
||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
||||
MSRStatus, whether successfully written into MindRecord.
|
||||
"""
|
||||
schema = {"id": {"type": "int64"}, "fine_label": {"type": "int64"},
|
||||
"coarse_label": {"type": "int64"}, "data": {"type": "bytes"}}
|
||||
|
|
|
@ -25,6 +25,7 @@ from .cifar10 import Cifar10
|
|||
from ..common.exceptions import PathNotExistsError
|
||||
from ..filewriter import FileWriter
|
||||
from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED
|
||||
|
||||
try:
|
||||
cv2 = import_module("cv2")
|
||||
except ModuleNotFoundError:
|
||||
|
@ -32,6 +33,7 @@ except ModuleNotFoundError:
|
|||
|
||||
__all__ = ['Cifar10ToMR']
|
||||
|
||||
|
||||
class Cifar10ToMR:
|
||||
"""
|
||||
A class to transform from cifar10 to MindRecord.
|
||||
|
@ -43,6 +45,7 @@ class Cifar10ToMR:
|
|||
Raises:
|
||||
ValueError: If source or destination is invalid.
|
||||
"""
|
||||
|
||||
def __init__(self, source, destination):
|
||||
check_filename(source)
|
||||
self.source = source
|
||||
|
@ -73,7 +76,7 @@ class Cifar10ToMR:
|
|||
fields (list[str], optional): A list of index fields, e.g.["label"] (default=None).
|
||||
|
||||
Returns:
|
||||
SUCCESS or FAILED, whether cifar10 is successfully transformed to MindRecord.
|
||||
MSRStatus, whether cifar10 is successfully transformed to MindRecord.
|
||||
"""
|
||||
if fields and not isinstance(fields, list):
|
||||
raise ValueError("The parameter fields should be None or list")
|
||||
|
@ -109,6 +112,7 @@ class Cifar10ToMR:
|
|||
raise t.exception
|
||||
return t.res
|
||||
|
||||
|
||||
def _construct_raw_data(images, labels):
|
||||
"""
|
||||
Construct raw data from cifar10 data.
|
||||
|
@ -118,7 +122,7 @@ def _construct_raw_data(images, labels):
|
|||
labels (list): label list from cifar10.
|
||||
|
||||
Returns:
|
||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
||||
list[dict], data dictionary constructed from cifar10.
|
||||
"""
|
||||
if not cv2:
|
||||
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")
|
||||
|
@ -133,6 +137,7 @@ def _construct_raw_data(images, labels):
|
|||
raw_data.append(row_data)
|
||||
return raw_data
|
||||
|
||||
|
||||
def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
|
||||
"""
|
||||
Generate MindRecord file from raw data.
|
||||
|
@ -145,7 +150,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
|
|||
schema_desc (str): String of schema description.
|
||||
|
||||
Returns:
|
||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
||||
MSRStatus, whether successfully written into MindRecord.
|
||||
"""
|
||||
schema = {"id": {"type": "int64"}, "label": {"type": "int64"},
|
||||
"data": {"type": "bytes"}}
|
||||
|
|
|
@ -29,6 +29,7 @@ except ModuleNotFoundError:
|
|||
|
||||
__all__ = ['CsvToMR']
|
||||
|
||||
|
||||
class CsvToMR:
|
||||
"""
|
||||
A class to transform from csv to MindRecord.
|
||||
|
@ -121,7 +122,7 @@ class CsvToMR:
|
|||
Executes transformation from csv to MindRecord.
|
||||
|
||||
Returns:
|
||||
SUCCESS or FAILED, whether csv is successfully transformed to MindRecord.
|
||||
MSRStatus, whether csv is successfully transformed to MindRecord.
|
||||
"""
|
||||
if not os.path.exists(self.source):
|
||||
raise IOError("Csv file {} do not exist.".format(self.source))
|
||||
|
|
|
@ -47,6 +47,7 @@ class ImageNetToMR:
|
|||
Raises:
|
||||
ValueError: If `map_file`, `image_dir` or `destination` is invalid.
|
||||
"""
|
||||
|
||||
def __init__(self, map_file, image_dir, destination, partition_number=1):
|
||||
check_filename(map_file)
|
||||
self.map_file = map_file
|
||||
|
@ -122,7 +123,7 @@ class ImageNetToMR:
|
|||
Executes transformation from imagenet to MindRecord.
|
||||
|
||||
Returns:
|
||||
SUCCESS or FAILED, whether imagenet is successfully transformed to MindRecord.
|
||||
MSRStatus, whether imagenet is successfully transformed to MindRecord.
|
||||
"""
|
||||
t0_total = time.time()
|
||||
|
||||
|
@ -133,10 +134,10 @@ class ImageNetToMR:
|
|||
logger.info("transformed MindRecord schema is: {}".format(imagenet_schema_json))
|
||||
|
||||
# set the header size
|
||||
self.writer.set_header_size(1<<24)
|
||||
self.writer.set_header_size(1 << 24)
|
||||
|
||||
# set the page size
|
||||
self.writer.set_page_size(1<<26)
|
||||
self.writer.set_page_size(1 << 26)
|
||||
|
||||
# create the schema
|
||||
self.writer.add_schema(imagenet_schema_json, "imagenet_schema")
|
||||
|
|
|
@ -32,6 +32,7 @@ except ModuleNotFoundError:
|
|||
|
||||
__all__ = ['MnistToMR']
|
||||
|
||||
|
||||
class MnistToMR:
|
||||
"""
|
||||
A class to transform from Mnist to MindRecord.
|
||||
|
@ -125,7 +126,7 @@ class MnistToMR:
|
|||
Executes transformation from Mnist train part to MindRecord.
|
||||
|
||||
Returns:
|
||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
||||
MSRStatus, whether successfully written into MindRecord.
|
||||
"""
|
||||
t0_total = time.time()
|
||||
|
||||
|
@ -173,7 +174,7 @@ class MnistToMR:
|
|||
Executes transformation from Mnist test part to MindRecord.
|
||||
|
||||
Returns:
|
||||
SUCCESS or FAILED, whether Mnist is successfully transformed to MindRecord.
|
||||
MSRStatus, whether Mnist is successfully transformed to MindRecord.
|
||||
"""
|
||||
t0_total = time.time()
|
||||
|
||||
|
@ -222,7 +223,7 @@ class MnistToMR:
|
|||
Executes transformation from Mnist to MindRecord.
|
||||
|
||||
Returns:
|
||||
SUCCESS/FAILED, whether successfully written into MindRecord.
|
||||
MSRStatus, whether successfully written into MindRecord.
|
||||
"""
|
||||
if not cv2:
|
||||
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")
|
||||
|
|
|
@ -23,7 +23,6 @@ from mindspore import log as logger
|
|||
from ..filewriter import FileWriter
|
||||
from ..shardutils import check_filename, ExceptionThread
|
||||
|
||||
|
||||
__all__ = ['TFRecordToMR']
|
||||
|
||||
SupportedTensorFlowVersion = '1.13.0-rc1'
|
||||
|
@ -86,9 +85,10 @@ class TFRecordToMR:
|
|||
ValueError: If parameter is invalid.
|
||||
Exception: when tensorflow module is not found or version is not correct.
|
||||
"""
|
||||
|
||||
def __init__(self, source, destination, feature_dict, bytes_fields=None):
|
||||
try:
|
||||
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
|
||||
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
|
||||
except ModuleNotFoundError:
|
||||
self.tf = None
|
||||
if not self.tf:
|
||||
|
@ -265,7 +265,7 @@ class TFRecordToMR:
|
|||
Execute transformation from TFRecord to MindRecord.
|
||||
|
||||
Returns:
|
||||
SUCCESS or FAILED, whether TFRecord is successfuly transformed to MindRecord.
|
||||
MSRStatus, whether TFRecord is successfuly transformed to MindRecord.
|
||||
"""
|
||||
writer = FileWriter(self.destination)
|
||||
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"
|
||||
|
|
|
@ -22,13 +22,12 @@ import pickle
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from mindspore.dataset.engine import GeneratorDataset
|
||||
from mindspore.dataset import GeneratorDataset
|
||||
|
||||
import src.constants as rconst
|
||||
import src.movielens as movielens
|
||||
import src.stat_utils as stat_utils
|
||||
|
||||
|
||||
DATASET_TO_NUM_USERS_AND_ITEMS = {
|
||||
"ml-1m": (6040, 3706),
|
||||
"ml-20m": (138493, 26744)
|
||||
|
@ -205,6 +204,7 @@ class NCFDataset:
|
|||
"""
|
||||
A dataset for NCF network.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pos_users,
|
||||
pos_items,
|
||||
|
@ -407,6 +407,7 @@ class RandomSampler:
|
|||
"""
|
||||
A random sampler for dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, pos_count, num_train_negatives, batch_size):
|
||||
self.pos_count = pos_count
|
||||
self._num_samples = (1 + num_train_negatives) * self.pos_count
|
||||
|
@ -433,6 +434,7 @@ class DistributedSamplerOfTrain:
|
|||
"""
|
||||
A distributed sampler for dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size):
|
||||
"""
|
||||
Distributed sampler of training dataset.
|
||||
|
@ -443,15 +445,16 @@ class DistributedSamplerOfTrain:
|
|||
self._batch_size = batch_size
|
||||
|
||||
self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size))
|
||||
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
|
||||
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
|
||||
self._total_num_samples = self._samples_per_rank * self._rank_size
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Returns the data after each sampling.
|
||||
"""
|
||||
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
|
||||
indices = indices.tolist()
|
||||
indices.extend(indices[:self._total_num_samples-len(indices)])
|
||||
indices.extend(indices[:self._total_num_samples - len(indices)])
|
||||
indices = indices[self._rank_id:self._total_num_samples:self._rank_size]
|
||||
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)]
|
||||
|
||||
|
@ -463,10 +466,12 @@ class DistributedSamplerOfTrain:
|
|||
"""
|
||||
return self._batchs_per_rank
|
||||
|
||||
|
||||
class SequenceSampler:
|
||||
"""
|
||||
A sequence sampler for dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, eval_batch_size, num_users):
|
||||
self._eval_users_per_batch = int(
|
||||
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
||||
|
@ -491,10 +496,12 @@ class SequenceSampler:
|
|||
"""
|
||||
return self._eval_batches_per_epoch
|
||||
|
||||
|
||||
class DistributedSamplerOfEval:
|
||||
"""
|
||||
A distributed sampler for eval dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, eval_batch_size, num_users, rank_id, rank_size):
|
||||
self._eval_users_per_batch = int(
|
||||
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
||||
|
@ -507,8 +514,8 @@ class DistributedSamplerOfEval:
|
|||
self._eval_batch_size = eval_batch_size
|
||||
|
||||
self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size))
|
||||
#self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
|
||||
#self._total_num_samples = self._samples_per_rank * self._rank_size
|
||||
# self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
|
||||
# self._total_num_samples = self._samples_per_rank * self._rank_size
|
||||
|
||||
def __iter__(self):
|
||||
indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch)
|
||||
|
@ -525,6 +532,7 @@ class DistributedSamplerOfEval:
|
|||
def __len__(self):
|
||||
return self._batchs_per_rank
|
||||
|
||||
|
||||
def parse_eval_batch_size(eval_batch_size):
|
||||
"""
|
||||
Parse eval batch size.
|
||||
|
|
Loading…
Reference in New Issue