!2306 [Dataset] Code review & improve quality
This commit is contained in:
parent
83b53559f5
commit
4e3bfcf4c9
|
@ -20,6 +20,7 @@ import os
|
|||
import pickle as pkl
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from mindspore import log as logger
|
||||
|
||||
# parse args from command line parameter 'graph_api_args'
|
||||
# args delimiter is ':'
|
||||
|
@ -58,7 +59,7 @@ def yield_nodes(task_id=0):
|
|||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
logger.info("Node task is {}".format(task_id))
|
||||
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
|
||||
objects = []
|
||||
for name in names:
|
||||
|
@ -98,7 +99,7 @@ def yield_nodes(task_id=0):
|
|||
line_count += 1
|
||||
node_ids.append(i)
|
||||
yield node
|
||||
print('Processed {} lines for nodes.'.format(line_count))
|
||||
logger.info('Processed {} lines for nodes.'.format(line_count))
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
|
@ -108,21 +109,21 @@ def yield_edges(task_id=0):
|
|||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
logger.info("Edge task is {}".format(task_id))
|
||||
with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
|
||||
graph = pkl.load(f, encoding='latin1')
|
||||
line_count = 0
|
||||
for i in graph:
|
||||
for dst_id in graph[i]:
|
||||
if not i in node_ids:
|
||||
print('Source node {} does not exist.'.format(i))
|
||||
logger.info('Source node {} does not exist.'.format(i))
|
||||
continue
|
||||
if not dst_id in node_ids:
|
||||
print('Destination node {} does not exist.'.format(
|
||||
logger.info('Destination node {} does not exist.'.format(
|
||||
dst_id))
|
||||
continue
|
||||
edge = {'id': line_count,
|
||||
'src_id': i, 'dst_id': dst_id, 'type': 0}
|
||||
line_count += 1
|
||||
yield edge
|
||||
print('Processed {} lines for edges.'.format(line_count))
|
||||
logger.info('Processed {} lines for edges.'.format(line_count))
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
Graph data convert tool for MindRecord.
|
||||
"""
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
|
||||
__all__ = ['GraphMapSchema']
|
||||
|
||||
|
@ -41,6 +42,7 @@ class GraphMapSchema:
|
|||
"edge_feature_index": {"type": "int32", "shape": [-1]}
|
||||
}
|
||||
|
||||
@property
|
||||
def get_schema(self):
|
||||
"""
|
||||
Get schema
|
||||
|
@ -52,6 +54,7 @@ class GraphMapSchema:
|
|||
Set node features profile
|
||||
"""
|
||||
if num_features != len(features_data_type) or num_features != len(features_shape):
|
||||
logger.info("Node feature profile is not match.")
|
||||
raise ValueError("Node feature profile is not match.")
|
||||
|
||||
self.num_node_features = num_features
|
||||
|
@ -66,6 +69,7 @@ class GraphMapSchema:
|
|||
Set edge features profile
|
||||
"""
|
||||
if num_features != len(features_data_type) or num_features != len(features_shape):
|
||||
logger.info("Edge feature profile is not match.")
|
||||
raise ValueError("Edge feature profile is not match.")
|
||||
|
||||
self.num_edge_features = num_features
|
||||
|
@ -83,6 +87,10 @@ class GraphMapSchema:
|
|||
Returns:
|
||||
graph data with union schema
|
||||
"""
|
||||
if node is None:
|
||||
logger.info("node cannot be None.")
|
||||
raise ValueError("node cannot be None.")
|
||||
|
||||
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
|
||||
"node_feature_index": []}
|
||||
for i in range(self.num_node_features):
|
||||
|
@ -117,6 +125,10 @@ class GraphMapSchema:
|
|||
Returns:
|
||||
graph data with union schema
|
||||
"""
|
||||
if edge is None:
|
||||
logger.info("edge cannot be None.")
|
||||
raise ValueError("edge cannot be None.")
|
||||
|
||||
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
|
||||
"type": edge["type"], "edge_feature_index": []}
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ if __name__ == "__main__":
|
|||
num_features, feature_data_types, feature_shapes = mr_api.edge_profile
|
||||
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
|
||||
|
||||
graph_schema = graph_map_schema.get_schema()
|
||||
graph_schema = graph_map_schema.get_schema
|
||||
|
||||
# init writer
|
||||
writer = init_writer(graph_schema)
|
||||
|
|
|
@ -983,7 +983,9 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
|
|||
continue;
|
||||
}
|
||||
}
|
||||
memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size);
|
||||
int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size,
|
||||
count * type_size);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric");
|
||||
out_index += count;
|
||||
if (i < indices.size() - 1) {
|
||||
src_start = HandleNeg(indices[i + 1], dim_length); // next index
|
||||
|
|
|
@ -101,6 +101,9 @@ class BucketBatchByLengthOp : public PipelineOp {
|
|||
std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info,
|
||||
bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size);
|
||||
|
||||
// Destructor
|
||||
~BucketBatchByLengthOp() = default;
|
||||
|
||||
// Might need to batch remaining buckets after receiving eoe, so override this method.
|
||||
// @param int32_t workerId
|
||||
// @return Status - The error code returned
|
||||
|
|
|
@ -36,6 +36,7 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
|
|||
: mr_path_(mr_filepath),
|
||||
num_workers_(num_workers),
|
||||
row_id_(0),
|
||||
shard_reader_(nullptr),
|
||||
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
|
||||
|
||||
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace dataset {
|
|||
|
||||
// Driver method for TreePass
|
||||
Status TreePass::Run(ExecutionTree *tree, bool *modified) {
|
||||
if (!tree || !modified) {
|
||||
if (tree == nullptr || modified == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
|
||||
}
|
||||
return this->RunOnTree(tree, modified);
|
||||
|
@ -45,7 +45,7 @@ Status TreePass::Run(ExecutionTree *tree, bool *modified) {
|
|||
|
||||
// Driver method for NodePass
|
||||
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
|
||||
if (!tree || !modified) {
|
||||
if (tree == nullptr || modified == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
|
||||
}
|
||||
std::shared_ptr<DatasetOp> root = tree->root();
|
||||
|
|
|
@ -44,7 +44,7 @@ class ConnectorSize : public Sampling {
|
|||
public:
|
||||
explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {}
|
||||
|
||||
~ConnectorSize() = default;
|
||||
~ConnectorSize() override = default;
|
||||
|
||||
// Driver function for connector size sampling.
|
||||
// This function samples the connector size of every nodes within the ExecutionTree
|
||||
|
|
|
@ -26,6 +26,7 @@ Monitor::Monitor(ExecutionTree *tree) : tree_(tree) {
|
|||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
sampling_interval_ = cfg->monitor_sampling_interval();
|
||||
max_samples_ = 0;
|
||||
cur_row_ = 0;
|
||||
}
|
||||
|
||||
Status Monitor::operator()() {
|
||||
|
|
|
@ -34,6 +34,8 @@ class Slice {
|
|||
Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
|
||||
explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
|
||||
|
||||
~Slice() = default;
|
||||
|
||||
std::vector<dsize_t> Indices(dsize_t length) {
|
||||
std::vector<dsize_t> indices;
|
||||
dsize_t index = std::min(Tensor::HandleNeg(start_, length), length);
|
||||
|
|
|
@ -29,8 +29,8 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow
|
|||
BOUNDING_BOX_CHECK(input);
|
||||
if (distribution_(rnd_)) {
|
||||
// To test bounding boxes algorithm, create random bboxes from image dims
|
||||
size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes
|
||||
float img_center = (input[0]->shape()[1] / 2); // get the center of the image
|
||||
size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes
|
||||
float img_center = (input[0]->shape()[1] / 2.); // get the center of the image
|
||||
|
||||
for (int i = 0; i < num_of_boxes; i++) {
|
||||
uint32_t b_w = 0; // bounding box width
|
||||
|
|
|
@ -49,6 +49,7 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal
|
|||
preserve_unused_token_(preserve_unused_token),
|
||||
case_fold_(std::make_unique<CaseFoldOp>()),
|
||||
nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)),
|
||||
normalization_form_(normalization_form),
|
||||
common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)),
|
||||
replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")),
|
||||
replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) {
|
||||
|
|
|
@ -35,9 +35,9 @@ class BasicTokenizerOp : public TensorOp {
|
|||
static const bool kDefKeepWhitespace;
|
||||
static const NormalizeForm kDefNormalizationForm;
|
||||
static const bool kDefPreserveUnusedToken;
|
||||
BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace,
|
||||
NormalizeForm normalization_form = kDefNormalizationForm,
|
||||
bool preserve_unused_token = kDefPreserveUnusedToken);
|
||||
explicit BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace,
|
||||
NormalizeForm normalization_form = kDefNormalizationForm,
|
||||
bool preserve_unused_token = kDefPreserveUnusedToken);
|
||||
|
||||
~BasicTokenizerOp() override = default;
|
||||
|
||||
|
|
|
@ -28,14 +28,14 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
class BertTokenizerOp : public TensorOp {
|
||||
public:
|
||||
BertTokenizerOp(const std::shared_ptr<Vocab> &vocab,
|
||||
const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator,
|
||||
const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||
const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken,
|
||||
bool lower_case = BasicTokenizerOp::kDefLowerCase,
|
||||
bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace,
|
||||
NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm,
|
||||
bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken)
|
||||
explicit BertTokenizerOp(const std::shared_ptr<Vocab> &vocab,
|
||||
const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator,
|
||||
const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||
const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken,
|
||||
bool lower_case = BasicTokenizerOp::kDefLowerCase,
|
||||
bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace,
|
||||
NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm,
|
||||
bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken)
|
||||
: wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token),
|
||||
basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
|
|||
// @return
|
||||
Status insert(const value_type &val, key_type *key = nullptr) {
|
||||
key_type my_inx = inx_.fetch_add(1);
|
||||
if (key) {
|
||||
if (key != nullptr) {
|
||||
*key = my_inx;
|
||||
}
|
||||
return my_tree::DoInsert(my_inx, val);
|
||||
|
|
|
@ -323,7 +323,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
|
|||
}
|
||||
|
||||
vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
|
||||
uint64_t i_size = kUnsignedOne << int_type;
|
||||
uint64_t i_size = kUnsignedOne << static_cast<uint8_t>(int_type);
|
||||
// Get number of elements
|
||||
uint64_t src_n_int = src_bytes.size() / i_size;
|
||||
// Calculate bitmap size (bytes)
|
||||
|
@ -344,7 +344,7 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const
|
|||
// Initialize destination data type
|
||||
IntegerType dst_int_type = kInt8Type;
|
||||
// Shift to next int position
|
||||
uint64_t pos = i * (kUnsignedOne << int_type);
|
||||
uint64_t pos = i * (kUnsignedOne << static_cast<uint8_t>(int_type));
|
||||
// Narrow down this int
|
||||
int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ class Shuffle(str, Enum):
|
|||
@check_zip
|
||||
def zip(datasets):
|
||||
"""
|
||||
Zips the datasets in the input tuple of datasets.
|
||||
Zip the datasets in the input tuple of datasets.
|
||||
|
||||
Args:
|
||||
datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
|
||||
|
@ -152,7 +152,7 @@ class Dataset:
|
|||
|
||||
def get_args(self):
|
||||
"""
|
||||
Returns attributes (member variables) related to the current class.
|
||||
Return attributes (member variables) related to the current class.
|
||||
|
||||
Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
|
||||
|
||||
|
@ -239,7 +239,7 @@ class Dataset:
|
|||
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
|
||||
input_columns=None, pad_info=None):
|
||||
"""
|
||||
Combines batch_size number of consecutive rows into batches.
|
||||
Combine batch_size number of consecutive rows into batches.
|
||||
|
||||
For any child node, a batch is treated as a single row.
|
||||
For any column, all the elements within that column must have the same shape.
|
||||
|
@ -340,7 +340,7 @@ class Dataset:
|
|||
|
||||
def flat_map(self, func):
|
||||
"""
|
||||
Maps `func` to each row in dataset and flatten the result.
|
||||
Map `func` to each row in dataset and flatten the result.
|
||||
|
||||
The specified `func` is a function that must take one 'Ndarray' as input
|
||||
and return a 'Dataset'.
|
||||
|
@ -370,6 +370,7 @@ class Dataset:
|
|||
"""
|
||||
dataset = None
|
||||
if not hasattr(func, '__call__'):
|
||||
logger.error("func must be a function.")
|
||||
raise TypeError("func must be a function.")
|
||||
|
||||
for row_data in self:
|
||||
|
@ -379,6 +380,7 @@ class Dataset:
|
|||
dataset += func(row_data)
|
||||
|
||||
if not isinstance(dataset, Dataset):
|
||||
logger.error("flat_map must return a Dataset object.")
|
||||
raise TypeError("flat_map must return a Dataset object.")
|
||||
return dataset
|
||||
|
||||
|
@ -386,7 +388,7 @@ class Dataset:
|
|||
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||
num_parallel_workers=None, python_multiprocessing=False):
|
||||
"""
|
||||
Applies each operation in operations to this dataset.
|
||||
Apply each operation in operations to this dataset.
|
||||
|
||||
The order of operations is determined by the position of each operation in operations.
|
||||
operations[0] will be applied first, then operations[1], then operations[2], etc.
|
||||
|
@ -570,7 +572,7 @@ class Dataset:
|
|||
@check_repeat
|
||||
def repeat(self, count=None):
|
||||
"""
|
||||
Repeats this dataset count times. Repeat indefinitely if the count is None or -1.
|
||||
Repeat this dataset count times. Repeat indefinitely if the count is None or -1.
|
||||
|
||||
Note:
|
||||
The order of using repeat and batch reflects the number of batches. Recommend that
|
||||
|
@ -662,13 +664,16 @@ class Dataset:
|
|||
dataset_size = self.get_dataset_size()
|
||||
|
||||
if dataset_size is None or dataset_size <= 0:
|
||||
raise RuntimeError("dataset size unknown, unable to split.")
|
||||
raise RuntimeError("dataset_size is unknown, unable to split.")
|
||||
|
||||
if not isinstance(sizes, list):
|
||||
raise RuntimeError("sizes should be a list.")
|
||||
|
||||
all_int = all(isinstance(item, int) for item in sizes)
|
||||
if all_int:
|
||||
sizes_sum = sum(sizes)
|
||||
if sizes_sum != dataset_size:
|
||||
raise RuntimeError("sum of split sizes {} is not equal to dataset size {}."
|
||||
raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
|
||||
.format(sizes_sum, dataset_size))
|
||||
return sizes
|
||||
|
||||
|
@ -676,7 +681,7 @@ class Dataset:
|
|||
for item in sizes:
|
||||
absolute_size = int(round(item * dataset_size))
|
||||
if absolute_size == 0:
|
||||
raise RuntimeError("split percentage {} is too small.".format(item))
|
||||
raise RuntimeError("Split percentage {} is too small.".format(item))
|
||||
absolute_sizes.append(absolute_size)
|
||||
|
||||
absolute_sizes_sum = sum(absolute_sizes)
|
||||
|
@ -694,7 +699,7 @@ class Dataset:
|
|||
break
|
||||
|
||||
if sum(absolute_sizes) != dataset_size:
|
||||
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
|
||||
raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
|
||||
.format(absolute_sizes_sum, dataset_size))
|
||||
|
||||
return absolute_sizes
|
||||
|
@ -702,7 +707,7 @@ class Dataset:
|
|||
@check_split
|
||||
def split(self, sizes, randomize=True):
|
||||
"""
|
||||
Splits the dataset into smaller, non-overlapping datasets.
|
||||
Split the dataset into smaller, non-overlapping datasets.
|
||||
|
||||
This is a general purpose split function which can be called from any operator in the pipeline.
|
||||
There is another, optimized split function, which will be called automatically if ds.split is
|
||||
|
@ -759,10 +764,10 @@ class Dataset:
|
|||
>>> train, test = data.split([0.9, 0.1])
|
||||
"""
|
||||
if self.is_shuffled():
|
||||
logger.warning("dataset is shuffled before split.")
|
||||
logger.warning("Dataset is shuffled before split.")
|
||||
|
||||
if self.is_sharded():
|
||||
raise RuntimeError("dataset should not be sharded before split.")
|
||||
raise RuntimeError("Dataset should not be sharded before split.")
|
||||
|
||||
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
||||
splits = []
|
||||
|
@ -788,7 +793,7 @@ class Dataset:
|
|||
@check_zip_dataset
|
||||
def zip(self, datasets):
|
||||
"""
|
||||
Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
|
||||
Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
|
||||
|
||||
Args:
|
||||
datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset
|
||||
|
@ -845,7 +850,7 @@ class Dataset:
|
|||
@check_rename
|
||||
def rename(self, input_columns, output_columns):
|
||||
"""
|
||||
Renames the columns in input datasets.
|
||||
Rename the columns in input datasets.
|
||||
|
||||
Args:
|
||||
input_columns (list[str]): list of names of the input columns.
|
||||
|
@ -871,7 +876,7 @@ class Dataset:
|
|||
@check_project
|
||||
def project(self, columns):
|
||||
"""
|
||||
Projects certain columns in input datasets.
|
||||
Project certain columns in input datasets.
|
||||
|
||||
The specified columns will be selected from the dataset and passed down
|
||||
the pipeline in the order specified. The other columns are discarded.
|
||||
|
@ -936,7 +941,7 @@ class Dataset:
|
|||
|
||||
def device_que(self, prefetch_size=None):
|
||||
"""
|
||||
Returns a transferredDataset that transfer data through device.
|
||||
Return a transferredDataset that transfer data through device.
|
||||
|
||||
Args:
|
||||
prefetch_size (int, optional): prefetch number of records ahead of the
|
||||
|
@ -953,7 +958,7 @@ class Dataset:
|
|||
|
||||
def to_device(self, num_batch=None):
|
||||
"""
|
||||
Transfers data through CPU, GPU or Ascend devices.
|
||||
Transfer data through CPU, GPU or Ascend devices.
|
||||
|
||||
Args:
|
||||
num_batch (int, optional): limit the number of batch to be sent to device (default=None).
|
||||
|
@ -988,7 +993,7 @@ class Dataset:
|
|||
raise TypeError("Please set device_type in context")
|
||||
|
||||
if device_type not in ('Ascend', 'GPU', 'CPU'):
|
||||
raise ValueError("only support CPU, Ascend, GPU")
|
||||
raise ValueError("Only support CPU, Ascend, GPU")
|
||||
|
||||
if num_batch is None or num_batch == 0:
|
||||
raise ValueError("num_batch is None or 0.")
|
||||
|
@ -1089,7 +1094,7 @@ class Dataset:
|
|||
|
||||
def _get_pipeline_info(self):
|
||||
"""
|
||||
Gets pipeline information.
|
||||
Get pipeline information.
|
||||
"""
|
||||
device_iter = TupleIterator(self)
|
||||
self._output_shapes = device_iter.get_output_shapes()
|
||||
|
@ -1344,7 +1349,7 @@ class MappableDataset(SourceDataset):
|
|||
@check_split
|
||||
def split(self, sizes, randomize=True):
|
||||
"""
|
||||
Splits the dataset into smaller, non-overlapping datasets.
|
||||
Split the dataset into smaller, non-overlapping datasets.
|
||||
|
||||
There is the optimized split function, which will be called automatically when the dataset
|
||||
that calls this function is a MappableDataset.
|
||||
|
@ -1411,10 +1416,10 @@ class MappableDataset(SourceDataset):
|
|||
>>> train.use_sampler(train_sampler)
|
||||
"""
|
||||
if self.is_shuffled():
|
||||
logger.warning("dataset is shuffled before split.")
|
||||
logger.warning("Dataset is shuffled before split.")
|
||||
|
||||
if self.is_sharded():
|
||||
raise RuntimeError("dataset should not be sharded before split.")
|
||||
raise RuntimeError("Dataset should not be sharded before split.")
|
||||
|
||||
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
||||
splits = []
|
||||
|
@ -1633,7 +1638,7 @@ class BlockReleasePair:
|
|||
|
||||
def __init__(self, init_release_rows, callback=None):
|
||||
if isinstance(init_release_rows, int) and init_release_rows <= 0:
|
||||
raise ValueError("release_rows need to be greater than 0.")
|
||||
raise ValueError("release_rows need to be greater than 0.")
|
||||
self.row_count = -init_release_rows
|
||||
self.cv = threading.Condition()
|
||||
self.callback = callback
|
||||
|
@ -2699,10 +2704,10 @@ class MindDataset(MappableDataset):
|
|||
self.shard_id = shard_id
|
||||
|
||||
if block_reader is True and num_shards is not None:
|
||||
raise ValueError("block reader not allowed true when use partitions")
|
||||
raise ValueError("block_reader not allowed true when use partitions")
|
||||
|
||||
if block_reader is True and shuffle is True:
|
||||
raise ValueError("block reader not allowed true when use shuffle")
|
||||
raise ValueError("block_reader not allowed true when use shuffle")
|
||||
|
||||
if block_reader is True:
|
||||
logger.warning("WARN: global shuffle is not used.")
|
||||
|
@ -2711,14 +2716,14 @@ class MindDataset(MappableDataset):
|
|||
if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
|
||||
samplers.DistributedSampler, samplers.RandomSampler,
|
||||
samplers.SequentialSampler)) is False:
|
||||
raise ValueError("the sampler is not supported yet.")
|
||||
raise ValueError("The sampler is not supported yet.")
|
||||
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
|
||||
# sampler exclusive
|
||||
if block_reader is True and sampler is not None:
|
||||
raise ValueError("block reader not allowed true when use sampler")
|
||||
raise ValueError("block_reader not allowed true when use sampler")
|
||||
|
||||
if num_padded is None:
|
||||
num_padded = 0
|
||||
|
@ -2770,7 +2775,7 @@ class MindDataset(MappableDataset):
|
|||
if value >= 0:
|
||||
self._dataset_size = value
|
||||
else:
|
||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_option is None:
|
||||
|
@ -2872,7 +2877,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
|
|||
|
||||
def _fetch_py_sampler_indices(sampler, num_samples):
|
||||
"""
|
||||
Indices fetcher for python sampler.
|
||||
Indice fetcher for python sampler.
|
||||
"""
|
||||
if num_samples is not None:
|
||||
sampler_iter = iter(sampler)
|
||||
|
@ -3163,7 +3168,7 @@ class GeneratorDataset(MappableDataset):
|
|||
if value >= 0:
|
||||
self._dataset_size = value
|
||||
else:
|
||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
if id(self) in memodict:
|
||||
|
@ -3313,7 +3318,7 @@ class TFRecordDataset(SourceDataset):
|
|||
if value >= 0:
|
||||
self._dataset_size = value
|
||||
else:
|
||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
@ -4382,7 +4387,9 @@ class CelebADataset(MappableDataset):
|
|||
try:
|
||||
with open(attr_file, 'r') as f:
|
||||
num_rows = int(f.readline())
|
||||
except Exception:
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError("attr_file not found.")
|
||||
except BaseException:
|
||||
raise RuntimeError("Get dataset size failed from attribution file.")
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is not None:
|
||||
|
|
|
@ -319,7 +319,7 @@ class PKSampler(BuiltinSampler):
|
|||
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
|
||||
|
||||
if num_class is not None:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Not support specify num_class")
|
||||
|
||||
if not isinstance(shuffle, bool):
|
||||
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
|
||||
|
@ -551,8 +551,8 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
|
||||
Args:
|
||||
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
|
||||
num_samples (int): Number of elements to sample (default=None, all elements).
|
||||
replacement (bool, optional): If True, put the sample ID back for the next draw (default=True).
|
||||
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
||||
replacement (bool): If True, put the sample ID back for the next draw (default=True).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
|
|
@ -50,7 +50,7 @@ def check_filename(path):
|
|||
Exception: when error
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
raise ValueError("path: {} is not string".format(path))
|
||||
raise TypeError("path: {} is not string".format(path))
|
||||
filename = os.path.basename(path)
|
||||
|
||||
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
|
||||
|
@ -143,7 +143,7 @@ def check_sampler_shuffle_shard_options(param_dict):
|
|||
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||
|
||||
if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||
raise ValueError("sampler is not a valid Sampler type.")
|
||||
raise TypeError("sampler is not a valid Sampler type.")
|
||||
|
||||
if sampler is not None:
|
||||
if shuffle is not None:
|
||||
|
@ -328,13 +328,13 @@ def check_vocdataset(method):
|
|||
if task is None:
|
||||
raise ValueError("task is not provided.")
|
||||
if not isinstance(task, str):
|
||||
raise ValueError("task is not str type.")
|
||||
raise TypeError("task is not str type.")
|
||||
# check mode; required argument
|
||||
mode = param_dict.get('mode')
|
||||
if mode is None:
|
||||
raise ValueError("mode is not provided.")
|
||||
if not isinstance(mode, str):
|
||||
raise ValueError("mode is not str type.")
|
||||
raise TypeError("mode is not str type.")
|
||||
|
||||
imagesets_file = ""
|
||||
if task == "Segmentation":
|
||||
|
@ -388,7 +388,7 @@ def check_cocodataset(method):
|
|||
if task is None:
|
||||
raise ValueError("task is not provided.")
|
||||
if not isinstance(task, str):
|
||||
raise ValueError("task is not str type.")
|
||||
raise TypeError("task is not str type.")
|
||||
|
||||
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
|
||||
raise ValueError("Invalid task type")
|
||||
|
@ -556,7 +556,7 @@ def check_generatordataset(method):
|
|||
|
||||
def check_batch_size(batch_size):
|
||||
if not (isinstance(batch_size, int) or (callable(batch_size))):
|
||||
raise ValueError("batch_size should either be an int or a callable.")
|
||||
raise TypeError("batch_size should either be an int or a callable.")
|
||||
if callable(batch_size):
|
||||
sig = ins.signature(batch_size)
|
||||
if len(sig.parameters) != 1:
|
||||
|
@ -706,6 +706,7 @@ def check_batch(method):
|
|||
|
||||
def check_sync_wait(method):
|
||||
"""check the input arguments of sync_wait."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -773,7 +774,7 @@ def check_filter(method):
|
|||
param_dict = make_param_dict(method, args, kwargs)
|
||||
predicate = param_dict.get("predicate")
|
||||
if not callable(predicate):
|
||||
raise ValueError("Predicate should be a python function or a callable python object.")
|
||||
raise TypeError("Predicate should be a python function or a callable python object.")
|
||||
|
||||
nreq_param_int = ['num_parallel_workers']
|
||||
check_param_type(nreq_param_int, param_dict, int)
|
||||
|
@ -865,7 +866,7 @@ def check_zip_dataset(method):
|
|||
raise ValueError("datasets is not provided.")
|
||||
|
||||
if not isinstance(ds, (tuple, datasets.Dataset)):
|
||||
raise ValueError("datasets is not tuple or of type Dataset.")
|
||||
raise TypeError("datasets is not tuple or of type Dataset.")
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
|
@ -885,7 +886,7 @@ def check_concat(method):
|
|||
raise ValueError("datasets is not provided.")
|
||||
|
||||
if not isinstance(ds, (list, datasets.Dataset)):
|
||||
raise ValueError("datasets is not list or of type Dataset.")
|
||||
raise TypeError("datasets is not list or of type Dataset.")
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
|
@ -964,7 +965,7 @@ def check_add_column(method):
|
|||
de_type = param_dict.get("de_type")
|
||||
if de_type is not None:
|
||||
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
|
||||
raise ValueError("Unknown column type.")
|
||||
raise TypeError("Unknown column type.")
|
||||
else:
|
||||
raise TypeError("Expected non-empty string.")
|
||||
|
||||
|
|
|
@ -10,6 +10,6 @@ wheel >= 0.32.0
|
|||
decorator >= 4.4.0
|
||||
setuptools >= 40.8.0
|
||||
matplotlib >= 3.1.3 # for ut test
|
||||
opencv-python >= 4.2.0.32 # for ut test
|
||||
opencv-python >= 4.1.2.30 # for ut test
|
||||
sklearn >= 0.0 # for st test
|
||||
pandas >= 1.0.2 # for ut test
|
|
@ -42,15 +42,15 @@ def split_with_invalid_inputs(d):
|
|||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_, _ = d.split([3, 1])
|
||||
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
|
||||
assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_, _ = d.split([5, 1])
|
||||
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
|
||||
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_, _ = d.split([-0.5, 0.5])
|
||||
|
@ -80,7 +80,7 @@ def test_unmappable_invalid_input():
|
|||
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_, _ = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
assert "Dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
|
||||
def test_unmappable_split():
|
||||
|
@ -274,7 +274,7 @@ def test_mappable_invalid_input():
|
|||
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_, _ = d.split([4, 1])
|
||||
assert "dataset should not be sharded before split" in str(info.value)
|
||||
assert "Dataset should not be sharded before split" in str(info.value)
|
||||
|
||||
|
||||
def test_mappable_split_general():
|
||||
|
|
Loading…
Reference in New Issue