From 4e3bfcf4c95129925f7349f14a7553de33c24fc2 Mon Sep 17 00:00:00 2001 From: YangLuo Date: Thu, 18 Jun 2020 21:03:55 +0800 Subject: [PATCH] !2306 [Dataset] Code review & improve quality --- .../graph_to_mindrecord/citeseer/mr_api.py | 13 ++-- .../graph_to_mindrecord/graph_map_schema.py | 12 +++ example/graph_to_mindrecord/writer.py | 2 +- mindspore/ccsrc/dataset/core/tensor.cc | 4 +- .../datasetops/bucket_batch_by_length_op.h | 3 + .../ccsrc/dataset/engine/gnn/graph_loader.cc | 1 + mindspore/ccsrc/dataset/engine/opt/pass.cc | 4 +- .../dataset/engine/perf/connector_size.h | 2 +- .../ccsrc/dataset/engine/perf/monitor.cc | 1 + .../ccsrc/dataset/kernels/data/slice_op.h | 2 + .../image/random_horizontal_flip_bbox_op.cc | 4 +- .../text/kernels/basic_tokenizer_op.cc | 1 + .../dataset/text/kernels/basic_tokenizer_op.h | 6 +- .../dataset/text/kernels/bert_tokenizer_op.h | 16 ++-- mindspore/ccsrc/dataset/util/auto_index.h | 2 +- .../ccsrc/mindrecord/meta/shard_column.cc | 4 +- mindspore/dataset/engine/datasets.py | 73 ++++++++++--------- mindspore/dataset/engine/samplers.py | 6 +- mindspore/dataset/engine/validators.py | 21 +++--- requirements.txt | 2 +- tests/ut/python/dataset/test_split.py | 10 +-- 21 files changed, 110 insertions(+), 79 deletions(-) diff --git a/example/graph_to_mindrecord/citeseer/mr_api.py b/example/graph_to_mindrecord/citeseer/mr_api.py index 69bc442f4e8..aa9e2a2c4d4 100644 --- a/example/graph_to_mindrecord/citeseer/mr_api.py +++ b/example/graph_to_mindrecord/citeseer/mr_api.py @@ -20,6 +20,7 @@ import os import pickle as pkl import numpy as np import scipy.sparse as sp +from mindspore import log as logger # parse args from command line parameter 'graph_api_args' # args delimiter is ':' @@ -58,7 +59,7 @@ def yield_nodes(task_id=0): Yields: data (dict): data row which is dict. """ - print("Node task is {}".format(task_id)) + logger.info("Node task is {}".format(task_id)) names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] objects = [] for name in names: @@ -98,7 +99,7 @@ def yield_nodes(task_id=0): line_count += 1 node_ids.append(i) yield node - print('Processed {} lines for nodes.'.format(line_count)) + logger.info('Processed {} lines for nodes.'.format(line_count)) def yield_edges(task_id=0): @@ -108,21 +109,21 @@ def yield_edges(task_id=0): Yields: data (dict): data row which is dict. """ - print("Edge task is {}".format(task_id)) + logger.info("Edge task is {}".format(task_id)) with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: graph = pkl.load(f, encoding='latin1') line_count = 0 for i in graph: for dst_id in graph[i]: if not i in node_ids: - print('Source node {} does not exist.'.format(i)) + logger.info('Source node {} does not exist.'.format(i)) continue if not dst_id in node_ids: - print('Destination node {} does not exist.'.format( + logger.info('Destination node {} does not exist.'.format( dst_id)) continue edge = {'id': line_count, 'src_id': i, 'dst_id': dst_id, 'type': 0} line_count += 1 yield edge - print('Processed {} lines for edges.'.format(line_count)) + logger.info('Processed {} lines for edges.'.format(line_count)) diff --git a/example/graph_to_mindrecord/graph_map_schema.py b/example/graph_to_mindrecord/graph_map_schema.py index e131de9f650..1da1ced2f7d 100644 --- a/example/graph_to_mindrecord/graph_map_schema.py +++ b/example/graph_to_mindrecord/graph_map_schema.py @@ -16,6 +16,7 @@ Graph data convert tool for MindRecord. """ import numpy as np +from mindspore import log as logger __all__ = ['GraphMapSchema'] @@ -41,6 +42,7 @@ class GraphMapSchema: "edge_feature_index": {"type": "int32", "shape": [-1]} } + @property def get_schema(self): """ Get schema @@ -52,6 +54,7 @@ class GraphMapSchema: Set node features profile """ if num_features != len(features_data_type) or num_features != len(features_shape): + logger.info("Node feature profile is not match.") raise ValueError("Node feature profile is not match.") self.num_node_features = num_features @@ -66,6 +69,7 @@ class GraphMapSchema: Set edge features profile """ if num_features != len(features_data_type) or num_features != len(features_shape): + logger.info("Edge feature profile is not match.") raise ValueError("Edge feature profile is not match.") self.num_edge_features = num_features @@ -83,6 +87,10 @@ class GraphMapSchema: Returns: graph data with union schema """ + if node is None: + logger.info("node cannot be None.") + raise ValueError("node cannot be None.") + node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], "node_feature_index": []} for i in range(self.num_node_features): @@ -117,6 +125,10 @@ class GraphMapSchema: Returns: graph data with union schema """ + if edge is None: + logger.info("edge cannot be None.") + raise ValueError("edge cannot be None.") + edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', "type": edge["type"], "edge_feature_index": []} diff --git a/example/graph_to_mindrecord/writer.py b/example/graph_to_mindrecord/writer.py index 1024c823729..9dce63e265e 100644 --- a/example/graph_to_mindrecord/writer.py +++ b/example/graph_to_mindrecord/writer.py @@ -164,7 +164,7 @@ if __name__ == "__main__": num_features, feature_data_types, feature_shapes = mr_api.edge_profile graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes) - graph_schema = graph_map_schema.get_schema() + graph_schema = graph_map_schema.get_schema # init writer writer = init_writer(graph_schema) diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc index a3c3e4533cb..abab8cf3f47 100644 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ b/mindspore/ccsrc/dataset/core/tensor.cc @@ -983,7 +983,9 @@ Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vectorSizeInBytes(), data_ + src_start * type_size, count * type_size); + int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, + count * type_size); + CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric"); out_index += count; if (i < indices.size() - 1) { src_start = HandleNeg(indices[i + 1], dim_length); // next index diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h index 9a9025f2375..bf0bcb0e787 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -101,6 +101,9 @@ class BucketBatchByLengthOp : public PipelineOp { std::vector bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); + // Destructor + ~BucketBatchByLengthOp() = default; + // Might need to batch remaining buckets after receiving eoe, so override this method. // @param int32_t workerId // @return Status - The error code returned diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc index 9e5cbbb7889..6504d088bf2 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc @@ -36,6 +36,7 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) : mr_path_(mr_filepath), num_workers_(num_workers), row_id_(0), + shard_reader_(nullptr), keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.cc b/mindspore/ccsrc/dataset/engine/opt/pass.cc index 91f74581f9b..a032d46cba5 100644 --- a/mindspore/ccsrc/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/dataset/engine/opt/pass.cc @@ -37,7 +37,7 @@ namespace dataset { // Driver method for TreePass Status TreePass::Run(ExecutionTree *tree, bool *modified) { - if (!tree || !modified) { + if (tree == nullptr || modified == nullptr) { return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); } return this->RunOnTree(tree, modified); @@ -45,7 +45,7 @@ Status TreePass::Run(ExecutionTree *tree, bool *modified) { // Driver method for NodePass Status NodePass::Run(ExecutionTree *tree, bool *modified) { - if (!tree || !modified) { + if (tree == nullptr || modified == nullptr) { return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); } std::shared_ptr root = tree->root(); diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_size.h b/mindspore/ccsrc/dataset/engine/perf/connector_size.h index 321972d28c9..6840ffe2449 100644 --- a/mindspore/ccsrc/dataset/engine/perf/connector_size.h +++ b/mindspore/ccsrc/dataset/engine/perf/connector_size.h @@ -44,7 +44,7 @@ class ConnectorSize : public Sampling { public: explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} - ~ConnectorSize() = default; + ~ConnectorSize() override = default; // Driver function for connector size sampling. // This function samples the connector size of every nodes within the ExecutionTree diff --git a/mindspore/ccsrc/dataset/engine/perf/monitor.cc b/mindspore/ccsrc/dataset/engine/perf/monitor.cc index 9064604075b..c9dce004b59 100644 --- a/mindspore/ccsrc/dataset/engine/perf/monitor.cc +++ b/mindspore/ccsrc/dataset/engine/perf/monitor.cc @@ -26,6 +26,7 @@ Monitor::Monitor(ExecutionTree *tree) : tree_(tree) { std::shared_ptr cfg = GlobalContext::config_manager(); sampling_interval_ = cfg->monitor_sampling_interval(); max_samples_ = 0; + cur_row_ = 0; } Status Monitor::operator()() { diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.h b/mindspore/ccsrc/dataset/kernels/data/slice_op.h index 1bc7f0d5b9a..0a24ae171ee 100644 --- a/mindspore/ccsrc/dataset/kernels/data/slice_op.h +++ b/mindspore/ccsrc/dataset/kernels/data/slice_op.h @@ -34,6 +34,8 @@ class Slice { Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} + ~Slice() = default; + std::vector Indices(dsize_t length) { std::vector indices; dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc index 71030d1cb11..5a5c632e81d 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc @@ -29,8 +29,8 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow BOUNDING_BOX_CHECK(input); if (distribution_(rnd_)) { // To test bounding boxes algorithm, create random bboxes from image dims - size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes - float img_center = (input[0]->shape()[1] / 2); // get the center of the image + size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes + float img_center = (input[0]->shape()[1] / 2.); // get the center of the image for (int i = 0; i < num_of_boxes; i++) { uint32_t b_w = 0; // bounding box width diff --git a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc index e8f5f1f15af..1128990b44e 100644 --- a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc +++ b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc @@ -49,6 +49,7 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal preserve_unused_token_(preserve_unused_token), case_fold_(std::make_unique()), nfd_normalize_(std::make_unique(NormalizeForm::kNfd)), + normalization_form_(normalization_form), common_normalize_(std::make_unique(normalization_form)), replace_accent_chars_(std::make_unique("\\p{Mn}", "")), replace_control_chars_(std::make_unique("\\p{Cc}|\\p{Cf}", " ")) { diff --git a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h index da79ad08766..a37e841573e 100644 --- a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h +++ b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h @@ -35,9 +35,9 @@ class BasicTokenizerOp : public TensorOp { static const bool kDefKeepWhitespace; static const NormalizeForm kDefNormalizationForm; static const bool kDefPreserveUnusedToken; - BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace, - NormalizeForm normalization_form = kDefNormalizationForm, - bool preserve_unused_token = kDefPreserveUnusedToken); + explicit BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace, + NormalizeForm normalization_form = kDefNormalizationForm, + bool preserve_unused_token = kDefPreserveUnusedToken); ~BasicTokenizerOp() override = default; diff --git a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h index 61c6785f357..660fdc7ba58 100644 --- a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h +++ b/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h @@ -28,14 +28,14 @@ namespace mindspore { namespace dataset { class BertTokenizerOp : public TensorOp { public: - BertTokenizerOp(const std::shared_ptr &vocab, - const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, - const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, - const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, - bool lower_case = BasicTokenizerOp::kDefLowerCase, - bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, - NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm, - bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken) + explicit BertTokenizerOp(const std::shared_ptr &vocab, + const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, + const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, + const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, + bool lower_case = BasicTokenizerOp::kDefLowerCase, + bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, + NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm, + bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken) : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token), basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {} diff --git a/mindspore/ccsrc/dataset/util/auto_index.h b/mindspore/ccsrc/dataset/util/auto_index.h index 2b4c2d68833..11a2e90b00d 100644 --- a/mindspore/ccsrc/dataset/util/auto_index.h +++ b/mindspore/ccsrc/dataset/util/auto_index.h @@ -48,7 +48,7 @@ class AutoIndexObj : public BPlusTree { // @return Status insert(const value_type &val, key_type *key = nullptr) { key_type my_inx = inx_.fetch_add(1); - if (key) { + if (key != nullptr) { *key = my_inx; } return my_tree::DoInsert(my_inx, val); diff --git a/mindspore/ccsrc/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/mindrecord/meta/shard_column.cc index 8a2fd47bf6a..28dc243e172 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_column.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_column.cc @@ -323,7 +323,7 @@ std::vector ShardColumn::CompressBlob(const std::vector &blob) } vector ShardColumn::CompressInt(const vector &src_bytes, const IntegerType &int_type) { - uint64_t i_size = kUnsignedOne << int_type; + uint64_t i_size = kUnsignedOne << static_cast(int_type); // Get number of elements uint64_t src_n_int = src_bytes.size() / i_size; // Calculate bitmap size (bytes) @@ -344,7 +344,7 @@ vector ShardColumn::CompressInt(const vector &src_bytes, const // Initialize destination data type IntegerType dst_int_type = kInt8Type; // Shift to next int position - uint64_t pos = i * (kUnsignedOne << int_type); + uint64_t pos = i * (kUnsignedOne << static_cast(int_type)); // Narrow down this int int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 603ad35c7ed..ca6f7ca33e5 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -61,7 +61,7 @@ class Shuffle(str, Enum): @check_zip def zip(datasets): """ - Zips the datasets in the input tuple of datasets. + Zip the datasets in the input tuple of datasets. Args: datasets (tuple of class Dataset): A tuple of datasets to be zipped together. @@ -152,7 +152,7 @@ class Dataset: def get_args(self): """ - Returns attributes (member variables) related to the current class. + Return attributes (member variables) related to the current class. Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'. @@ -239,7 +239,7 @@ class Dataset: def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, input_columns=None, pad_info=None): """ - Combines batch_size number of consecutive rows into batches. + Combine batch_size number of consecutive rows into batches. For any child node, a batch is treated as a single row. For any column, all the elements within that column must have the same shape. @@ -340,7 +340,7 @@ class Dataset: def flat_map(self, func): """ - Maps `func` to each row in dataset and flatten the result. + Map `func` to each row in dataset and flatten the result. The specified `func` is a function that must take one 'Ndarray' as input and return a 'Dataset'. @@ -370,6 +370,7 @@ class Dataset: """ dataset = None if not hasattr(func, '__call__'): + logger.error("func must be a function.") raise TypeError("func must be a function.") for row_data in self: @@ -379,6 +380,7 @@ class Dataset: dataset += func(row_data) if not isinstance(dataset, Dataset): + logger.error("flat_map must return a Dataset object.") raise TypeError("flat_map must return a Dataset object.") return dataset @@ -386,7 +388,7 @@ class Dataset: def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, num_parallel_workers=None, python_multiprocessing=False): """ - Applies each operation in operations to this dataset. + Apply each operation in operations to this dataset. The order of operations is determined by the position of each operation in operations. operations[0] will be applied first, then operations[1], then operations[2], etc. @@ -570,7 +572,7 @@ class Dataset: @check_repeat def repeat(self, count=None): """ - Repeats this dataset count times. Repeat indefinitely if the count is None or -1. + Repeat this dataset count times. Repeat indefinitely if the count is None or -1. Note: The order of using repeat and batch reflects the number of batches. Recommend that @@ -662,13 +664,16 @@ class Dataset: dataset_size = self.get_dataset_size() if dataset_size is None or dataset_size <= 0: - raise RuntimeError("dataset size unknown, unable to split.") + raise RuntimeError("dataset_size is unknown, unable to split.") + + if not isinstance(sizes, list): + raise RuntimeError("sizes should be a list.") all_int = all(isinstance(item, int) for item in sizes) if all_int: sizes_sum = sum(sizes) if sizes_sum != dataset_size: - raise RuntimeError("sum of split sizes {} is not equal to dataset size {}." + raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}." .format(sizes_sum, dataset_size)) return sizes @@ -676,7 +681,7 @@ class Dataset: for item in sizes: absolute_size = int(round(item * dataset_size)) if absolute_size == 0: - raise RuntimeError("split percentage {} is too small.".format(item)) + raise RuntimeError("Split percentage {} is too small.".format(item)) absolute_sizes.append(absolute_size) absolute_sizes_sum = sum(absolute_sizes) @@ -694,7 +699,7 @@ class Dataset: break if sum(absolute_sizes) != dataset_size: - raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}." + raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}." .format(absolute_sizes_sum, dataset_size)) return absolute_sizes @@ -702,7 +707,7 @@ class Dataset: @check_split def split(self, sizes, randomize=True): """ - Splits the dataset into smaller, non-overlapping datasets. + Split the dataset into smaller, non-overlapping datasets. This is a general purpose split function which can be called from any operator in the pipeline. There is another, optimized split function, which will be called automatically if ds.split is @@ -759,10 +764,10 @@ class Dataset: >>> train, test = data.split([0.9, 0.1]) """ if self.is_shuffled(): - logger.warning("dataset is shuffled before split.") + logger.warning("Dataset is shuffled before split.") if self.is_sharded(): - raise RuntimeError("dataset should not be sharded before split.") + raise RuntimeError("Dataset should not be sharded before split.") absolute_sizes = self._get_absolute_split_sizes(sizes) splits = [] @@ -788,7 +793,7 @@ class Dataset: @check_zip_dataset def zip(self, datasets): """ - Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name. + Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name. Args: datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset @@ -845,7 +850,7 @@ class Dataset: @check_rename def rename(self, input_columns, output_columns): """ - Renames the columns in input datasets. + Rename the columns in input datasets. Args: input_columns (list[str]): list of names of the input columns. @@ -871,7 +876,7 @@ class Dataset: @check_project def project(self, columns): """ - Projects certain columns in input datasets. + Project certain columns in input datasets. The specified columns will be selected from the dataset and passed down the pipeline in the order specified. The other columns are discarded. @@ -936,7 +941,7 @@ class Dataset: def device_que(self, prefetch_size=None): """ - Returns a transferredDataset that transfer data through device. + Return a transferredDataset that transfer data through device. Args: prefetch_size (int, optional): prefetch number of records ahead of the @@ -953,7 +958,7 @@ class Dataset: def to_device(self, num_batch=None): """ - Transfers data through CPU, GPU or Ascend devices. + Transfer data through CPU, GPU or Ascend devices. Args: num_batch (int, optional): limit the number of batch to be sent to device (default=None). @@ -988,7 +993,7 @@ class Dataset: raise TypeError("Please set device_type in context") if device_type not in ('Ascend', 'GPU', 'CPU'): - raise ValueError("only support CPU, Ascend, GPU") + raise ValueError("Only support CPU, Ascend, GPU") if num_batch is None or num_batch == 0: raise ValueError("num_batch is None or 0.") @@ -1089,7 +1094,7 @@ class Dataset: def _get_pipeline_info(self): """ - Gets pipeline information. + Get pipeline information. """ device_iter = TupleIterator(self) self._output_shapes = device_iter.get_output_shapes() @@ -1344,7 +1349,7 @@ class MappableDataset(SourceDataset): @check_split def split(self, sizes, randomize=True): """ - Splits the dataset into smaller, non-overlapping datasets. + Split the dataset into smaller, non-overlapping datasets. There is the optimized split function, which will be called automatically when the dataset that calls this function is a MappableDataset. @@ -1411,10 +1416,10 @@ class MappableDataset(SourceDataset): >>> train.use_sampler(train_sampler) """ if self.is_shuffled(): - logger.warning("dataset is shuffled before split.") + logger.warning("Dataset is shuffled before split.") if self.is_sharded(): - raise RuntimeError("dataset should not be sharded before split.") + raise RuntimeError("Dataset should not be sharded before split.") absolute_sizes = self._get_absolute_split_sizes(sizes) splits = [] @@ -1633,7 +1638,7 @@ class BlockReleasePair: def __init__(self, init_release_rows, callback=None): if isinstance(init_release_rows, int) and init_release_rows <= 0: - raise ValueError("release_rows need to be greater than 0.") + raise ValueError("release_rows need to be greater than 0.") self.row_count = -init_release_rows self.cv = threading.Condition() self.callback = callback @@ -2699,10 +2704,10 @@ class MindDataset(MappableDataset): self.shard_id = shard_id if block_reader is True and num_shards is not None: - raise ValueError("block reader not allowed true when use partitions") + raise ValueError("block_reader not allowed true when use partitions") if block_reader is True and shuffle is True: - raise ValueError("block reader not allowed true when use shuffle") + raise ValueError("block_reader not allowed true when use shuffle") if block_reader is True: logger.warning("WARN: global shuffle is not used.") @@ -2711,14 +2716,14 @@ class MindDataset(MappableDataset): if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SequentialSampler)) is False: - raise ValueError("the sampler is not supported yet.") + raise ValueError("The sampler is not supported yet.") self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples # sampler exclusive if block_reader is True and sampler is not None: - raise ValueError("block reader not allowed true when use sampler") + raise ValueError("block_reader not allowed true when use sampler") if num_padded is None: num_padded = 0 @@ -2770,7 +2775,7 @@ class MindDataset(MappableDataset): if value >= 0: self._dataset_size = value else: - raise ValueError('set dataset_size with negative value {}'.format(value)) + raise ValueError('Set dataset_size with negative value {}'.format(value)) def is_shuffled(self): if self.shuffle_option is None: @@ -2872,7 +2877,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): def _fetch_py_sampler_indices(sampler, num_samples): """ - Indices fetcher for python sampler. + Indice fetcher for python sampler. """ if num_samples is not None: sampler_iter = iter(sampler) @@ -3163,7 +3168,7 @@ class GeneratorDataset(MappableDataset): if value >= 0: self._dataset_size = value else: - raise ValueError('set dataset_size with negative value {}'.format(value)) + raise ValueError('Set dataset_size with negative value {}'.format(value)) def __deepcopy__(self, memodict): if id(self) in memodict: @@ -3313,7 +3318,7 @@ class TFRecordDataset(SourceDataset): if value >= 0: self._dataset_size = value else: - raise ValueError('set dataset_size with negative value {}'.format(value)) + raise ValueError('Set dataset_size with negative value {}'.format(value)) def is_shuffled(self): return self.shuffle_files @@ -4382,7 +4387,9 @@ class CelebADataset(MappableDataset): try: with open(attr_file, 'r') as f: num_rows = int(f.readline()) - except Exception: + except FileNotFoundError: + raise RuntimeError("attr_file not found.") + except BaseException: raise RuntimeError("Get dataset size failed from attribution file.") rows_per_shard = get_num_rows(num_rows, self.num_shards) if self.num_samples is not None: diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index d89d90c3283..b74874f9cf3 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -319,7 +319,7 @@ class PKSampler(BuiltinSampler): raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) if num_class is not None: - raise NotImplementedError + raise NotImplementedError("Not support specify num_class") if not isinstance(shuffle, bool): raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) @@ -551,8 +551,8 @@ class WeightedRandomSampler(BuiltinSampler): Args: weights (list[float]): A sequence of weights, not necessarily summing up to 1. - num_samples (int): Number of elements to sample (default=None, all elements). - replacement (bool, optional): If True, put the sample ID back for the next draw (default=True). + num_samples (int, optional): Number of elements to sample (default=None, all elements). + replacement (bool): If True, put the sample ID back for the next draw (default=True). Examples: >>> import mindspore.dataset as ds diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 6cfea7c8b89..005f7072aa2 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -50,7 +50,7 @@ def check_filename(path): Exception: when error """ if not isinstance(path, str): - raise ValueError("path: {} is not string".format(path)) + raise TypeError("path: {} is not string".format(path)) filename = os.path.basename(path) # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', @@ -143,7 +143,7 @@ def check_sampler_shuffle_shard_options(param_dict): num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): - raise ValueError("sampler is not a valid Sampler type.") + raise TypeError("sampler is not a valid Sampler type.") if sampler is not None: if shuffle is not None: @@ -328,13 +328,13 @@ def check_vocdataset(method): if task is None: raise ValueError("task is not provided.") if not isinstance(task, str): - raise ValueError("task is not str type.") + raise TypeError("task is not str type.") # check mode; required argument mode = param_dict.get('mode') if mode is None: raise ValueError("mode is not provided.") if not isinstance(mode, str): - raise ValueError("mode is not str type.") + raise TypeError("mode is not str type.") imagesets_file = "" if task == "Segmentation": @@ -388,7 +388,7 @@ def check_cocodataset(method): if task is None: raise ValueError("task is not provided.") if not isinstance(task, str): - raise ValueError("task is not str type.") + raise TypeError("task is not str type.") if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: raise ValueError("Invalid task type") @@ -556,7 +556,7 @@ def check_generatordataset(method): def check_batch_size(batch_size): if not (isinstance(batch_size, int) or (callable(batch_size))): - raise ValueError("batch_size should either be an int or a callable.") + raise TypeError("batch_size should either be an int or a callable.") if callable(batch_size): sig = ins.signature(batch_size) if len(sig.parameters) != 1: @@ -706,6 +706,7 @@ def check_batch(method): def check_sync_wait(method): """check the input arguments of sync_wait.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -773,7 +774,7 @@ def check_filter(method): param_dict = make_param_dict(method, args, kwargs) predicate = param_dict.get("predicate") if not callable(predicate): - raise ValueError("Predicate should be a python function or a callable python object.") + raise TypeError("Predicate should be a python function or a callable python object.") nreq_param_int = ['num_parallel_workers'] check_param_type(nreq_param_int, param_dict, int) @@ -865,7 +866,7 @@ def check_zip_dataset(method): raise ValueError("datasets is not provided.") if not isinstance(ds, (tuple, datasets.Dataset)): - raise ValueError("datasets is not tuple or of type Dataset.") + raise TypeError("datasets is not tuple or of type Dataset.") return method(*args, **kwargs) @@ -885,7 +886,7 @@ def check_concat(method): raise ValueError("datasets is not provided.") if not isinstance(ds, (list, datasets.Dataset)): - raise ValueError("datasets is not list or of type Dataset.") + raise TypeError("datasets is not list or of type Dataset.") return method(*args, **kwargs) @@ -964,7 +965,7 @@ def check_add_column(method): de_type = param_dict.get("de_type") if de_type is not None: if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): - raise ValueError("Unknown column type.") + raise TypeError("Unknown column type.") else: raise TypeError("Expected non-empty string.") diff --git a/requirements.txt b/requirements.txt index 32fc6473649..4038e63ea77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,6 @@ wheel >= 0.32.0 decorator >= 4.4.0 setuptools >= 40.8.0 matplotlib >= 3.1.3 # for ut test -opencv-python >= 4.2.0.32 # for ut test +opencv-python >= 4.1.2.30 # for ut test sklearn >= 0.0 # for st test pandas >= 1.0.2 # for ut test \ No newline at end of file diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py index b904e2e0168..a51e8524545 100644 --- a/tests/ut/python/dataset/test_split.py +++ b/tests/ut/python/dataset/test_split.py @@ -42,15 +42,15 @@ def split_with_invalid_inputs(d): with pytest.raises(RuntimeError) as info: _, _ = d.split([3, 1]) - assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value) + assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value) with pytest.raises(RuntimeError) as info: _, _ = d.split([5, 1]) - assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value) + assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value) with pytest.raises(RuntimeError) as info: _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) - assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) + assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) with pytest.raises(ValueError) as info: _, _ = d.split([-0.5, 0.5]) @@ -80,7 +80,7 @@ def test_unmappable_invalid_input(): d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) with pytest.raises(RuntimeError) as info: _, _ = d.split([4, 1]) - assert "dataset should not be sharded before split" in str(info.value) + assert "Dataset should not be sharded before split" in str(info.value) def test_unmappable_split(): @@ -274,7 +274,7 @@ def test_mappable_invalid_input(): d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) with pytest.raises(RuntimeError) as info: _, _ = d.split([4, 1]) - assert "dataset should not be sharded before split" in str(info.value) + assert "Dataset should not be sharded before split" in str(info.value) def test_mappable_split_general():