!2306 [Dataset] Code review & improve quality

This commit is contained in:
YangLuo 2020-06-18 21:03:55 +08:00
parent 83b53559f5
commit 4e3bfcf4c9
21 changed files with 110 additions and 79 deletions

View File

@ -20,6 +20,7 @@ import os
import pickle as pkl import pickle as pkl
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from mindspore import log as logger
# parse args from command line parameter 'graph_api_args' # parse args from command line parameter 'graph_api_args'
# args delimiter is ':' # args delimiter is ':'
@ -58,7 +59,7 @@ def yield_nodes(task_id=0):
Yields: Yields:
data (dict): data row which is dict. data (dict): data row which is dict.
""" """
print("Node task is {}".format(task_id)) logger.info("Node task is {}".format(task_id))
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
objects = [] objects = []
for name in names: for name in names:
@ -98,7 +99,7 @@ def yield_nodes(task_id=0):
line_count += 1 line_count += 1
node_ids.append(i) node_ids.append(i)
yield node yield node
print('Processed {} lines for nodes.'.format(line_count)) logger.info('Processed {} lines for nodes.'.format(line_count))
def yield_edges(task_id=0): def yield_edges(task_id=0):
@ -108,21 +109,21 @@ def yield_edges(task_id=0):
Yields: Yields:
data (dict): data row which is dict. data (dict): data row which is dict.
""" """
print("Edge task is {}".format(task_id)) logger.info("Edge task is {}".format(task_id))
with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
graph = pkl.load(f, encoding='latin1') graph = pkl.load(f, encoding='latin1')
line_count = 0 line_count = 0
for i in graph: for i in graph:
for dst_id in graph[i]: for dst_id in graph[i]:
if not i in node_ids: if not i in node_ids:
print('Source node {} does not exist.'.format(i)) logger.info('Source node {} does not exist.'.format(i))
continue continue
if not dst_id in node_ids: if not dst_id in node_ids:
print('Destination node {} does not exist.'.format( logger.info('Destination node {} does not exist.'.format(
dst_id)) dst_id))
continue continue
edge = {'id': line_count, edge = {'id': line_count,
'src_id': i, 'dst_id': dst_id, 'type': 0} 'src_id': i, 'dst_id': dst_id, 'type': 0}
line_count += 1 line_count += 1
yield edge yield edge
print('Processed {} lines for edges.'.format(line_count)) logger.info('Processed {} lines for edges.'.format(line_count))

View File

@ -16,6 +16,7 @@
Graph data convert tool for MindRecord. Graph data convert tool for MindRecord.
""" """
import numpy as np import numpy as np
from mindspore import log as logger
__all__ = ['GraphMapSchema'] __all__ = ['GraphMapSchema']
@ -41,6 +42,7 @@ class GraphMapSchema:
"edge_feature_index": {"type": "int32", "shape": [-1]} "edge_feature_index": {"type": "int32", "shape": [-1]}
} }
@property
def get_schema(self): def get_schema(self):
""" """
Get schema Get schema
@ -52,6 +54,7 @@ class GraphMapSchema:
Set node features profile Set node features profile
""" """
if num_features != len(features_data_type) or num_features != len(features_shape): if num_features != len(features_data_type) or num_features != len(features_shape):
logger.info("Node feature profile is not match.")
raise ValueError("Node feature profile is not match.") raise ValueError("Node feature profile is not match.")
self.num_node_features = num_features self.num_node_features = num_features
@ -66,6 +69,7 @@ class GraphMapSchema:
Set edge features profile Set edge features profile
""" """
if num_features != len(features_data_type) or num_features != len(features_shape): if num_features != len(features_data_type) or num_features != len(features_shape):
logger.info("Edge feature profile is not match.")
raise ValueError("Edge feature profile is not match.") raise ValueError("Edge feature profile is not match.")
self.num_edge_features = num_features self.num_edge_features = num_features
@ -83,6 +87,10 @@ class GraphMapSchema:
Returns: Returns:
graph data with union schema graph data with union schema
""" """
if node is None:
logger.info("node cannot be None.")
raise ValueError("node cannot be None.")
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
"node_feature_index": []} "node_feature_index": []}
for i in range(self.num_node_features): for i in range(self.num_node_features):
@ -117,6 +125,10 @@ class GraphMapSchema:
Returns: Returns:
graph data with union schema graph data with union schema
""" """
if edge is None:
logger.info("edge cannot be None.")
raise ValueError("edge cannot be None.")
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
"type": edge["type"], "edge_feature_index": []} "type": edge["type"], "edge_feature_index": []}

