!2306 [Dataset] Code review & improve quality

This commit is contained in:
YangLuo 2020-06-18 21:03:55 +08:00
parent 83b53559f5
commit 4e3bfcf4c9
21 changed files with 110 additions and 79 deletions

View File

@ -20,6 +20,7 @@ import os
import pickle as pkl
import numpy as np
import scipy.sparse as sp
from mindspore import log as logger
# parse args from command line parameter 'graph_api_args'
# args delimiter is ':'
@ -58,7 +59,7 @@ def yield_nodes(task_id=0):
Yields:
data (dict): data row which is dict.
"""
print("Node task is {}".format(task_id))
logger.info("Node task is {}".format(task_id))
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
objects = []
for name in names:
@ -98,7 +99,7 @@ def yield_nodes(task_id=0):
line_count += 1
node_ids.append(i)
yield node
print('Processed {} lines for nodes.'.format(line_count))
logger.info('Processed {} lines for nodes.'.format(line_count))
def yield_edges(task_id=0):
@ -108,21 +109,21 @@ def yield_edges(task_id=0):
Yields:
data (dict): data row which is dict.
"""
print("Edge task is {}".format(task_id))
logger.info("Edge task is {}".format(task_id))
with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
graph = pkl.load(f, encoding='latin1')
line_count = 0
for i in graph:
for dst_id in graph[i]:
if not i in node_ids:
print('Source node {} does not exist.'.format(i))
logger.info('Source node {} does not exist.'.format(i))
continue
if not dst_id in node_ids:
print('Destination node {} does not exist.'.format(
logger.info('Destination node {} does not exist.'.format(
dst_id))
continue
edge = {'id': line_count,
'src_id': i, 'dst_id': dst_id, 'type': 0}
line_count += 1
yield edge
print('Processed {} lines for edges.'.format(line_count))
logger.info('Processed {} lines for edges.'.format(line_count))

View File

@ -16,6 +16,7 @@
Graph data convert tool for MindRecord.
"""
import numpy as np
from mindspore import log as logger
__all__ = ['GraphMapSchema']
@ -41,6 +42,7 @@ class GraphMapSchema:
"edge_feature_index": {"type": "int32", "shape": [-1]}
}
@property
def get_schema(self):
"""
Get schema
@ -52,6 +54,7 @@ class GraphMapSchema:
Set node features profile
"""
if num_features != len(features_data_type) or num_features != len(features_shape):
logger.info("Node feature profile is not match.")
raise ValueError("Node feature profile is not match.")
self.num_node_features = num_features
@ -66,6 +69,7 @@ class GraphMapSchema:
Set edge features profile
"""
if num_features != len(features_data_type) or num_features != len(features_shape):
logger.info("Edge feature profile is not match.")
raise ValueError("Edge feature profile is not match.")
self.num_edge_features = num_features
@ -83,6 +87,10 @@ class GraphMapSchema:
Returns:
graph data with union schema
"""
if node is None:
logger.info("node cannot be None.")
raise ValueError("node cannot be None.")
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
"node_feature_index": []}
for i in range(self.num_node_features):
@ -117,6 +125,10 @@ class GraphMapSchema:
Returns:
graph data with union schema
"""
if edge is None:
logger.info("edge cannot be None.")
raise ValueError("edge cannot be None.")
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
"type": edge["type"], "edge_feature_index": []}

View File

@ -164,7 +164,7 @@ if __name__ == "__main__":
num_features, feature_data_types, feature_shapes = mr_api.edge_profile
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
graph_schema = graph_map_schema.get_schema()
graph_schema = graph_map_schema.get_schema
# init writer
writer = init_writer(graph_schema)

View File

@ -983,7 +983,9 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
continue;
}
}
memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size);
int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size,
count * type_size);
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric");
out_index += count;
if (i < indices.size() - 1) {
src_start = HandleNeg(indices[i + 1], dim_length); // next index

View File

@ -101,6 +101,9 @@ class BucketBatchByLengthOp : public PipelineOp {
std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info,
bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size);
// Destructor
~BucketBatchByLengthOp() = default;
// Might need to batch remaining buckets after receiving eoe, so override this method.
// @param int32_t workerId
// @return Status - The error code returned

View File

@ -36,6 +36,7 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
: mr_path_(mr_filepath),
num_workers_(num_workers),
row_id_(0),
shard_reader_(nullptr),
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,

View File

@ -37,7 +37,7 @@ namespace dataset {
// Driver method for TreePass
Status TreePass::Run(ExecutionTree *tree, bool *modified) {
if (!tree || !modified) {
if (tree == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
}
return this->RunOnTree(tree, modified);
@ -45,7 +45,7 @@ Status TreePass::Run(ExecutionTree *tree, bool *modified) {
// Driver method for NodePass
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
if (!tree || !modified) {
if (tree == nullptr || modified == nullptr) {
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
}
std::shared_ptr<DatasetOp> root = tree->root();

View File

@ -44,7 +44,7 @@ class ConnectorSize : public Sampling {
public:
explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {}
~ConnectorSize() = default;
~ConnectorSize() override = default;
// Driver function for connector size sampling.
// This function samples the connector size of every nodes within the ExecutionTree

View File

@ -26,6 +26,7 @@ Monitor::Monitor(ExecutionTree *tree) : tree_(tree) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
sampling_interval_ = cfg->monitor_sampling_interval();
max_samples_ = 0;
cur_row_ = 0;
}
Status Monitor::operator()() {

View File

@ -34,6 +34,8 @@ class Slice {
Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
~Slice() = default;
std::vector<dsize_t> Indices(dsize_t length) {
std::vector<dsize_t> indices;
dsize_t index = std::min(Tensor::HandleNeg(start_, length), length);

View File

@ -29,8 +29,8 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow
BOUNDING_BOX_CHECK(input);
if (distribution_(rnd_)) {
// To test bounding boxes algorithm, create random bboxes from image dims
size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes
float img_center = (input[0]->shape()[1] / 2); // get the center of the image
size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes
float img_center = (input[0]->shape()[1] / 2.); // get the center of the image
for (int i = 0; i < num_of_boxes; i++) {
uint32_t b_w = 0; // bounding box width

View File

@ -49,6 +49,7 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal
preserve_unused_token_(preserve_unused_token),
case_fold_(std::make_unique<CaseFoldOp>()),
nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)),
normalization_form_(normalization_form),
common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)),
replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")),
replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) {

View File

@ -35,9 +35,9 @@ class BasicTokenizerOp : public TensorOp {
static const bool kDefKeepWhitespace;
static const NormalizeForm kDefNormalizationForm;
static const bool kDefPreserveUnusedToken;
BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace,
NormalizeForm normalization_form = kDefNormalizationForm,
bool preserve_unused_token = kDefPreserveUnusedToken);
explicit BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace,
NormalizeForm normalization_form = kDefNormalizationForm,
bool preserve_unused_token = kDefPreserveUnusedToken);
~BasicTokenizerOp() override = default;

View File

@ -28,14 +28,14 @@ namespace mindspore {
namespace dataset {
class BertTokenizerOp : public TensorOp {
public:
BertTokenizerOp(const std::shared_ptr<Vocab> &vocab,
const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator,
const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken,
const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken,
bool lower_case = BasicTokenizerOp::kDefLowerCase,
bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace,
NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm,
bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken)
explicit BertTokenizerOp(const std::shared_ptr<Vocab> &vocab,
const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator,
const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken,
const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken,
bool lower_case = BasicTokenizerOp::kDefLowerCase,
bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace,
NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm,
bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken)
: wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token),
basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {}

View File

@ -48,7 +48,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
// @return
Status insert(const value_type &val, key_type *key = nullptr) {
key_type my_inx = inx_.fetch_add(1);
if (key) {
if (key != nullptr) {
*key = my_inx;
}
return my_tree::DoInsert(my_inx, val);

View File

@ -323,7 +323,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
}
vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
uint64_t i_size = kUnsignedOne << int_type;
uint64_t i_size = kUnsignedOne << static_cast<uint8_t>(int_type);
// Get number of elements
uint64_t src_n_int = src_bytes.size() / i_size;
// Calculate bitmap size (bytes)
@ -344,7 +344,7 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const
// Initialize destination data type
IntegerType dst_int_type = kInt8Type;
// Shift to next int position
uint64_t pos = i * (kUnsignedOne << int_type);
uint64_t pos = i * (kUnsignedOne << static_cast<uint8_t>(int_type));
// Narrow down this int
int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);

View File

@ -61,7 +61,7 @@ class Shuffle(str, Enum):
@check_zip
def zip(datasets):
"""
Zips the datasets in the input tuple of datasets.
Zip the datasets in the input tuple of datasets.
Args:
datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
@ -152,7 +152,7 @@ class Dataset:
def get_args(self):
"""
Returns attributes (member variables) related to the current class.
Return attributes (member variables) related to the current class.
Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
@ -239,7 +239,7 @@ class Dataset:
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
input_columns=None, pad_info=None):
"""
Combines batch_size number of consecutive rows into batches.
Combine batch_size number of consecutive rows into batches.
For any child node, a batch is treated as a single row.
For any column, all the elements within that column must have the same shape.
@ -340,7 +340,7 @@ class Dataset:
def flat_map(self, func):
"""
Maps `func` to each row in dataset and flatten the result.
Map `func` to each row in dataset and flatten the result.
The specified `func` is a function that must take one 'Ndarray' as input
and return a 'Dataset'.
@ -370,6 +370,7 @@ class Dataset:
"""
dataset = None
if not hasattr(func, '__call__'):
logger.error("func must be a function.")
raise TypeError("func must be a function.")
for row_data in self:
@ -379,6 +380,7 @@ class Dataset:
dataset += func(row_data)
if not isinstance(dataset, Dataset):
logger.error("flat_map must return a Dataset object.")
raise TypeError("flat_map must return a Dataset object.")
return dataset
@ -386,7 +388,7 @@ class Dataset:
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False):
"""
Applies each operation in operations to this dataset.
Apply each operation in operations to this dataset.
The order of operations is determined by the position of each operation in operations.
operations[0] will be applied first, then operations[1], then operations[2], etc.
@ -570,7 +572,7 @@ class Dataset:
@check_repeat
def repeat(self, count=None):
"""
Repeats this dataset count times. Repeat indefinitely if the count is None or -1.
Repeat this dataset count times. Repeat indefinitely if the count is None or -1.
Note:
The order of using repeat and batch reflects the number of batches. Recommend that
@ -662,13 +664,16 @@ class Dataset:
dataset_size = self.get_dataset_size()
if dataset_size is None or dataset_size <= 0:
raise RuntimeError("dataset size unknown, unable to split.")
raise RuntimeError("dataset_size is unknown, unable to split.")
if not isinstance(sizes, list):
raise RuntimeError("sizes should be a list.")
all_int = all(isinstance(item, int) for item in sizes)
if all_int:
sizes_sum = sum(sizes)
if sizes_sum != dataset_size:
raise RuntimeError("sum of split sizes {} is not equal to dataset size {}."
raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
.format(sizes_sum, dataset_size))
return sizes
@ -676,7 +681,7 @@ class Dataset:
for item in sizes:
absolute_size = int(round(item * dataset_size))
if absolute_size == 0:
raise RuntimeError("split percentage {} is too small.".format(item))
raise RuntimeError("Split percentage {} is too small.".format(item))
absolute_sizes.append(absolute_size)
absolute_sizes_sum = sum(absolute_sizes)
@ -694,7 +699,7 @@ class Dataset:
break
if sum(absolute_sizes) != dataset_size:
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
.format(absolute_sizes_sum, dataset_size))
return absolute_sizes
@ -702,7 +707,7 @@ class Dataset:
@check_split
def split(self, sizes, randomize=True):
"""
Splits the dataset into smaller, non-overlapping datasets.
Split the dataset into smaller, non-overlapping datasets.
This is a general purpose split function which can be called from any operator in the pipeline.
There is another, optimized split function, which will be called automatically if ds.split is
@ -759,10 +764,10 @@ class Dataset:
>>> train, test = data.split([0.9, 0.1])
"""
if self.is_shuffled():
logger.warning("dataset is shuffled before split.")
logger.warning("Dataset is shuffled before split.")
if self.is_sharded():
raise RuntimeError("dataset should not be sharded before split.")
raise RuntimeError("Dataset should not be sharded before split.")
absolute_sizes = self._get_absolute_split_sizes(sizes)
splits = []
@ -788,7 +793,7 @@ class Dataset:
@check_zip_dataset
def zip(self, datasets):
"""
Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
Args:
datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset
@ -845,7 +850,7 @@ class Dataset:
@check_rename
def rename(self, input_columns, output_columns):
"""
Renames the columns in input datasets.
Rename the columns in input datasets.
Args:
input_columns (list[str]): list of names of the input columns.
@ -871,7 +876,7 @@ class Dataset:
@check_project
def project(self, columns):
"""
Projects certain columns in input datasets.
Project certain columns in input datasets.
The specified columns will be selected from the dataset and passed down
the pipeline in the order specified. The other columns are discarded.
@ -936,7 +941,7 @@ class Dataset:
def device_que(self, prefetch_size=None):
"""
Returns a transferredDataset that transfer data through device.
Return a transferredDataset that transfer data through device.
Args:
prefetch_size (int, optional): prefetch number of records ahead of the
@ -953,7 +958,7 @@ class Dataset:
def to_device(self, num_batch=None):
"""
Transfers data through CPU, GPU or Ascend devices.
Transfer data through CPU, GPU or Ascend devices.
Args:
num_batch (int, optional): limit the number of batch to be sent to device (default=None).
@ -988,7 +993,7 @@ class Dataset:
raise TypeError("Please set device_type in context")
if device_type not in ('Ascend', 'GPU', 'CPU'):
raise ValueError("only support CPU, Ascend, GPU")
raise ValueError("Only support CPU, Ascend, GPU")
if num_batch is None or num_batch == 0:
raise ValueError("num_batch is None or 0.")
@ -1089,7 +1094,7 @@ class Dataset:
def _get_pipeline_info(self):
"""
Gets pipeline information.
Get pipeline information.
"""
device_iter = TupleIterator(self)
self._output_shapes = device_iter.get_output_shapes()
@ -1344,7 +1349,7 @@ class MappableDataset(SourceDataset):
@check_split
def split(self, sizes, randomize=True):
"""
Splits the dataset into smaller, non-overlapping datasets.
Split the dataset into smaller, non-overlapping datasets.
There is the optimized split function, which will be called automatically when the dataset
that calls this function is a MappableDataset.
@ -1411,10 +1416,10 @@ class MappableDataset(SourceDataset):
>>> train.use_sampler(train_sampler)
"""
if self.is_shuffled():
logger.warning("dataset is shuffled before split.")
logger.warning("Dataset is shuffled before split.")
if self.is_sharded():
raise RuntimeError("dataset should not be sharded before split.")
raise RuntimeError("Dataset should not be sharded before split.")
absolute_sizes = self._get_absolute_split_sizes(sizes)
splits = []
@ -1633,7 +1638,7 @@ class BlockReleasePair:
def __init__(self, init_release_rows, callback=None):
if isinstance(init_release_rows, int) and init_release_rows <= 0:
raise ValueError("release_rows need to be greater than 0.")
raise ValueError("release_rows need to be greater than 0.")
self.row_count = -init_release_rows
self.cv = threading.Condition()
self.callback = callback
@ -2699,10 +2704,10 @@ class MindDataset(MappableDataset):
self.shard_id = shard_id
if block_reader is True and num_shards is not None:
raise ValueError("block reader not allowed true when use partitions")
raise ValueError("block_reader not allowed true when use partitions")
if block_reader is True and shuffle is True:
raise ValueError("block reader not allowed true when use shuffle")
raise ValueError("block_reader not allowed true when use shuffle")
if block_reader is True:
logger.warning("WARN: global shuffle is not used.")
@ -2711,14 +2716,14 @@ class MindDataset(MappableDataset):
if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
samplers.DistributedSampler, samplers.RandomSampler,
samplers.SequentialSampler)) is False:
raise ValueError("the sampler is not supported yet.")
raise ValueError("The sampler is not supported yet.")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
# sampler exclusive
if block_reader is True and sampler is not None:
raise ValueError("block reader not allowed true when use sampler")
raise ValueError("block_reader not allowed true when use sampler")
if num_padded is None:
num_padded = 0
@ -2770,7 +2775,7 @@ class MindDataset(MappableDataset):
if value >= 0:
self._dataset_size = value
else:
raise ValueError('set dataset_size with negative value {}'.format(value))
raise ValueError('Set dataset_size with negative value {}'.format(value))
def is_shuffled(self):
if self.shuffle_option is None:
@ -2872,7 +2877,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
def _fetch_py_sampler_indices(sampler, num_samples):
"""
Indices fetcher for python sampler.
Indice fetcher for python sampler.
"""
if num_samples is not None:
sampler_iter = iter(sampler)
@ -3163,7 +3168,7 @@ class GeneratorDataset(MappableDataset):
if value >= 0:
self._dataset_size = value
else:
raise ValueError('set dataset_size with negative value {}'.format(value))
raise ValueError('Set dataset_size with negative value {}'.format(value))
def __deepcopy__(self, memodict):
if id(self) in memodict:
@ -3313,7 +3318,7 @@ class TFRecordDataset(SourceDataset):
if value >= 0:
self._dataset_size = value
else:
raise ValueError('set dataset_size with negative value {}'.format(value))
raise ValueError('Set dataset_size with negative value {}'.format(value))
def is_shuffled(self):
return self.shuffle_files
@ -4382,7 +4387,9 @@ class CelebADataset(MappableDataset):
try:
with open(attr_file, 'r') as f:
num_rows = int(f.readline())
except Exception:
except FileNotFoundError:
raise RuntimeError("attr_file not found.")
except BaseException:
raise RuntimeError("Get dataset size failed from attribution file.")
rows_per_shard = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None:

View File

@ -319,7 +319,7 @@ class PKSampler(BuiltinSampler):
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
if num_class is not None:
raise NotImplementedError
raise NotImplementedError("Not support specify num_class")
if not isinstance(shuffle, bool):
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
@ -551,8 +551,8 @@ class WeightedRandomSampler(BuiltinSampler):
Args:
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
num_samples (int): Number of elements to sample (default=None, all elements).
replacement (bool, optional): If True, put the sample ID back for the next draw (default=True).
num_samples (int, optional): Number of elements to sample (default=None, all elements).
replacement (bool): If True, put the sample ID back for the next draw (default=True).
Examples:
>>> import mindspore.dataset as ds

View File

@ -50,7 +50,7 @@ def check_filename(path):
Exception: when error
"""
if not isinstance(path, str):
raise ValueError("path: {} is not string".format(path))
raise TypeError("path: {} is not string".format(path))
filename = os.path.basename(path)
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
@ -143,7 +143,7 @@ def check_sampler_shuffle_shard_options(param_dict):
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise ValueError("sampler is not a valid Sampler type.")
raise TypeError("sampler is not a valid Sampler type.")
if sampler is not None:
if shuffle is not None:
@ -328,13 +328,13 @@ def check_vocdataset(method):
if task is None:
raise ValueError("task is not provided.")
if not isinstance(task, str):
raise ValueError("task is not str type.")
raise TypeError("task is not str type.")
# check mode; required argument
mode = param_dict.get('mode')
if mode is None:
raise ValueError("mode is not provided.")
if not isinstance(mode, str):
raise ValueError("mode is not str type.")
raise TypeError("mode is not str type.")
imagesets_file = ""
if task == "Segmentation":
@ -388,7 +388,7 @@ def check_cocodataset(method):
if task is None:
raise ValueError("task is not provided.")
if not isinstance(task, str):
raise ValueError("task is not str type.")
raise TypeError("task is not str type.")
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
raise ValueError("Invalid task type")
@ -556,7 +556,7 @@ def check_generatordataset(method):
def check_batch_size(batch_size):
if not (isinstance(batch_size, int) or (callable(batch_size))):
raise ValueError("batch_size should either be an int or a callable.")
raise TypeError("batch_size should either be an int or a callable.")
if callable(batch_size):
sig = ins.signature(batch_size)
if len(sig.parameters) != 1:
@ -706,6 +706,7 @@ def check_batch(method):
def check_sync_wait(method):
"""check the input arguments of sync_wait."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
@ -773,7 +774,7 @@ def check_filter(method):
param_dict = make_param_dict(method, args, kwargs)
predicate = param_dict.get("predicate")
if not callable(predicate):
raise ValueError("Predicate should be a python function or a callable python object.")
raise TypeError("Predicate should be a python function or a callable python object.")
nreq_param_int = ['num_parallel_workers']
check_param_type(nreq_param_int, param_dict, int)
@ -865,7 +866,7 @@ def check_zip_dataset(method):
raise ValueError("datasets is not provided.")
if not isinstance(ds, (tuple, datasets.Dataset)):
raise ValueError("datasets is not tuple or of type Dataset.")
raise TypeError("datasets is not tuple or of type Dataset.")
return method(*args, **kwargs)
@ -885,7 +886,7 @@ def check_concat(method):
raise ValueError("datasets is not provided.")
if not isinstance(ds, (list, datasets.Dataset)):
raise ValueError("datasets is not list or of type Dataset.")
raise TypeError("datasets is not list or of type Dataset.")
return method(*args, **kwargs)
@ -964,7 +965,7 @@ def check_add_column(method):
de_type = param_dict.get("de_type")
if de_type is not None:
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
raise ValueError("Unknown column type.")
raise TypeError("Unknown column type.")
else:
raise TypeError("Expected non-empty string.")

View File

@ -10,6 +10,6 @@ wheel >= 0.32.0
decorator >= 4.4.0
setuptools >= 40.8.0
matplotlib >= 3.1.3 # for ut test
opencv-python >= 4.2.0.32 # for ut test
opencv-python >= 4.1.2.30 # for ut test
sklearn >= 0.0 # for st test
pandas >= 1.0.2 # for ut test

View File

@ -42,15 +42,15 @@ def split_with_invalid_inputs(d):
with pytest.raises(RuntimeError) as info:
_, _ = d.split([3, 1])
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info:
_, _ = d.split([5, 1])
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(RuntimeError) as info:
_, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
with pytest.raises(ValueError) as info:
_, _ = d.split([-0.5, 0.5])
@ -80,7 +80,7 @@ def test_unmappable_invalid_input():
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info:
_, _ = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
assert "Dataset should not be sharded before split" in str(info.value)
def test_unmappable_split():
@ -274,7 +274,7 @@ def test_mappable_invalid_input():
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
with pytest.raises(RuntimeError) as info:
_, _ = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
assert "Dataset should not be sharded before split" in str(info.value)
def test_mappable_split_general():