forked from mindspore-Ecosystem/mindspore
!2306 [Dataset] Code review & improve quality
This commit is contained in:
parent
83b53559f5
commit
4e3bfcf4c9
|
@ -20,6 +20,7 @@ import os
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse as sp
|
import scipy.sparse as sp
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
# parse args from command line parameter 'graph_api_args'
|
# parse args from command line parameter 'graph_api_args'
|
||||||
# args delimiter is ':'
|
# args delimiter is ':'
|
||||||
|
@ -58,7 +59,7 @@ def yield_nodes(task_id=0):
|
||||||
Yields:
|
Yields:
|
||||||
data (dict): data row which is dict.
|
data (dict): data row which is dict.
|
||||||
"""
|
"""
|
||||||
print("Node task is {}".format(task_id))
|
logger.info("Node task is {}".format(task_id))
|
||||||
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
|
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
|
||||||
objects = []
|
objects = []
|
||||||
for name in names:
|
for name in names:
|
||||||
|
@ -98,7 +99,7 @@ def yield_nodes(task_id=0):
|
||||||
line_count += 1
|
line_count += 1
|
||||||
node_ids.append(i)
|
node_ids.append(i)
|
||||||
yield node
|
yield node
|
||||||
print('Processed {} lines for nodes.'.format(line_count))
|
logger.info('Processed {} lines for nodes.'.format(line_count))
|
||||||
|
|
||||||
|
|
||||||
def yield_edges(task_id=0):
|
def yield_edges(task_id=0):
|
||||||
|
@ -108,21 +109,21 @@ def yield_edges(task_id=0):
|
||||||
Yields:
|
Yields:
|
||||||
data (dict): data row which is dict.
|
data (dict): data row which is dict.
|
||||||
"""
|
"""
|
||||||
print("Edge task is {}".format(task_id))
|
logger.info("Edge task is {}".format(task_id))
|
||||||
with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
|
with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
|
||||||
graph = pkl.load(f, encoding='latin1')
|
graph = pkl.load(f, encoding='latin1')
|
||||||
line_count = 0
|
line_count = 0
|
||||||
for i in graph:
|
for i in graph:
|
||||||
for dst_id in graph[i]:
|
for dst_id in graph[i]:
|
||||||
if not i in node_ids:
|
if not i in node_ids:
|
||||||
print('Source node {} does not exist.'.format(i))
|
logger.info('Source node {} does not exist.'.format(i))
|
||||||
continue
|
continue
|
||||||
if not dst_id in node_ids:
|
if not dst_id in node_ids:
|
||||||
print('Destination node {} does not exist.'.format(
|
logger.info('Destination node {} does not exist.'.format(
|
||||||
dst_id))
|
dst_id))
|
||||||
continue
|
continue
|
||||||
edge = {'id': line_count,
|
edge = {'id': line_count,
|
||||||
'src_id': i, 'dst_id': dst_id, 'type': 0}
|
'src_id': i, 'dst_id': dst_id, 'type': 0}
|
||||||
line_count += 1
|
line_count += 1
|
||||||
yield edge
|
yield edge
|
||||||
print('Processed {} lines for edges.'.format(line_count))
|
logger.info('Processed {} lines for edges.'.format(line_count))
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
Graph data convert tool for MindRecord.
|
Graph data convert tool for MindRecord.
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
__all__ = ['GraphMapSchema']
|
__all__ = ['GraphMapSchema']
|
||||||
|
|
||||||
|
@ -41,6 +42,7 @@ class GraphMapSchema:
|
||||||
"edge_feature_index": {"type": "int32", "shape": [-1]}
|
"edge_feature_index": {"type": "int32", "shape": [-1]}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
def get_schema(self):
|
def get_schema(self):
|
||||||
"""
|
"""
|
||||||
Get schema
|
Get schema
|
||||||
|
@ -52,6 +54,7 @@ class GraphMapSchema:
|
||||||
Set node features profile
|
Set node features profile
|
||||||
"""
|
"""
|
||||||
if num_features != len(features_data_type) or num_features != len(features_shape):
|
if num_features != len(features_data_type) or num_features != len(features_shape):
|
||||||
|
logger.info("Node feature profile is not match.")
|
||||||
raise ValueError("Node feature profile is not match.")
|
raise ValueError("Node feature profile is not match.")
|
||||||
|
|
||||||
self.num_node_features = num_features
|
self.num_node_features = num_features
|
||||||
|
@ -66,6 +69,7 @@ class GraphMapSchema:
|
||||||
Set edge features profile
|
Set edge features profile
|
||||||
"""
|
"""
|
||||||
if num_features != len(features_data_type) or num_features != len(features_shape):
|
if num_features != len(features_data_type) or num_features != len(features_shape):
|
||||||
|
logger.info("Edge feature profile is not match.")
|
||||||
raise ValueError("Edge feature profile is not match.")
|
raise ValueError("Edge feature profile is not match.")
|
||||||
|
|
||||||
self.num_edge_features = num_features
|
self.num_edge_features = num_features
|
||||||
|
@ -83,6 +87,10 @@ class GraphMapSchema:
|
||||||
Returns:
|
Returns:
|
||||||
graph data with union schema
|
graph data with union schema
|
||||||
"""
|
"""
|
||||||
|
if node is None:
|
||||||
|
logger.info("node cannot be None.")
|
||||||
|
raise ValueError("node cannot be None.")
|
||||||
|
|
||||||
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
|
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
|
||||||
"node_feature_index": []}
|
"node_feature_index": []}
|
||||||
for i in range(self.num_node_features):
|
for i in range(self.num_node_features):
|
||||||
|
@ -117,6 +125,10 @@ class GraphMapSchema:
|
||||||
Returns:
|
Returns:
|
||||||
graph data with union schema
|
graph data with union schema
|
||||||
"""
|
"""
|
||||||
|
if edge is None:
|
||||||
|
logger.info("edge cannot be None.")
|
||||||
|
raise ValueError("edge cannot be None.")
|
||||||
|
|
||||||
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
|
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
|
||||||
"type": edge["type"], "edge_feature_index": []}
|
"type": edge["type"], "edge_feature_index": []}
|
||||||
|
|
||||||
|
|
|
@ -164,7 +164,7 @@ if __name__ == "__main__":
|
||||||
num_features, feature_data_types, feature_shapes = mr_api.edge_profile
|
num_features, feature_data_types, feature_shapes = mr_api.edge_profile
|
||||||
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
|
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
|
||||||
|
|
||||||
graph_schema = graph_map_schema.get_schema()
|
graph_schema = graph_map_schema.get_schema
|
||||||
|
|
||||||
# init writer
|
# init writer
|
||||||
writer = init_writer(graph_schema)
|
writer = init_writer(graph_schema)
|
||||||
|
|
|
@ -983,7 +983,9 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size);
|
int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size,
|
||||||
|
count * type_size);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric");
|
||||||
out_index += count;
|
out_index += count;
|
||||||
if (i < indices.size() - 1) {
|
if (i < indices.size() - 1) {
|
||||||
src_start = HandleNeg(indices[i + 1], dim_length); // next index
|
src_start = HandleNeg(indices[i + 1], dim_length); // next index
|
||||||
|
|
|
@ -101,6 +101,9 @@ class BucketBatchByLengthOp : public PipelineOp {
|
||||||
std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info,
|
std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info,
|
||||||
bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size);
|
bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size);
|
||||||
|
|
||||||
|
// Destructor
|
||||||
|
~BucketBatchByLengthOp() = default;
|
||||||
|
|
||||||
// Might need to batch remaining buckets after receiving eoe, so override this method.
|
// Might need to batch remaining buckets after receiving eoe, so override this method.
|
||||||
// @param int32_t workerId
|
// @param int32_t workerId
|
||||||
// @return Status - The error code returned
|
// @return Status - The error code returned
|
||||||
|
|
|
@ -36,6 +36,7 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
|
||||||
: mr_path_(mr_filepath),
|
: mr_path_(mr_filepath),
|
||||||
num_workers_(num_workers),
|
num_workers_(num_workers),
|
||||||
row_id_(0),
|
row_id_(0),
|
||||||
|
shard_reader_(nullptr),
|
||||||
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
|
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
|
||||||
|
|
||||||
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
|
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
|
||||||
|
|
|
@ -37,7 +37,7 @@ namespace dataset {
|
||||||
|
|
||||||
// Driver method for TreePass
|
// Driver method for TreePass
|
||||||
Status TreePass::Run(ExecutionTree *tree, bool *modified) {
|
Status TreePass::Run(ExecutionTree *tree, bool *modified) {
|
||||||
if (!tree || !modified) {
|
if (tree == nullptr || modified == nullptr) {
|
||||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
|
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
|
||||||
}
|
}
|
||||||
return this->RunOnTree(tree, modified);
|
return this->RunOnTree(tree, modified);
|
||||||
|
@ -45,7 +45,7 @@ Status TreePass::Run(ExecutionTree *tree, bool *modified) {
|
||||||
|
|
||||||
// Driver method for NodePass
|
// Driver method for NodePass
|
||||||
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
|
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
|
||||||
if (!tree || !modified) {
|
if (tree == nullptr || modified == nullptr) {
|
||||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
|
return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass");
|
||||||
}
|
}
|
||||||
std::shared_ptr<DatasetOp> root = tree->root();
|
std::shared_ptr<DatasetOp> root = tree->root();
|
||||||
|
|
|
@ -44,7 +44,7 @@ class ConnectorSize : public Sampling {
|
||||||
public:
|
public:
|
||||||
explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {}
|
explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {}
|
||||||
|
|
||||||
~ConnectorSize() = default;
|
~ConnectorSize() override = default;
|
||||||
|
|
||||||
// Driver function for connector size sampling.
|
// Driver function for connector size sampling.
|
||||||
// This function samples the connector size of every nodes within the ExecutionTree
|
// This function samples the connector size of every nodes within the ExecutionTree
|
||||||
|
|
|
@ -26,6 +26,7 @@ Monitor::Monitor(ExecutionTree *tree) : tree_(tree) {
|
||||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||||
sampling_interval_ = cfg->monitor_sampling_interval();
|
sampling_interval_ = cfg->monitor_sampling_interval();
|
||||||
max_samples_ = 0;
|
max_samples_ = 0;
|
||||||
|
cur_row_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Monitor::operator()() {
|
Status Monitor::operator()() {
|
||||||
|
|
|
@ -34,6 +34,8 @@ class Slice {
|
||||||
Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
|
Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
|
||||||
explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
|
explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
|
||||||
|
|
||||||
|
~Slice() = default;
|
||||||
|
|
||||||
std::vector<dsize_t> Indices(dsize_t length) {
|
std::vector<dsize_t> Indices(dsize_t length) {
|
||||||
std::vector<dsize_t> indices;
|
std::vector<dsize_t> indices;
|
||||||
dsize_t index = std::min(Tensor::HandleNeg(start_, length), length);
|
dsize_t index = std::min(Tensor::HandleNeg(start_, length), length);
|
||||||
|
|
|
@ -29,8 +29,8 @@ Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow
|
||||||
BOUNDING_BOX_CHECK(input);
|
BOUNDING_BOX_CHECK(input);
|
||||||
if (distribution_(rnd_)) {
|
if (distribution_(rnd_)) {
|
||||||
// To test bounding boxes algorithm, create random bboxes from image dims
|
// To test bounding boxes algorithm, create random bboxes from image dims
|
||||||
size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes
|
size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes
|
||||||
float img_center = (input[0]->shape()[1] / 2); // get the center of the image
|
float img_center = (input[0]->shape()[1] / 2.); // get the center of the image
|
||||||
|
|
||||||
for (int i = 0; i < num_of_boxes; i++) {
|
for (int i = 0; i < num_of_boxes; i++) {
|
||||||
uint32_t b_w = 0; // bounding box width
|
uint32_t b_w = 0; // bounding box width
|
||||||
|
|
|
@ -49,6 +49,7 @@ BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, Normal
|
||||||
preserve_unused_token_(preserve_unused_token),
|
preserve_unused_token_(preserve_unused_token),
|
||||||
case_fold_(std::make_unique<CaseFoldOp>()),
|
case_fold_(std::make_unique<CaseFoldOp>()),
|
||||||
nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)),
|
nfd_normalize_(std::make_unique<NormalizeUTF8Op>(NormalizeForm::kNfd)),
|
||||||
|
normalization_form_(normalization_form),
|
||||||
common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)),
|
common_normalize_(std::make_unique<NormalizeUTF8Op>(normalization_form)),
|
||||||
replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")),
|
replace_accent_chars_(std::make_unique<RegexReplaceOp>("\\p{Mn}", "")),
|
||||||
replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) {
|
replace_control_chars_(std::make_unique<RegexReplaceOp>("\\p{Cc}|\\p{Cf}", " ")) {
|
||||||
|
|
|
@ -35,9 +35,9 @@ class BasicTokenizerOp : public TensorOp {
|
||||||
static const bool kDefKeepWhitespace;
|
static const bool kDefKeepWhitespace;
|
||||||
static const NormalizeForm kDefNormalizationForm;
|
static const NormalizeForm kDefNormalizationForm;
|
||||||
static const bool kDefPreserveUnusedToken;
|
static const bool kDefPreserveUnusedToken;
|
||||||
BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace,
|
explicit BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace,
|
||||||
NormalizeForm normalization_form = kDefNormalizationForm,
|
NormalizeForm normalization_form = kDefNormalizationForm,
|
||||||
bool preserve_unused_token = kDefPreserveUnusedToken);
|
bool preserve_unused_token = kDefPreserveUnusedToken);
|
||||||
|
|
||||||
~BasicTokenizerOp() override = default;
|
~BasicTokenizerOp() override = default;
|
||||||
|
|
||||||
|
|
|
@ -28,14 +28,14 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
class BertTokenizerOp : public TensorOp {
|
class BertTokenizerOp : public TensorOp {
|
||||||
public:
|
public:
|
||||||
BertTokenizerOp(const std::shared_ptr<Vocab> &vocab,
|
explicit BertTokenizerOp(const std::shared_ptr<Vocab> &vocab,
|
||||||
const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator,
|
const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator,
|
||||||
const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||||
const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken,
|
const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken,
|
||||||
bool lower_case = BasicTokenizerOp::kDefLowerCase,
|
bool lower_case = BasicTokenizerOp::kDefLowerCase,
|
||||||
bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace,
|
bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace,
|
||||||
NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm,
|
NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm,
|
||||||
bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken)
|
bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken)
|
||||||
: wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token),
|
: wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token),
|
||||||
basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {}
|
basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {}
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
|
||||||
// @return
|
// @return
|
||||||
Status insert(const value_type &val, key_type *key = nullptr) {
|
Status insert(const value_type &val, key_type *key = nullptr) {
|
||||||
key_type my_inx = inx_.fetch_add(1);
|
key_type my_inx = inx_.fetch_add(1);
|
||||||
if (key) {
|
if (key != nullptr) {
|
||||||
*key = my_inx;
|
*key = my_inx;
|
||||||
}
|
}
|
||||||
return my_tree::DoInsert(my_inx, val);
|
return my_tree::DoInsert(my_inx, val);
|
||||||
|
|
|
@ -323,7 +323,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
|
vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type) {
|
||||||
uint64_t i_size = kUnsignedOne << int_type;
|
uint64_t i_size = kUnsignedOne << static_cast<uint8_t>(int_type);
|
||||||
// Get number of elements
|
// Get number of elements
|
||||||
uint64_t src_n_int = src_bytes.size() / i_size;
|
uint64_t src_n_int = src_bytes.size() / i_size;
|
||||||
// Calculate bitmap size (bytes)
|
// Calculate bitmap size (bytes)
|
||||||
|
@ -344,7 +344,7 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const
|
||||||
// Initialize destination data type
|
// Initialize destination data type
|
||||||
IntegerType dst_int_type = kInt8Type;
|
IntegerType dst_int_type = kInt8Type;
|
||||||
// Shift to next int position
|
// Shift to next int position
|
||||||
uint64_t pos = i * (kUnsignedOne << int_type);
|
uint64_t pos = i * (kUnsignedOne << static_cast<uint8_t>(int_type));
|
||||||
// Narrow down this int
|
// Narrow down this int
|
||||||
int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);
|
int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type);
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ class Shuffle(str, Enum):
|
||||||
@check_zip
|
@check_zip
|
||||||
def zip(datasets):
|
def zip(datasets):
|
||||||
"""
|
"""
|
||||||
Zips the datasets in the input tuple of datasets.
|
Zip the datasets in the input tuple of datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
|
datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
|
||||||
|
@ -152,7 +152,7 @@ class Dataset:
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
"""
|
"""
|
||||||
Returns attributes (member variables) related to the current class.
|
Return attributes (member variables) related to the current class.
|
||||||
|
|
||||||
Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
|
Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ class Dataset:
|
||||||
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
|
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
|
||||||
input_columns=None, pad_info=None):
|
input_columns=None, pad_info=None):
|
||||||
"""
|
"""
|
||||||
Combines batch_size number of consecutive rows into batches.
|
Combine batch_size number of consecutive rows into batches.
|
||||||
|
|
||||||
For any child node, a batch is treated as a single row.
|
For any child node, a batch is treated as a single row.
|
||||||
For any column, all the elements within that column must have the same shape.
|
For any column, all the elements within that column must have the same shape.
|
||||||
|
@ -340,7 +340,7 @@ class Dataset:
|
||||||
|
|
||||||
def flat_map(self, func):
|
def flat_map(self, func):
|
||||||
"""
|
"""
|
||||||
Maps `func` to each row in dataset and flatten the result.
|
Map `func` to each row in dataset and flatten the result.
|
||||||
|
|
||||||
The specified `func` is a function that must take one 'Ndarray' as input
|
The specified `func` is a function that must take one 'Ndarray' as input
|
||||||
and return a 'Dataset'.
|
and return a 'Dataset'.
|
||||||
|
@ -370,6 +370,7 @@ class Dataset:
|
||||||
"""
|
"""
|
||||||
dataset = None
|
dataset = None
|
||||||
if not hasattr(func, '__call__'):
|
if not hasattr(func, '__call__'):
|
||||||
|
logger.error("func must be a function.")
|
||||||
raise TypeError("func must be a function.")
|
raise TypeError("func must be a function.")
|
||||||
|
|
||||||
for row_data in self:
|
for row_data in self:
|
||||||
|
@ -379,6 +380,7 @@ class Dataset:
|
||||||
dataset += func(row_data)
|
dataset += func(row_data)
|
||||||
|
|
||||||
if not isinstance(dataset, Dataset):
|
if not isinstance(dataset, Dataset):
|
||||||
|
logger.error("flat_map must return a Dataset object.")
|
||||||
raise TypeError("flat_map must return a Dataset object.")
|
raise TypeError("flat_map must return a Dataset object.")
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
@ -386,7 +388,7 @@ class Dataset:
|
||||||
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||||
num_parallel_workers=None, python_multiprocessing=False):
|
num_parallel_workers=None, python_multiprocessing=False):
|
||||||
"""
|
"""
|
||||||
Applies each operation in operations to this dataset.
|
Apply each operation in operations to this dataset.
|
||||||
|
|
||||||
The order of operations is determined by the position of each operation in operations.
|
The order of operations is determined by the position of each operation in operations.
|
||||||
operations[0] will be applied first, then operations[1], then operations[2], etc.
|
operations[0] will be applied first, then operations[1], then operations[2], etc.
|
||||||
|
@ -570,7 +572,7 @@ class Dataset:
|
||||||
@check_repeat
|
@check_repeat
|
||||||
def repeat(self, count=None):
|
def repeat(self, count=None):
|
||||||
"""
|
"""
|
||||||
Repeats this dataset count times. Repeat indefinitely if the count is None or -1.
|
Repeat this dataset count times. Repeat indefinitely if the count is None or -1.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The order of using repeat and batch reflects the number of batches. Recommend that
|
The order of using repeat and batch reflects the number of batches. Recommend that
|
||||||
|
@ -662,13 +664,16 @@ class Dataset:
|
||||||
dataset_size = self.get_dataset_size()
|
dataset_size = self.get_dataset_size()
|
||||||
|
|
||||||
if dataset_size is None or dataset_size <= 0:
|
if dataset_size is None or dataset_size <= 0:
|
||||||
raise RuntimeError("dataset size unknown, unable to split.")
|
raise RuntimeError("dataset_size is unknown, unable to split.")
|
||||||
|
|
||||||
|
if not isinstance(sizes, list):
|
||||||
|
raise RuntimeError("sizes should be a list.")
|
||||||
|
|
||||||
all_int = all(isinstance(item, int) for item in sizes)
|
all_int = all(isinstance(item, int) for item in sizes)
|
||||||
if all_int:
|
if all_int:
|
||||||
sizes_sum = sum(sizes)
|
sizes_sum = sum(sizes)
|
||||||
if sizes_sum != dataset_size:
|
if sizes_sum != dataset_size:
|
||||||
raise RuntimeError("sum of split sizes {} is not equal to dataset size {}."
|
raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
|
||||||
.format(sizes_sum, dataset_size))
|
.format(sizes_sum, dataset_size))
|
||||||
return sizes
|
return sizes
|
||||||
|
|
||||||
|
@ -676,7 +681,7 @@ class Dataset:
|
||||||
for item in sizes:
|
for item in sizes:
|
||||||
absolute_size = int(round(item * dataset_size))
|
absolute_size = int(round(item * dataset_size))
|
||||||
if absolute_size == 0:
|
if absolute_size == 0:
|
||||||
raise RuntimeError("split percentage {} is too small.".format(item))
|
raise RuntimeError("Split percentage {} is too small.".format(item))
|
||||||
absolute_sizes.append(absolute_size)
|
absolute_sizes.append(absolute_size)
|
||||||
|
|
||||||
absolute_sizes_sum = sum(absolute_sizes)
|
absolute_sizes_sum = sum(absolute_sizes)
|
||||||
|
@ -694,7 +699,7 @@ class Dataset:
|
||||||
break
|
break
|
||||||
|
|
||||||
if sum(absolute_sizes) != dataset_size:
|
if sum(absolute_sizes) != dataset_size:
|
||||||
raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
|
raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
|
||||||
.format(absolute_sizes_sum, dataset_size))
|
.format(absolute_sizes_sum, dataset_size))
|
||||||
|
|
||||||
return absolute_sizes
|
return absolute_sizes
|
||||||
|
@ -702,7 +707,7 @@ class Dataset:
|
||||||
@check_split
|
@check_split
|
||||||
def split(self, sizes, randomize=True):
|
def split(self, sizes, randomize=True):
|
||||||
"""
|
"""
|
||||||
Splits the dataset into smaller, non-overlapping datasets.
|
Split the dataset into smaller, non-overlapping datasets.
|
||||||
|
|
||||||
This is a general purpose split function which can be called from any operator in the pipeline.
|
This is a general purpose split function which can be called from any operator in the pipeline.
|
||||||
There is another, optimized split function, which will be called automatically if ds.split is
|
There is another, optimized split function, which will be called automatically if ds.split is
|
||||||
|
@ -759,10 +764,10 @@ class Dataset:
|
||||||
>>> train, test = data.split([0.9, 0.1])
|
>>> train, test = data.split([0.9, 0.1])
|
||||||
"""
|
"""
|
||||||
if self.is_shuffled():
|
if self.is_shuffled():
|
||||||
logger.warning("dataset is shuffled before split.")
|
logger.warning("Dataset is shuffled before split.")
|
||||||
|
|
||||||
if self.is_sharded():
|
if self.is_sharded():
|
||||||
raise RuntimeError("dataset should not be sharded before split.")
|
raise RuntimeError("Dataset should not be sharded before split.")
|
||||||
|
|
||||||
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
||||||
splits = []
|
splits = []
|
||||||
|
@ -788,7 +793,7 @@ class Dataset:
|
||||||
@check_zip_dataset
|
@check_zip_dataset
|
||||||
def zip(self, datasets):
|
def zip(self, datasets):
|
||||||
"""
|
"""
|
||||||
Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
|
Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset
|
datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset
|
||||||
|
@ -845,7 +850,7 @@ class Dataset:
|
||||||
@check_rename
|
@check_rename
|
||||||
def rename(self, input_columns, output_columns):
|
def rename(self, input_columns, output_columns):
|
||||||
"""
|
"""
|
||||||
Renames the columns in input datasets.
|
Rename the columns in input datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_columns (list[str]): list of names of the input columns.
|
input_columns (list[str]): list of names of the input columns.
|
||||||
|
@ -871,7 +876,7 @@ class Dataset:
|
||||||
@check_project
|
@check_project
|
||||||
def project(self, columns):
|
def project(self, columns):
|
||||||
"""
|
"""
|
||||||
Projects certain columns in input datasets.
|
Project certain columns in input datasets.
|
||||||
|
|
||||||
The specified columns will be selected from the dataset and passed down
|
The specified columns will be selected from the dataset and passed down
|
||||||
the pipeline in the order specified. The other columns are discarded.
|
the pipeline in the order specified. The other columns are discarded.
|
||||||
|
@ -936,7 +941,7 @@ class Dataset:
|
||||||
|
|
||||||
def device_que(self, prefetch_size=None):
|
def device_que(self, prefetch_size=None):
|
||||||
"""
|
"""
|
||||||
Returns a transferredDataset that transfer data through device.
|
Return a transferredDataset that transfer data through device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prefetch_size (int, optional): prefetch number of records ahead of the
|
prefetch_size (int, optional): prefetch number of records ahead of the
|
||||||
|
@ -953,7 +958,7 @@ class Dataset:
|
||||||
|
|
||||||
def to_device(self, num_batch=None):
|
def to_device(self, num_batch=None):
|
||||||
"""
|
"""
|
||||||
Transfers data through CPU, GPU or Ascend devices.
|
Transfer data through CPU, GPU or Ascend devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_batch (int, optional): limit the number of batch to be sent to device (default=None).
|
num_batch (int, optional): limit the number of batch to be sent to device (default=None).
|
||||||
|
@ -988,7 +993,7 @@ class Dataset:
|
||||||
raise TypeError("Please set device_type in context")
|
raise TypeError("Please set device_type in context")
|
||||||
|
|
||||||
if device_type not in ('Ascend', 'GPU', 'CPU'):
|
if device_type not in ('Ascend', 'GPU', 'CPU'):
|
||||||
raise ValueError("only support CPU, Ascend, GPU")
|
raise ValueError("Only support CPU, Ascend, GPU")
|
||||||
|
|
||||||
if num_batch is None or num_batch == 0:
|
if num_batch is None or num_batch == 0:
|
||||||
raise ValueError("num_batch is None or 0.")
|
raise ValueError("num_batch is None or 0.")
|
||||||
|
@ -1089,7 +1094,7 @@ class Dataset:
|
||||||
|
|
||||||
def _get_pipeline_info(self):
|
def _get_pipeline_info(self):
|
||||||
"""
|
"""
|
||||||
Gets pipeline information.
|
Get pipeline information.
|
||||||
"""
|
"""
|
||||||
device_iter = TupleIterator(self)
|
device_iter = TupleIterator(self)
|
||||||
self._output_shapes = device_iter.get_output_shapes()
|
self._output_shapes = device_iter.get_output_shapes()
|
||||||
|
@ -1344,7 +1349,7 @@ class MappableDataset(SourceDataset):
|
||||||
@check_split
|
@check_split
|
||||||
def split(self, sizes, randomize=True):
|
def split(self, sizes, randomize=True):
|
||||||
"""
|
"""
|
||||||
Splits the dataset into smaller, non-overlapping datasets.
|
Split the dataset into smaller, non-overlapping datasets.
|
||||||
|
|
||||||
There is the optimized split function, which will be called automatically when the dataset
|
There is the optimized split function, which will be called automatically when the dataset
|
||||||
that calls this function is a MappableDataset.
|
that calls this function is a MappableDataset.
|
||||||
|
@ -1411,10 +1416,10 @@ class MappableDataset(SourceDataset):
|
||||||
>>> train.use_sampler(train_sampler)
|
>>> train.use_sampler(train_sampler)
|
||||||
"""
|
"""
|
||||||
if self.is_shuffled():
|
if self.is_shuffled():
|
||||||
logger.warning("dataset is shuffled before split.")
|
logger.warning("Dataset is shuffled before split.")
|
||||||
|
|
||||||
if self.is_sharded():
|
if self.is_sharded():
|
||||||
raise RuntimeError("dataset should not be sharded before split.")
|
raise RuntimeError("Dataset should not be sharded before split.")
|
||||||
|
|
||||||
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
absolute_sizes = self._get_absolute_split_sizes(sizes)
|
||||||
splits = []
|
splits = []
|
||||||
|
@ -1633,7 +1638,7 @@ class BlockReleasePair:
|
||||||
|
|
||||||
def __init__(self, init_release_rows, callback=None):
|
def __init__(self, init_release_rows, callback=None):
|
||||||
if isinstance(init_release_rows, int) and init_release_rows <= 0:
|
if isinstance(init_release_rows, int) and init_release_rows <= 0:
|
||||||
raise ValueError("release_rows need to be greater than 0.")
|
raise ValueError("release_rows need to be greater than 0.")
|
||||||
self.row_count = -init_release_rows
|
self.row_count = -init_release_rows
|
||||||
self.cv = threading.Condition()
|
self.cv = threading.Condition()
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
@ -2699,10 +2704,10 @@ class MindDataset(MappableDataset):
|
||||||
self.shard_id = shard_id
|
self.shard_id = shard_id
|
||||||
|
|
||||||
if block_reader is True and num_shards is not None:
|
if block_reader is True and num_shards is not None:
|
||||||
raise ValueError("block reader not allowed true when use partitions")
|
raise ValueError("block_reader not allowed true when use partitions")
|
||||||
|
|
||||||
if block_reader is True and shuffle is True:
|
if block_reader is True and shuffle is True:
|
||||||
raise ValueError("block reader not allowed true when use shuffle")
|
raise ValueError("block_reader not allowed true when use shuffle")
|
||||||
|
|
||||||
if block_reader is True:
|
if block_reader is True:
|
||||||
logger.warning("WARN: global shuffle is not used.")
|
logger.warning("WARN: global shuffle is not used.")
|
||||||
|
@ -2711,14 +2716,14 @@ class MindDataset(MappableDataset):
|
||||||
if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
|
if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
|
||||||
samplers.DistributedSampler, samplers.RandomSampler,
|
samplers.DistributedSampler, samplers.RandomSampler,
|
||||||
samplers.SequentialSampler)) is False:
|
samplers.SequentialSampler)) is False:
|
||||||
raise ValueError("the sampler is not supported yet.")
|
raise ValueError("The sampler is not supported yet.")
|
||||||
|
|
||||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
|
|
||||||
# sampler exclusive
|
# sampler exclusive
|
||||||
if block_reader is True and sampler is not None:
|
if block_reader is True and sampler is not None:
|
||||||
raise ValueError("block reader not allowed true when use sampler")
|
raise ValueError("block_reader not allowed true when use sampler")
|
||||||
|
|
||||||
if num_padded is None:
|
if num_padded is None:
|
||||||
num_padded = 0
|
num_padded = 0
|
||||||
|
@ -2770,7 +2775,7 @@ class MindDataset(MappableDataset):
|
||||||
if value >= 0:
|
if value >= 0:
|
||||||
self._dataset_size = value
|
self._dataset_size = value
|
||||||
else:
|
else:
|
||||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||||
|
|
||||||
def is_shuffled(self):
|
def is_shuffled(self):
|
||||||
if self.shuffle_option is None:
|
if self.shuffle_option is None:
|
||||||
|
@ -2872,7 +2877,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
|
||||||
|
|
||||||
def _fetch_py_sampler_indices(sampler, num_samples):
|
def _fetch_py_sampler_indices(sampler, num_samples):
|
||||||
"""
|
"""
|
||||||
Indices fetcher for python sampler.
|
Indice fetcher for python sampler.
|
||||||
"""
|
"""
|
||||||
if num_samples is not None:
|
if num_samples is not None:
|
||||||
sampler_iter = iter(sampler)
|
sampler_iter = iter(sampler)
|
||||||
|
@ -3163,7 +3168,7 @@ class GeneratorDataset(MappableDataset):
|
||||||
if value >= 0:
|
if value >= 0:
|
||||||
self._dataset_size = value
|
self._dataset_size = value
|
||||||
else:
|
else:
|
||||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||||
|
|
||||||
def __deepcopy__(self, memodict):
|
def __deepcopy__(self, memodict):
|
||||||
if id(self) in memodict:
|
if id(self) in memodict:
|
||||||
|
@ -3313,7 +3318,7 @@ class TFRecordDataset(SourceDataset):
|
||||||
if value >= 0:
|
if value >= 0:
|
||||||
self._dataset_size = value
|
self._dataset_size = value
|
||||||
else:
|
else:
|
||||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||||
|
|
||||||
def is_shuffled(self):
|
def is_shuffled(self):
|
||||||
return self.shuffle_files
|
return self.shuffle_files
|
||||||
|
@ -4382,7 +4387,9 @@ class CelebADataset(MappableDataset):
|
||||||
try:
|
try:
|
||||||
with open(attr_file, 'r') as f:
|
with open(attr_file, 'r') as f:
|
||||||
num_rows = int(f.readline())
|
num_rows = int(f.readline())
|
||||||
except Exception:
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError("attr_file not found.")
|
||||||
|
except BaseException:
|
||||||
raise RuntimeError("Get dataset size failed from attribution file.")
|
raise RuntimeError("Get dataset size failed from attribution file.")
|
||||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||||
if self.num_samples is not None:
|
if self.num_samples is not None:
|
||||||
|
|
|
@ -319,7 +319,7 @@ class PKSampler(BuiltinSampler):
|
||||||
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
|
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
|
||||||
|
|
||||||
if num_class is not None:
|
if num_class is not None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError("Not support specify num_class")
|
||||||
|
|
||||||
if not isinstance(shuffle, bool):
|
if not isinstance(shuffle, bool):
|
||||||
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
|
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
|
||||||
|
@ -551,8 +551,8 @@ class WeightedRandomSampler(BuiltinSampler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
|
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
|
||||||
num_samples (int): Number of elements to sample (default=None, all elements).
|
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
||||||
replacement (bool, optional): If True, put the sample ID back for the next draw (default=True).
|
replacement (bool): If True, put the sample ID back for the next draw (default=True).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataset as ds
|
||||||
|
|
|
@ -50,7 +50,7 @@ def check_filename(path):
|
||||||
Exception: when error
|
Exception: when error
|
||||||
"""
|
"""
|
||||||
if not isinstance(path, str):
|
if not isinstance(path, str):
|
||||||
raise ValueError("path: {} is not string".format(path))
|
raise TypeError("path: {} is not string".format(path))
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
|
|
||||||
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
|
# '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
|
||||||
|
@ -143,7 +143,7 @@ def check_sampler_shuffle_shard_options(param_dict):
|
||||||
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
|
||||||
|
|
||||||
if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||||
raise ValueError("sampler is not a valid Sampler type.")
|
raise TypeError("sampler is not a valid Sampler type.")
|
||||||
|
|
||||||
if sampler is not None:
|
if sampler is not None:
|
||||||
if shuffle is not None:
|
if shuffle is not None:
|
||||||
|
@ -328,13 +328,13 @@ def check_vocdataset(method):
|
||||||
if task is None:
|
if task is None:
|
||||||
raise ValueError("task is not provided.")
|
raise ValueError("task is not provided.")
|
||||||
if not isinstance(task, str):
|
if not isinstance(task, str):
|
||||||
raise ValueError("task is not str type.")
|
raise TypeError("task is not str type.")
|
||||||
# check mode; required argument
|
# check mode; required argument
|
||||||
mode = param_dict.get('mode')
|
mode = param_dict.get('mode')
|
||||||
if mode is None:
|
if mode is None:
|
||||||
raise ValueError("mode is not provided.")
|
raise ValueError("mode is not provided.")
|
||||||
if not isinstance(mode, str):
|
if not isinstance(mode, str):
|
||||||
raise ValueError("mode is not str type.")
|
raise TypeError("mode is not str type.")
|
||||||
|
|
||||||
imagesets_file = ""
|
imagesets_file = ""
|
||||||
if task == "Segmentation":
|
if task == "Segmentation":
|
||||||
|
@ -388,7 +388,7 @@ def check_cocodataset(method):
|
||||||
if task is None:
|
if task is None:
|
||||||
raise ValueError("task is not provided.")
|
raise ValueError("task is not provided.")
|
||||||
if not isinstance(task, str):
|
if not isinstance(task, str):
|
||||||
raise ValueError("task is not str type.")
|
raise TypeError("task is not str type.")
|
||||||
|
|
||||||
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
|
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
|
||||||
raise ValueError("Invalid task type")
|
raise ValueError("Invalid task type")
|
||||||
|
@ -556,7 +556,7 @@ def check_generatordataset(method):
|
||||||
|
|
||||||
def check_batch_size(batch_size):
|
def check_batch_size(batch_size):
|
||||||
if not (isinstance(batch_size, int) or (callable(batch_size))):
|
if not (isinstance(batch_size, int) or (callable(batch_size))):
|
||||||
raise ValueError("batch_size should either be an int or a callable.")
|
raise TypeError("batch_size should either be an int or a callable.")
|
||||||
if callable(batch_size):
|
if callable(batch_size):
|
||||||
sig = ins.signature(batch_size)
|
sig = ins.signature(batch_size)
|
||||||
if len(sig.parameters) != 1:
|
if len(sig.parameters) != 1:
|
||||||
|
@ -706,6 +706,7 @@ def check_batch(method):
|
||||||
|
|
||||||
def check_sync_wait(method):
|
def check_sync_wait(method):
|
||||||
"""check the input arguments of sync_wait."""
|
"""check the input arguments of sync_wait."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(*args, **kwargs):
|
def new_method(*args, **kwargs):
|
||||||
param_dict = make_param_dict(method, args, kwargs)
|
param_dict = make_param_dict(method, args, kwargs)
|
||||||
|
@ -773,7 +774,7 @@ def check_filter(method):
|
||||||
param_dict = make_param_dict(method, args, kwargs)
|
param_dict = make_param_dict(method, args, kwargs)
|
||||||
predicate = param_dict.get("predicate")
|
predicate = param_dict.get("predicate")
|
||||||
if not callable(predicate):
|
if not callable(predicate):
|
||||||
raise ValueError("Predicate should be a python function or a callable python object.")
|
raise TypeError("Predicate should be a python function or a callable python object.")
|
||||||
|
|
||||||
nreq_param_int = ['num_parallel_workers']
|
nreq_param_int = ['num_parallel_workers']
|
||||||
check_param_type(nreq_param_int, param_dict, int)
|
check_param_type(nreq_param_int, param_dict, int)
|
||||||
|
@ -865,7 +866,7 @@ def check_zip_dataset(method):
|
||||||
raise ValueError("datasets is not provided.")
|
raise ValueError("datasets is not provided.")
|
||||||
|
|
||||||
if not isinstance(ds, (tuple, datasets.Dataset)):
|
if not isinstance(ds, (tuple, datasets.Dataset)):
|
||||||
raise ValueError("datasets is not tuple or of type Dataset.")
|
raise TypeError("datasets is not tuple or of type Dataset.")
|
||||||
|
|
||||||
return method(*args, **kwargs)
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -885,7 +886,7 @@ def check_concat(method):
|
||||||
raise ValueError("datasets is not provided.")
|
raise ValueError("datasets is not provided.")
|
||||||
|
|
||||||
if not isinstance(ds, (list, datasets.Dataset)):
|
if not isinstance(ds, (list, datasets.Dataset)):
|
||||||
raise ValueError("datasets is not list or of type Dataset.")
|
raise TypeError("datasets is not list or of type Dataset.")
|
||||||
|
|
||||||
return method(*args, **kwargs)
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -964,7 +965,7 @@ def check_add_column(method):
|
||||||
de_type = param_dict.get("de_type")
|
de_type = param_dict.get("de_type")
|
||||||
if de_type is not None:
|
if de_type is not None:
|
||||||
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
|
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
|
||||||
raise ValueError("Unknown column type.")
|
raise TypeError("Unknown column type.")
|
||||||
else:
|
else:
|
||||||
raise TypeError("Expected non-empty string.")
|
raise TypeError("Expected non-empty string.")
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,6 @@ wheel >= 0.32.0
|
||||||
decorator >= 4.4.0
|
decorator >= 4.4.0
|
||||||
setuptools >= 40.8.0
|
setuptools >= 40.8.0
|
||||||
matplotlib >= 3.1.3 # for ut test
|
matplotlib >= 3.1.3 # for ut test
|
||||||
opencv-python >= 4.2.0.32 # for ut test
|
opencv-python >= 4.1.2.30 # for ut test
|
||||||
sklearn >= 0.0 # for st test
|
sklearn >= 0.0 # for st test
|
||||||
pandas >= 1.0.2 # for ut test
|
pandas >= 1.0.2 # for ut test
|
|
@ -42,15 +42,15 @@ def split_with_invalid_inputs(d):
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as info:
|
with pytest.raises(RuntimeError) as info:
|
||||||
_, _ = d.split([3, 1])
|
_, _ = d.split([3, 1])
|
||||||
assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
|
assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as info:
|
with pytest.raises(RuntimeError) as info:
|
||||||
_, _ = d.split([5, 1])
|
_, _ = d.split([5, 1])
|
||||||
assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
|
assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as info:
|
with pytest.raises(RuntimeError) as info:
|
||||||
_, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
|
_, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
|
||||||
assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
|
assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
|
||||||
|
|
||||||
with pytest.raises(ValueError) as info:
|
with pytest.raises(ValueError) as info:
|
||||||
_, _ = d.split([-0.5, 0.5])
|
_, _ = d.split([-0.5, 0.5])
|
||||||
|
@ -80,7 +80,7 @@ def test_unmappable_invalid_input():
|
||||||
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
|
d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
|
||||||
with pytest.raises(RuntimeError) as info:
|
with pytest.raises(RuntimeError) as info:
|
||||||
_, _ = d.split([4, 1])
|
_, _ = d.split([4, 1])
|
||||||
assert "dataset should not be sharded before split" in str(info.value)
|
assert "Dataset should not be sharded before split" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_unmappable_split():
|
def test_unmappable_split():
|
||||||
|
@ -274,7 +274,7 @@ def test_mappable_invalid_input():
|
||||||
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
|
d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
|
||||||
with pytest.raises(RuntimeError) as info:
|
with pytest.raises(RuntimeError) as info:
|
||||||
_, _ = d.split([4, 1])
|
_, _ = d.split([4, 1])
|
||||||
assert "dataset should not be sharded before split" in str(info.value)
|
assert "Dataset should not be sharded before split" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_mappable_split_general():
|
def test_mappable_split_general():
|
||||||
|
|
Loading…
Reference in New Issue