View File

@ -164,7 +164,7 @@ if __name__ == "__main__":
num_features, feature_data_types, feature_shapes = mr_api.edge_profile num_features, feature_data_types, feature_shapes = mr_api.edge_profile
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes) graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
graph_schema = graph_map_schema.get_schema() graph_schema = graph_map_schema.get_schema
# init writer # init writer
writer = init_writer(graph_schema) writer = init_writer(graph_schema)

View File

@ -983,7 +983,9 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
continue; continue;
} }
} }
memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size); int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size,
count * type_size);
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric");
out_index += count; out_index += count;
if (i < indices.size() - 1) { if (i < indices.size() - 1) {
src_start = HandleNeg(indices[i + 1], dim_length); // next index src_start = HandleNeg(indices[i + 1], dim_length); // next index

View File

@ -101,6 +101,9 @@ class BucketBatchByLengthOp : public PipelineOp {
std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info,
bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size);
// Destructor
~BucketBatchByLengthOp() = default;
// Might need to batch remaining buckets after receiving eoe, so override this method. // Might need to batch remaining buckets after receiving eoe, so override this method.
// @param int32_t workerId // @param int32_t workerId
// @return Status - The error code returned // @return Status - The error code returned

View File

@ -36,6 +36,7 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
: mr_path_(mr_filepath), : mr_path_(mr_filepath),
num_workers_(num_workers), num_workers_(num_workers),
row_id_(0), row_id_(0),
shard_reader_(nullptr),
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,

View File

@ -37,7 +37,7 @@ namespace dataset {
// Driver method for TreePass // Driver method for TreePass
Status TreePass::Run(ExecutionTree *tree, bool *modified) { Status TreePass::Run(ExecutionTree *tree, bool *modified) {
if (!tree || !modified) { if (tree == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
} }
return this->RunOnTree(tree, modified); return this->RunOnTree(tree, modified);
@ -45,7 +45,7 @@ Status TreePass::Run(ExecutionTree *tree, bool *modified) {
// Driver method for NodePass // Driver method for NodePass
Status NodePass::Run(ExecutionTree *tree, bool *modified) { Status NodePass::Run(ExecutionTree *tree, bool *modified) {
if (!tree || !modified) { if (tree == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
} }
std::shared_ptr<DatasetOp> root = tree->root(); std::shared_ptr<DatasetOp> root = tree->root();

View File

@ -44,7 +44,7 @@ class ConnectorSize : public Sampling {
public: public:
explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {}
~ConnectorSize() = default; ~ConnectorSize() override = default;
// Driver function for connector size sampling. // Driver function for connector size sampling.
// This function samples the connector size of every nodes within the ExecutionTree // This function samples the connector size of every nodes within the ExecutionTree

View File

@ -26,6 +26,7 @@ Monitor::Monitor(ExecutionTree *tree) : tree_(tree) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
sampling_interval_ = cfg->monitor_sampling_interval(); sampling_interval_ = cfg->monitor_sampling_interval();
max_samples_ = 0; max_samples_ = 0;
cur_row_ = 0;
} }
Status Monitor::operator()() { Status Monitor::operator()() {

View File

@ -34,6 +34,8 @@ class Slice {
Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
~Slice() = default;
std::vector<dsize_t> Indices(dsize_t length) { std::vector<dsize_t> Indices(dsize_t length) {
std::vector<dsize_t> indices; std::vector<dsize_t> indices;
dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); dsize_t index = std::min(Tensor::HandleNeg(start_, length), length);

View File

@ -29,8 +29,8 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow
BOUNDING_BOX_CHECK(input); BOUNDING_BOX_CHECK(input);
if (distribution_(rnd_)) { if (distribution_(rnd_)) {
// To test bounding boxes algorithm, create random bboxes from image dims // To test bounding boxes algorithm, create random bboxes from image dims
size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes
float img_center = (input[0]->shape()[1] / 2); // get the center of the image float img_center = (input[0]->shape()[1] / 2.); // get the center of the image
for (int i = 0; i < num_of_boxes; i++) { for (int i = 0; i < num_of_boxes; i++) {
uint32_t b_w = 0; // bounding box width uint32_t b_w = 0; // bounding box width

View File

@ -49,6 +49,7 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal
preserve_unused_token_(preserve_unused_token), preserve_unused_token_(preserve_unused_token),
case_fold_(std::make_unique<CaseFoldOp>()), case_fold_(std::make_unique<CaseFoldOp>()),
nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)), nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)),
normalization_form_(normalization_form),
common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)), common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)),
replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")), replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")),
replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) { replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) {

View File

@ -35,9 +35,9 @@ class BasicTokenizerOp : public TensorOp {
static const bool kDefKeepWhitespace; static const bool kDefKeepWhitespace;
static const NormalizeForm kDefNormalizationForm; static const NormalizeForm kDefNormalizationForm;
static const bool kDefPreserveUnusedToken; static const bool kDefPreserveUnusedToken;
BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace, explicit BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace,
NormalizeForm normalization_form = kDefNormalizationForm, NormalizeForm normalization_form = kDefNormalizationForm,
bool preserve_unused_token = kDefPreserveUnusedToken); bool preserve_unused_token = kDefPreserveUnusedToken);
~BasicTokenizerOp() override = default; ~BasicTokenizerOp() override = default;

View File

@ -28,14 +28,14 @@ namespace mindspore {
namespace dataset { namespace dataset {
class BertTokenizerOp : public TensorOp { class BertTokenizerOp : public TensorOp {
public: public:
BertTokenizerOp(const std::shared_ptr<Vocab> &vocab, explicit BertTokenizerOp(const std::shared_ptr<Vocab> &vocab,
const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator,
const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken,
const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken,
bool lower_case = BasicTokenizerOp::kDefLowerCase, bool lower_case = BasicTokenizerOp::kDefLowerCase,
bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace,
NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm, NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm,
bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken) bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken)
: wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token), : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token),
basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {} basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {}

View File

@ -48,7 +48,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
// @return // @return
Status insert(const value_type &val, key_type *key = nullptr) { Status insert(const value_type &val, key_type *key = nullptr) {
key_type my_inx = inx_.fetch_add(1); key_type my_inx = inx_.fetch_add(1);
if (key) { if (key != nullptr) {
*key = my_inx; *key = my_inx;
} }
return my_tree::DoInsert(my_inx, val); return my_tree::DoInsert(my_inx, val);

View File

@ -323,7 +323,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
} }
vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) { vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
uint64_t i_size = kUnsignedOne << int_type; uint64_t i_size = kUnsignedOne << static_cast<uint8_t>(int_type);
// Get number of elements // Get number of elements
uint64_t src_n_int = src_bytes.size() / i_size; uint64_t src_n_int = src_bytes.size() / i_size;
// Calculate bitmap size (bytes) // Calculate bitmap size (bytes)
@ -344,7 +344,7 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const
// Initialize destination data type // Initialize destination data type
IntegerType dst_int_type = kInt8Type; IntegerType dst_int_type = kInt8Type;
// Shift to next int position // Shift to next int position
uint64_t pos = i * (kUnsignedOne << int_type); uint64_t pos = i * (kUnsignedOne << static_cast<uint8_t>(int_type));
// Narrow down this int // Narrow down this int
int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);

View File

@ -61,7 +61,7 @@ class Shuffle(str, Enum):
@check_zip @check_zip
def zip(datasets): def zip(datasets):
""" """
Zips the datasets in the input tuple of datasets. Zip the datasets in the input tuple of datasets.
Args: Args:
datasets (tuple of class Dataset): A tuple of datasets to be zipped together. datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
@ -152,7 +152,7 @@ class Dataset:
def get_args(self): def get_args(self):
""" """
Returns attributes (member variables) related to the current class. Return attributes (member variables) related to the current class.
Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'. Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
@ -239,7 +239,7 @@ class Dataset:
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
input_columns=None, pad_info=None): input_columns=None, pad_info=None):
""" """
Combines batch_size number of consecutive rows into batches. Combine batch_size number of consecutive rows into batches.
For any child node, a batch is treated as a single row. For any child node, a batch is treated as a single row.
For any column, all the elements within that column must have the same shape. For any column, all the elements within that column must have the same shape.
@ -340,7 +340,7 @@ class Dataset:
def flat_map(self, func): def flat_map(self, func):
""" """
Maps `func` to each row in dataset and flatten the result. Map `func` to each row in dataset and flatten the result.
The specified `func` is a function that must take one 'Ndarray' as input The specified `func` is a function that must take one 'Ndarray' as input
and return a 'Dataset'. and return a 'Dataset'.
@ -370,6 +370,7 @@ class Dataset:
""" """
dataset = None dataset = None
if not hasattr(func, '__call__'): if not hasattr(func, '__call__'):
logger.error("func must be a function.")
raise TypeError("func must be a function.") raise TypeError("func must be a function.")
for row_data in self: for row_data in self:
@ -379,6 +380,7 @@ class Dataset:
dataset += func(row_data) dataset += func(row_data)
if not isinstance(dataset, Dataset): if not isinstance(dataset, Dataset):
logger.error("flat_map must return a Dataset object.")
raise TypeError("flat_map must return a Dataset object.") raise TypeError("flat_map must return a Dataset object.")
return dataset return dataset
@ -386,7 +388,7 @@ class Dataset:
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False): num_parallel_workers=None, python_multiprocessing=False):
""" """
Applies each operation in operations to this dataset. Apply each operation in operations to this dataset.
The order of operations is determined by the position of each operation in operations. The order of operations is determined by the position of each operation in operations.
operations[0] will be applied first, then operations[1], then operations[2], etc. operations[0] will be applied first, then operations[1], then operations[2], etc.
@ -570,7 +572,7 @@ class Dataset:
@check_repeat @check_repeat
def repeat(self, count=None): def repeat(self, count=None):
""" """
Repeats this dataset count times. Repeat indefinitely if the count is None or -1. Repeat this dataset count times. Repeat indefinitely if the count is None or -1.
Note: Note:
The order of using repeat and batch reflects the number of batches. Recommend that The order of using repeat and batch reflects the number of batches. Recommend that
@ -662,13 +664,16 @@ class Dataset:
dataset_size = self.get_dataset_size() dataset_size = self.get_dataset_size()
if dataset_size is None or dataset_size <= 0: if dataset_size is None or dataset_size <= 0:
raise RuntimeError("dataset size unknown, unable to split.") raise RuntimeError("dataset_size is unknown, unable to split.")
if not isinstance(sizes, list):
raise RuntimeError("sizes should be a list.")
all_int = all(isinstance(item, int) for item in sizes) all_int = all(isinstance(item, int) for item in sizes)
if all_int: if all_int:
sizes_sum = sum(sizes) sizes_sum = sum(sizes)
if sizes_sum != dataset_size: if sizes_sum != dataset_size:
raise RuntimeError("sum of split sizes {} is not equal to dataset size {}." raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
.format(sizes_sum, dataset_size)) .format(sizes_sum, dataset_size))
return sizes return sizes
@ -676,7 +681,7 @@ class Dataset:
for item in sizes: for item in sizes:
absolute_size = int(round(item * dataset_size)) absolute_size = int(round(item * dataset_size))
if absolute_size == 0: if absolute_size == 0:
raise RuntimeError("split percentage {} is too small.".format(item)) raise RuntimeError("Split percentage {} is too small.".format(item))
absolute_sizes.append(absolute_size) absolute_sizes.append(absolute_size)
absolute_sizes_sum = sum(absolute_sizes) absolute_sizes_sum = sum(absolute_sizes)
@ -694,7 +699,7 @@ class Dataset:
break break
if sum(absolute_sizes) != dataset_size: if sum(absolute_sizes) != dataset_size:
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}." raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
.format(absolute_sizes_sum, dataset_size)) .format(absolute_sizes_sum, dataset_size))
return absolute_sizes return absolute_sizes
@ -702,7 +707,7 @@ class Dataset:
@check_split @check_split
def split(self, sizes, randomize=True): def split(self, sizes, randomize=True):
""" """
Splits the dataset into smaller, non-overlapping datasets. Split the dataset into smaller, non-overlapping datasets.
This is a general purpose split function which can be called from any operator in the pipeline. This is a general purpose split function which can be called from any operator in the pipeline.
There is another, optimized split function, which will be called automatically if ds.split is There is another, optimized split function, which will be called automatically if ds.split is
@ -759,10 +764,10 @@ class Dataset:
>>> train, test = data.split([0.9, 0.1]) >>> train, test = data.split([0.9, 0.1])
""" """
if self.is_shuffled(): if self.is_shuffled():
logger.warning("dataset is shuffled before split.") logger.warning("Dataset is shuffled before split.")
if self.is_sharded(): if self.is_sharded():
raise RuntimeError("dataset should not be sharded before split.") raise RuntimeError("Dataset should not be sharded before split.")
absolute_sizes = self._get_absolute_split_sizes(sizes) absolute_sizes = self._get_absolute_split_sizes(sizes)
splits = [] splits = []
@ -788,7 +793,7 @@ class Dataset:
@check_zip_dataset @check_zip_dataset
def zip(self, datasets): def zip(self, datasets):
""" """
Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name. Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
Args: Args:
datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset
@ -845,7 +850,7 @@ class Dataset:
@check_rename @check_rename
def rename(self, input_columns, output_columns): def rename(self, input_columns, output_columns):
""" """
Renames the columns in input datasets. Rename the columns in input datasets.
Args: Args:
input_columns (list[str]): list of names of the input columns. input_columns (list[str]): list of names of the input columns.
@ -871,7 +876,7 @@ class Dataset:
@check_project @check_project
def project(self, columns): def project(self, columns):
""" """
Projects certain columns in input datasets. Project certain columns in input datasets.
The specified columns will be selected from the dataset and passed down The specified columns will be selected from the dataset and passed down
the pipeline in the order specified. The other columns are discarded. the pipeline in the order specified. The other columns are discarded.
@ -936,7 +941,7 @@ class Dataset:
def device_que(self, prefetch_size=None): def device_que(self, prefetch_size=None):
""" """
Returns a transferredDataset that transfer data through device. Return a transferredDataset that transfer data through device.
Args: Args:
prefetch_size (int, optional): prefetch number of records ahead of the prefetch_size (int, optional): prefetch number of records ahead of the
@ -953,7 +958,7 @@ class Dataset:
def to_device(self, num_batch=None): def to_device(self, num_batch=None):
""" """
Transfers data through CPU, GPU or Ascend devices. Transfer data through CPU, GPU or Ascend devices.
Args: Args:
num_batch (int, optional): limit the number of batch to be sent to device (default=None). num_batch (int, optional): limit the number of batch to be sent to device (default=None).
@ -988,7 +993,7 @@ class Dataset:
raise TypeError("Please set device_type in context") raise TypeError("Please set device_type in context")
if device_type not in ('Ascend', 'GPU', 'CPU'): if device_type not in ('Ascend', 'GPU', 'CPU'):
raise ValueError("only support CPU, Ascend, GPU") raise ValueError("Only support CPU, Ascend, GPU")
if num_batch is None or num_batch == 0: if num_batch is None or num_batch == 0:
raise ValueError("num_batch is None or 0.") raise ValueError("num_batch is None or 0.")
@ -1089,7 +1094,7 @@ class Dataset:
def _get_pipeline_info(self): def _get_pipeline_info(self):
""" """
Gets pipeline information. Get pipeline information.
""" """
device_iter = TupleIterator(self) device_iter = TupleIterator(self)
self._output_shapes = device_iter.get_output_shapes() self._output_shapes = device_iter.get_output_shapes()
@ -1344,7 +1349,7 @@ class MappableDataset(SourceDataset):
@check_split @check_split
def split(self, sizes, randomize=True): def split(self, sizes, randomize=True):
""" """
Splits the dataset into smaller, non-overlapping datasets. Split the dataset into smaller, non-overlapping datasets.
There is the optimized split function, which will be called automatically when the dataset There is the optimized split function, which will be called automatically when the dataset
that calls this function is a MappableDataset. that calls this function is a MappableDataset.
@ -1411,10 +1416,10 @@ class MappableDataset(SourceDataset):
>>> train.use_sampler(train_sampler) >>> train.use_sampler(train_sampler)
""" """
if self.is_shuffled(): if self.is_shuffled():
logger.warning("dataset is shuffled before split.") logger.warning("Dataset is shuffled before split.")
if self.is_sharded(): if self.is_sharded():
raise RuntimeError("dataset should not be sharded before split.") raise RuntimeError("Dataset should not be sharded before split.")
absolute_sizes = self._get_absolute_split_sizes(sizes) absolute_sizes = self._get_absolute_split_sizes(sizes)
splits = [] splits = []
@ -1633,7 +1638,7 @@ class BlockReleasePair:
def __init__(self, init_release_rows, callback=None): def __init__(self, init_release_rows, callback=None):
if isinstance(init_release_rows, int) and init_release_rows <= 0: if isinstance(init_release_rows, int) and init_release_rows <= 0:
raise ValueError("release_rows need to be greater than 0.") raise ValueError("release_rows need to be greater than 0.")
self.row_count = -init_release_rows self.row_count = -init_release_rows
self.cv = threading.Condition() self.cv = threading.Condition()
self.callback = callback self.callback = callback
@ -2699,10 +2704,10 @@ class MindDataset(MappableDataset):
self.shard_id = shard_id self.shard_id = shard_id
if block_reader is True and num_shards is not None: if block_reader is True and num_shards is not None:
raise ValueError("block reader not allowed true when use partitions") raise ValueError("block_reader not allowed true when use partitions")
if block_reader is True and shuffle is True: if block_reader is True and shuffle is True:
raise ValueError("block reader not allowed true when use shuffle") raise ValueError("block_reader not allowed true when use shuffle")
if block_reader is True: if block_reader is True:
logger.warning("WARN: global shuffle is not used.") logger.warning("WARN: global shuffle is not used.")
@ -2711,14 +2716,14 @@ class MindDataset(MappableDataset):
if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler, if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
samplers.DistributedSampler, samplers.RandomSampler, samplers.DistributedSampler, samplers.RandomSampler,
samplers.SequentialSampler)) is False: samplers.SequentialSampler)) is False:
raise ValueError("the sampler is not supported yet.") raise ValueError("The sampler is not supported yet.")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
# sampler exclusive # sampler exclusive
if block_reader is True and sampler is not None: if block_reader is True and sampler is not None:
raise ValueError("block reader not allowed true when use sampler") raise ValueError("block_reader not allowed true when use sampler")
if num_padded is None: if num_padded is None:
num_padded = 0 num_padded = 0
@ -2770,7 +2775,7 @@ class MindDataset(MappableDataset):
if value >= 0: if value >= 0:
self._dataset_size = value self._dataset_size = value
else: else:
raise ValueError('set dataset_size with negative value {}'.format(value)) raise ValueError('Set dataset_size with negative value {}'.format(value))
def is_shuffled(self): def is_shuffled(self):
if self.shuffle_option is None: if self.shuffle_option is None:
@ -2872,7 +2877,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
def _fetch_py_sampler_indices(sampler, num_samples): def _fetch_py_sampler_indices(sampler, num_samples):
""" """
Indices fetcher for python sampler. Indice fetcher for python sampler.
""" """
if num_samples is not None: if num_samples is not None:
sampler_iter = iter(sampler) sampler_iter = iter(sampler)
@ -3163,7 +3168,7 @@ class GeneratorDataset(MappableDataset):
if value >= 0: if value >= 0:
self._dataset_size = value self._dataset_size = value
else: else:
raise ValueError('set dataset_size with negative value {}'.format(value)) raise ValueError('Set dataset_size with negative value {}'.format(value))
def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
if id(self) in memodict: if id(self) in memodict:
@ -3313,7 +3318,7 @@ class TFRecordDataset(SourceDataset):
if value >= 0: if value >= 0:
self._dataset_size = value self._dataset_size = value
else: else:
raise ValueError('set dataset_size with negative value {}'.format(value)) raise ValueError('Set dataset_size with negative value {}'.format(value))
def is_shuffled(self): def is_shuffled(self):
return self.shuffle_files return self.shuffle_files
@ -4382,7 +4387,9 @@ class CelebADataset(MappableDataset):
try: try:
with open(attr_file, 'r') as f: with open(attr_file, 'r') as f:
num_rows = int(f.readline()) num_rows = int(f.readline())
except Exception: except FileNotFoundError:
raise RuntimeError("attr_file not found.")
except BaseException:
raise RuntimeError("Get dataset size failed from attribution file.") raise RuntimeError("Get dataset size failed from attribution file.")
rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_per_shard = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None: if self.num_samples is not None:

View File

@ -319,7 +319,7 @@ class PKSampler(BuiltinSampler):
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
if num_class is not None: if num_class is not None:
raise NotImplementedError raise NotImplementedError("Not support specify num_class")
if not isinstance(shuffle, bool): if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
@ -551,8 +551,8 @@ class WeightedRandomSampler(BuiltinSampler):
Args: Args:
weights (list[float]): A sequence of weights, not necessarily summing up to 1. weights (list[float]): A sequence of weights, not necessarily summing up to 1.
num_samples (int): Number of elements to sample (default=None, all elements). num_samples (int, optional): Number of elements to sample (default=None, all elements).
replacement (bool, optional): If True, put the sample ID back for the next draw (default=True). replacement (bool): If True, put the sample ID back for the next draw (default=True).
Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds

View File

@ -50,7 +50,7 @@ def check_filename(path):
Exception: when error Exception: when error
""" """
if not isinstance(path, str): if not isinstance(path, str):
raise ValueError("path: {} is not string".format(path)) raise TypeError("path: {} is not string".format(path))
filename = os.path.basename(path) filename = os.path.basename(path)
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
@ -143,7 +143,7 @@ def check_sampler_shuffle_shard_options(param_dict):
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise ValueError("sampler is not a valid Sampler type.") raise TypeError("sampler is not a valid Sampler type.")
if sampler is not None: if sampler is not None:
if shuffle is not None: if shuffle is not None:
@ -328,13 +328,13 @@ def check_vocdataset(method):
if task is None: if task is None:
raise ValueError("task is not provided.") raise ValueError("task is not provided.")
if not isinstance(task, str): if not isinstance(task, str):
raise ValueError("task is not str type.") raise TypeError("task is not str type.")
# check mode; required argument # check mode; required argument
mode = param_dict.get('mode') mode = param_dict.get('mode')
if mode is None: if mode is None:
raise ValueError("mode is not provided.") raise ValueError("mode is not provided.")
if not isinstance(mode, str): if not isinstance(mode, str):
raise ValueError("mode is not str type.") raise TypeError("mode is not str type.")
imagesets_file = "" imagesets_file = ""
if task == "Segmentation": if task == "Segmentation":
@ -388,7 +388,7 @@ def check_cocodataset(method):
if task is None: if task is None:
raise ValueError("task is not provided.") raise ValueError("task is not provided.")
if not isinstance(task, str): if not isinstance(task, str):
raise ValueError("task is not str type.") raise TypeError("task is not str type.")
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
raise ValueError("Invalid task type") raise ValueError("Invalid task type")
@ -556,7 +556,7 @@ def check_generatordataset(method):
def check_batch_size(batch_size): def check_batch_size(batch_size):
if not (isinstance(batch_size, int) or (callable(batch_size))): if not (isinstance(batch_size, int) or (callable(batch_size))):
raise ValueError("batch_size should either be an int or a callable.") raise TypeError("batch_size should either be an int or a callable.")
if callable(batch_size): if callable(batch_size):
sig = ins.signature(batch_size) sig = ins.signature(batch_size)
if len(sig.parameters) != 1: if len(sig.parameters) != 1:
@ -706,6 +706,7 @@ def check_batch(method):
def check_sync_wait(method): def check_sync_wait(method):
"""check the input arguments of sync_wait.""" """check the input arguments of sync_wait."""
@wraps(method) @wraps(method)
def new_method(*args, **kwargs): def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs) param_dict = make_param_dict(method, args, kwargs)
@ -773,7 +774,7 @@ def check_filter(method):
param_dict = make_param_dict(method, args, kwargs) param_dict = make_param_dict(method, args, kwargs)
predicate = param_dict.get("predicate") predicate = param_dict.get("predicate")
if not callable(predicate): if not callable(predicate):
raise ValueError("Predicate should be a python function or a callable python object.") raise TypeError("Predicate should be a python function or a callable python object.")
nreq_param_int = ['num_parallel_workers'] nreq_param_int = ['num_parallel_workers']
check_param_type(nreq_param_int, param_dict, int) check_param_type(nreq_param_int, param_dict, int)
@ -865,7 +866,7 @@ def check_zip_dataset(method):
raise ValueError("datasets is not provided.") raise ValueError("datasets is not provided.")
if not isinstance(ds, (tuple, datasets.Dataset)): if not isinstance(ds, (tuple, datasets.Dataset)):
raise ValueError("datasets is not tuple or of type Dataset.") raise TypeError("datasets is not tuple or of type Dataset.")
return method(*args, **kwargs) return method(*args, **kwargs)
@ -885,7 +886,7 @@ def check_concat(method):
raise ValueError("datasets is not provided.") raise ValueError("datasets is not provided.")
if not isinstance(ds, (list, datasets.Dataset)): if not isinstance(ds, (list, datasets.Dataset)):
raise ValueError("datasets is not list or of type Dataset.") raise TypeError("datasets is not list or of type Dataset.")
return method(*args, **kwargs) return method(*args, **kwargs)
@ -964,7 +965,7 @@ def check_add_column(method):
de_type = param_dict.get("de_type") de_type = param_dict.get("de_type")
if de_type is not None: if de_type is not None:
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
raise ValueError("Unknown column type.") raise TypeError("Unknown column type.")
else: else:
raise TypeError("Expected non-empty string.") raise TypeError("Expected non-empty string.")

View File

@ -10,6 +10,6 @@ wheel >= 0.32.0
decorator >= 4.4.0 decorator >= 4.4.0
setuptools >= 40.8.0 setuptools >= 40.8.0
matplotlib >= 3.1.3 # for ut test matplotlib >= 3.1.3 # for ut test
opencv-python >= 4.2.0.32 # for ut test opencv-python >= 4.1.2.30 # for ut test
sklearn >= 0.0 # for st test sklearn >= 0.0 # for st test
pandas >= 1.0.2 # for ut test pandas >= 1.0.2 # for ut test

View File

@ -42,15 +42,15 @@ def split_with_invalid_inputs(d):
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
_, _ = d.split([3, 1]) _, _ = d.split([3, 1])
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value) assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
_, _ = d.split([5, 1]) _, _ = d.split([5, 1])
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value) assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
_, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
_, _ = d.split([-0.5, 0.5]) _, _ = d.split([-0.5, 0.5])
@ -80,7 +80,7 @@ def test_unmappable_invalid_input():
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
_, _ = d.split([4, 1]) _, _ = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value) assert "Dataset should not be sharded before split" in str(info.value)
def test_unmappable_split(): def test_unmappable_split():
@ -274,7 +274,7 @@ def test_mappable_invalid_input():
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
_, _ = d.split([4, 1]) _, _ = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value) assert "Dataset should not be sharded before split" in str(info.value)
def test_mappable_split_general(): def test_mappable_split_general():