forked from mindspore-Ecosystem/mindspore
!8594 [MD] Pybind Pushdown Support for dataset
From: @cathwong Reviewed-by: Signed-off-by:
This commit is contained in:
commit
adc8e3e707
|
@ -2,9 +2,11 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
if (ENABLE_PYTHON)
|
||||
add_library(APItoPython OBJECT
|
||||
python/de_pipeline.cc
|
||||
python/pybind_register.cc
|
||||
python/bindings.cc
|
||||
python/pybind_conversion.cc
|
||||
python/bindings/dataset/include/datasets_bindings.cc
|
||||
python/bindings/dataset/include/iterator_bindings.cc
|
||||
python/bindings/dataset/include/schema_bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
python/bindings/dataset/callback/bindings.cc
|
||||
|
|
|
@ -115,7 +115,8 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Function to return a transferred Node that transfers data through a device.
|
||||
bool Dataset::DeviceQueue(bool send_epoch_end) {
|
||||
bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32_t num_epochs, bool send_epoch_end,
|
||||
int32_t total_batches, bool create_data_info_queue) {
|
||||
Status rc;
|
||||
|
||||
// Build and launch tree
|
||||
|
@ -126,11 +127,12 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Add TransferNode IR on top of dataset d
|
||||
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end);
|
||||
// Add TransferNode IR on top of dataset
|
||||
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), queue_name, device_type, send_epoch_end,
|
||||
total_batches, create_data_info_queue);
|
||||
|
||||
// Get ToDevice consumer
|
||||
auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1);
|
||||
auto consumer = std::make_unique<ToDevice>(num_epochs);
|
||||
ToDevice *consumer_ = consumer.get();
|
||||
rc = consumer->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
|
@ -199,127 +201,55 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
|
|||
|
||||
int64_t Dataset::GetDatasetSize() {
|
||||
int64_t dataset_size;
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetDatasetSize(&dataset_size);
|
||||
return rc.IsError() ? -1 : dataset_size;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1);
|
||||
return dataset_size;
|
||||
}
|
||||
|
||||
std::vector<DataType> Dataset::GetOutputTypes() {
|
||||
std::vector<DataType> types;
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
|
||||
return types;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
|
||||
return types;
|
||||
}
|
||||
rc = tree_getters_->GetOutputTypes(&types);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed.";
|
||||
types.clear();
|
||||
return types;
|
||||
}
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputTypes(&types), {});
|
||||
return types;
|
||||
}
|
||||
|
||||
std::vector<TensorShape> Dataset::GetOutputShapes() {
|
||||
std::vector<TensorShape> shapes;
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
|
||||
return shapes;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
|
||||
return shapes;
|
||||
}
|
||||
rc = tree_getters_->GetOutputShapes(&shapes);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed.";
|
||||
shapes.clear();
|
||||
return shapes;
|
||||
}
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputShapes(&shapes), {});
|
||||
return shapes;
|
||||
}
|
||||
|
||||
int64_t Dataset::GetNumClasses() {
|
||||
int64_t num_classes;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetNumClasses(&num_classes);
|
||||
return rc.IsError() ? -1 : num_classes;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetNumClasses(&num_classes), -1);
|
||||
return num_classes;
|
||||
}
|
||||
|
||||
std::vector<std::string> Dataset::GetColumnNames() {
|
||||
std::vector<std::string> col_names;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed.";
|
||||
return std::vector<std::string>();
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed.";
|
||||
return std::vector<std::string>();
|
||||
}
|
||||
rc = tree_getters_->GetColumnNames(&col_names);
|
||||
return rc.IsError() ? std::vector<std::string>() : col_names;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetColumnNames(&col_names), {});
|
||||
return col_names;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed.";
|
||||
return output_class_indexing;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed.";
|
||||
return output_class_indexing;
|
||||
}
|
||||
rc = tree_getters_->GetClassIndexing(&output_class_indexing);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed.";
|
||||
output_class_indexing.clear();
|
||||
return output_class_indexing;
|
||||
}
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetClassIndexing(&output_class_indexing), {});
|
||||
return output_class_indexing;
|
||||
}
|
||||
|
||||
|
@ -501,9 +431,13 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset(
|
|||
std::function<TensorRow(TensorRow)> element_length_function,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder) {
|
||||
auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries,
|
||||
bucket_batch_sizes, element_length_function, pad_info,
|
||||
pad_to_bucket_boundary, drop_remainder);
|
||||
std::shared_ptr<TensorOp> c_func = nullptr;
|
||||
if (element_length_function != nullptr) {
|
||||
c_func = std::make_shared<CFuncOp>(element_length_function);
|
||||
}
|
||||
auto ds =
|
||||
std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, bucket_batch_sizes,
|
||||
c_func, pad_info, pad_to_bucket_boundary, drop_remainder);
|
||||
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
@ -522,7 +456,9 @@ ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datase
|
|||
|
||||
FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate,
|
||||
std::vector<std::string> input_columns) {
|
||||
auto ds = std::make_shared<FilterNode>(input->IRNode(), predicate, input_columns);
|
||||
std::shared_ptr<TensorOp> c_func = nullptr;
|
||||
if (predicate) c_func = std::make_shared<CFuncOp>(predicate);
|
||||
auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, input_columns);
|
||||
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
@ -604,40 +540,20 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
|||
#endif
|
||||
int64_t Dataset::GetBatchSize() {
|
||||
int64_t batch_size;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetBatchSize(&batch_size);
|
||||
return rc.IsError() ? -1 : batch_size;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1);
|
||||
return batch_size;
|
||||
}
|
||||
|
||||
int64_t Dataset::GetRepeatCount() {
|
||||
int64_t repeat_count;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetRepeatCount(&repeat_count);
|
||||
return rc.IsError() ? 0 : repeat_count;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0);
|
||||
return repeat_count;
|
||||
}
|
||||
|
||||
std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
|
||||
|
@ -720,62 +636,65 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
|
|||
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}
|
||||
|
||||
// SchemaObj init function
|
||||
bool SchemaObj::init() {
|
||||
if (schema_file_ != "") {
|
||||
Status SchemaObj::init() {
|
||||
if (!schema_file_.empty()) {
|
||||
Path schema_file(schema_file_);
|
||||
if (!schema_file.Exists()) {
|
||||
MS_LOG(ERROR) << "The file " << schema_file << " does not exist or permission denied!";
|
||||
return false;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(),
|
||||
"The file " + schema_file_ + " does not exist or permission denied!");
|
||||
|
||||
nlohmann::json js;
|
||||
try {
|
||||
std::ifstream in(schema_file_);
|
||||
in >> js;
|
||||
if (js.find("columns") == js.end()) {
|
||||
MS_LOG(ERROR) << "\"columns\" node is required in the schema json file.";
|
||||
return false;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
|
||||
"\"columns\" node is required in the schema json file.");
|
||||
} catch (const std::exception &err) {
|
||||
MS_LOG(ERROR) << "Schema file failed to load";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Schema file failed to load");
|
||||
}
|
||||
return from_json(js);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to add a column to schema with a mstype de_type
|
||||
bool SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) {
|
||||
nlohmann::json new_column;
|
||||
new_column["name"] = name;
|
||||
// if de_type is mstype
|
||||
// Function to add a column to schema with a mstype de_type and known shape
|
||||
Status SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) {
|
||||
DataType data_type = dataset::MSTypeToDEType(de_type);
|
||||
new_column["type"] = data_type.ToString();
|
||||
if (shape.size() > 0) {
|
||||
new_column["shape"] = shape;
|
||||
new_column["rank"] = shape.size();
|
||||
} else {
|
||||
new_column["rank"] = 1;
|
||||
}
|
||||
columns_.push_back(new_column);
|
||||
return true;
|
||||
return add_column(name, data_type.ToString(), shape);
|
||||
}
|
||||
|
||||
// Function to add a column to schema with a string de_type
|
||||
bool SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) {
|
||||
// Function to add a column to schema with a string de_type and known shape
|
||||
Status SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) {
|
||||
DataType data_type(de_type);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
|
||||
|
||||
nlohmann::json new_column;
|
||||
new_column["name"] = name;
|
||||
DataType data_type(de_type);
|
||||
new_column["type"] = data_type.ToString();
|
||||
if (shape.size() > 0) {
|
||||
new_column["shape"] = shape;
|
||||
new_column["rank"] = shape.size();
|
||||
} else {
|
||||
new_column["rank"] = 1;
|
||||
}
|
||||
new_column["shape"] = shape;
|
||||
new_column["rank"] = shape.size();
|
||||
|
||||
columns_.push_back(new_column);
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to add a column to schema with a mstype de_type and without shape
|
||||
Status SchemaObj::add_column(std::string name, TypeId de_type) {
|
||||
DataType data_type = dataset::MSTypeToDEType(de_type);
|
||||
return add_column(name, data_type.ToString());
|
||||
}
|
||||
|
||||
// Function to add a column to schema with a string de_type and without shape
|
||||
Status SchemaObj::add_column(std::string name, std::string de_type) {
|
||||
DataType data_type(de_type);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
|
||||
|
||||
nlohmann::json new_column;
|
||||
new_column["name"] = name;
|
||||
new_column["type"] = data_type.ToString();
|
||||
new_column["rank"] = 1;
|
||||
|
||||
columns_.push_back(new_column);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string SchemaObj::to_json() {
|
||||
|
@ -792,7 +711,7 @@ std::string SchemaObj::to_json() {
|
|||
return json_file.dump(2);
|
||||
}
|
||||
|
||||
bool SchemaObj::parse_column(nlohmann::json columns) {
|
||||
Status SchemaObj::parse_column(nlohmann::json columns) {
|
||||
std::string name, de_type;
|
||||
std::vector<int32_t> shape;
|
||||
|
||||
|
@ -802,15 +721,13 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
|
|||
for (auto column : columns) {
|
||||
auto key_name = column.find("name");
|
||||
if (key_name == column.end()) {
|
||||
MS_LOG(ERROR) << "Column's name is missing";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Column's name is missing");
|
||||
}
|
||||
name = *key_name;
|
||||
|
||||
auto key_type = column.find("type");
|
||||
if (key_type == column.end()) {
|
||||
MS_LOG(ERROR) << "Column's type is missing";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
|
||||
}
|
||||
de_type = *key_type;
|
||||
|
||||
|
@ -819,17 +736,14 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
|
|||
if (key_shape != column.end()) {
|
||||
shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
|
||||
}
|
||||
if (!add_column(name, de_type, shape)) {
|
||||
return false;
|
||||
}
|
||||
RETURN_IF_NOT_OK(add_column(name, de_type, shape));
|
||||
}
|
||||
} else if (columns.type() == nlohmann::json::value_t::object) {
|
||||
for (const auto &it_child : columns.items()) {
|
||||
name = it_child.key();
|
||||
auto key_type = it_child.value().find("type");
|
||||
if (key_type == it_child.value().end()) {
|
||||
MS_LOG(ERROR) << "Column's type is missing";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
|
||||
}
|
||||
de_type = *key_type;
|
||||
|
||||
|
@ -839,43 +753,45 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
|
|||
shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
|
||||
}
|
||||
|
||||
if (!add_column(name, de_type, shape)) {
|
||||
return false;
|
||||
}
|
||||
RETURN_IF_NOT_OK(add_column(name, de_type, shape));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "columns must be dict or list, columns contain name, type, shape(optional).";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional).");
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool SchemaObj::from_json(nlohmann::json json_obj) {
|
||||
Status SchemaObj::from_json(nlohmann::json json_obj) {
|
||||
for (const auto &it_child : json_obj.items()) {
|
||||
if (it_child.key() == "datasetType") {
|
||||
dataset_type_ = it_child.value();
|
||||
} else if (it_child.key() == "numRows") {
|
||||
num_rows_ = it_child.value();
|
||||
} else if (it_child.key() == "columns") {
|
||||
if (!parse_column(it_child.value())) {
|
||||
MS_LOG(ERROR) << "parse columns failed";
|
||||
return false;
|
||||
}
|
||||
RETURN_IF_NOT_OK(parse_column(it_child.value()));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknown field " << it_child.key();
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key());
|
||||
}
|
||||
}
|
||||
if (columns_.empty()) {
|
||||
MS_LOG(ERROR) << "Columns are missing.";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Columns are missing.");
|
||||
}
|
||||
if (num_rows_ <= 0) {
|
||||
MS_LOG(ERROR) << "numRows must be greater than 0";
|
||||
return false;
|
||||
if (num_rows_ < 0) {
|
||||
RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0");
|
||||
}
|
||||
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
Status SchemaObj::FromJSONString(const std::string &json_string) {
|
||||
try {
|
||||
nlohmann::json js = nlohmann::json::parse(json_string);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
|
||||
"\"columns\" node is required in the schema json JSON.");
|
||||
RETURN_IF_NOT_OK(from_json(js));
|
||||
} catch (const std::exception &err) {
|
||||
RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// OTHER FUNCTIONS
|
||||
|
|
|
@ -1,136 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
DEPipeline, 0, ([](const py::module *m) {
|
||||
(void)py::class_<DEPipeline>(*m, "DEPipeline")
|
||||
.def(py::init<>())
|
||||
.def(
|
||||
"AddNodeToTree",
|
||||
[](DEPipeline &de, const OpName &op_name, const py::dict &args) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
|
||||
return out;
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def_static("AddChildToParentNode",
|
||||
[](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
|
||||
THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
|
||||
})
|
||||
.def("AssignRootNode",
|
||||
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
|
||||
.def("SetBatchParameters",
|
||||
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
|
||||
.def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); })
|
||||
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
|
||||
.def("GetColumnNames",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetColumnNames(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetNextAsMap",
|
||||
[](DEPipeline &de) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(de.GetNextAsMap(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetNextAsList",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetNextAsList(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetOutputShapes",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetOutputShapes(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetOutputTypes",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetOutputTypes(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetDataInfo",
|
||||
[](DEPipeline &de) {
|
||||
py::list types, shapes;
|
||||
THROW_IF_ERROR(de.GetDataInfo(&types, &shapes));
|
||||
return py::make_tuple(types, shapes);
|
||||
})
|
||||
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
|
||||
.def("GetBatchSize", &DEPipeline::GetBatchSize)
|
||||
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
||||
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
|
||||
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
|
||||
.def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
|
||||
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
|
||||
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
|
||||
return true;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(OpName, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<OpName>(*m, "OpName", py::arithmetic())
|
||||
.value("SHUFFLE", OpName::kShuffle)
|
||||
.value("BATCH", OpName::kBatch)
|
||||
.value("BUCKETBATCH", OpName::kBucketBatch)
|
||||
.value("BARRIER", OpName::kBarrier)
|
||||
.value("MINDRECORD", OpName::kMindrecord)
|
||||
.value("CACHE", OpName::kCache)
|
||||
.value("REPEAT", OpName::kRepeat)
|
||||
.value("SKIP", OpName::kSkip)
|
||||
.value("TAKE", OpName::kTake)
|
||||
.value("ZIP", OpName::kZip)
|
||||
.value("CONCAT", OpName::kConcat)
|
||||
.value("MAP", OpName::kMap)
|
||||
.value("FILTER", OpName::kFilter)
|
||||
.value("DEVICEQUEUE", OpName::kDeviceQueue)
|
||||
.value("GENERATOR", OpName::kGenerator)
|
||||
.export_values()
|
||||
.value("RENAME", OpName::kRename)
|
||||
.value("TFREADER", OpName::kTfReader)
|
||||
.value("PROJECT", OpName::kProject)
|
||||
.value("IMAGEFOLDER", OpName::kImageFolder)
|
||||
.value("MNIST", OpName::kMnist)
|
||||
.value("MANIFEST", OpName::kManifest)
|
||||
.value("VOC", OpName::kVoc)
|
||||
.value("COCO", OpName::kCoco)
|
||||
.value("CIFAR10", OpName::kCifar10)
|
||||
.value("CIFAR100", OpName::kCifar100)
|
||||
.value("RANDOMDATA", OpName::kRandomData)
|
||||
.value("BUILDVOCAB", OpName::kBuildVocab)
|
||||
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
|
||||
.value("CELEBA", OpName::kCelebA)
|
||||
.value("TEXTFILE", OpName::kTextFile)
|
||||
.value("EPOCHCTRL", OpName::kEpochCtrl)
|
||||
.value("CSV", OpName::kCsv)
|
||||
.value("CLUE", OpName::kClue);
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -19,8 +19,10 @@
|
|||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/core/client.h" // DE client
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
|
|
@ -0,0 +1,551 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR non-leaf nodes - for android
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
|
||||
// IR leaf nodes disabled for android
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
|
||||
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
|
||||
.def("SetNumWorkers",
|
||||
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) {
|
||||
return num_workers ? self->SetNumWorkers(*num_workers) : self;
|
||||
})
|
||||
.def(
|
||||
"Zip",
|
||||
[](std::shared_ptr<DatasetNode> self, py::list datasets) {
|
||||
auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets)));
|
||||
THROW_IF_ERROR(zip->ValidateParams());
|
||||
return zip;
|
||||
},
|
||||
py::arg("datasets"));
|
||||
}));
|
||||
|
||||
// PYBIND FOR LEAF NODES
|
||||
// (In alphabetical order)
|
||||
|
||||
PYBIND_REGISTER(
|
||||
CelebANode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", "to create a CelebANode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, bool decode,
|
||||
std::optional<py::list> extensions, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
|
||||
toStringSet(extensions), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(celebA->ValidateParams());
|
||||
return celebA;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
|
||||
"to create a Cifar10Node")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler),
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(cifar10->ValidateParams());
|
||||
return cifar10;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
|
||||
"to create a Cifar100Node")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler),
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(cifar100->ValidateParams());
|
||||
return cifar100;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
CLUENode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode")
|
||||
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
std::shared_ptr<CLUENode> clue_node =
|
||||
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle),
|
||||
num_shards, shard_id, toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(clue_node->ValidateParams());
|
||||
return clue_node;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
CocoNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", "to create a CocoNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, bool decode,
|
||||
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
|
||||
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(coco->ValidateParams());
|
||||
return coco;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
|
||||
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
|
||||
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
|
||||
int32_t num_shards, int32_t shard_id,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults),
|
||||
column_names, num_samples, toShuffleMode(shuffle),
|
||||
num_shards, shard_id, toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(csv->ValidateParams());
|
||||
return csv;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
|
||||
*m, "GeneratorNode", "to create a GeneratorNode")
|
||||
.def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
|
||||
const std::vector<DataType> &column_types) {
|
||||
auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types);
|
||||
THROW_IF_ERROR(gen->ValidateParams());
|
||||
return gen;
|
||||
}))
|
||||
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema) {
|
||||
auto gen = std::make_shared<GeneratorNode>(generator_function, schema);
|
||||
THROW_IF_ERROR(gen->ValidateParams());
|
||||
return gen;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
|
||||
*m, "ImageFolderNode", "to create an ImageFolderNode")
|
||||
.def(py::init([](std::string dataset_dir, bool decode, std::optional<py::handle> sampler,
|
||||
std::optional<py::list> extensions, std::optional<py::dict> class_indexing,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
bool recursive = true;
|
||||
auto imagefolder = std::make_shared<ImageFolderNode>(
|
||||
dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions),
|
||||
toStringMap(class_indexing), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(imagefolder->ValidateParams());
|
||||
return imagefolder;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
|
||||
"to create a ManifestNode")
|
||||
.def(py::init([](std::string dataset_file, std::string usage, std::optional<py::handle> sampler,
|
||||
std::optional<py::dict> class_indexing, bool decode,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
|
||||
toStringMap(class_indexing), decode,
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(manifest->ValidateParams());
|
||||
return manifest;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
|
||||
"to create a MindDataNode")
|
||||
.def(py::init([](std::string dataset_file, std::optional<py::list> columns_list,
|
||||
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
|
||||
nlohmann::json padded_sample_json;
|
||||
std::map<std::string, std::string> sample_bytes;
|
||||
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
|
||||
auto minddata =
|
||||
std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list),
|
||||
toSamplerObj(sampler, true), padded_sample_json, num_padded);
|
||||
minddata->SetSampleBytes(&sample_bytes);
|
||||
THROW_IF_ERROR(minddata->ValidateParams());
|
||||
return minddata;
|
||||
}))
|
||||
.def(py::init([](py::list dataset_file, std::optional<py::list> columns_list,
|
||||
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
|
||||
nlohmann::json padded_sample_json;
|
||||
std::map<std::string, std::string> sample_bytes;
|
||||
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
|
||||
auto minddata =
|
||||
std::make_shared<MindDataNode>(toStringVector(dataset_file), toStringVector(columns_list),
|
||||
toSamplerObj(sampler, true), padded_sample_json, num_padded);
|
||||
minddata->SetSampleBytes(&sample_bytes);
|
||||
THROW_IF_ERROR(minddata->ValidateParams());
|
||||
return minddata;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
|
||||
"to create an MnistNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler),
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(mnist->ValidateParams());
|
||||
return mnist;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RandomNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode")
|
||||
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto random_node =
|
||||
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(random_node->ValidateParams());
|
||||
return random_node;
|
||||
}))
|
||||
.def(py::init([](int32_t total_rows, std::string schema, std::optional<py::list> columns_list,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto random_node =
|
||||
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(random_node->ValidateParams());
|
||||
return random_node;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
|
||||
"to create a TextFileNode")
|
||||
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
|
||||
int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>(
|
||||
toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id,
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(textfile_node->ValidateParams());
|
||||
return textfile_node;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
TFRecordNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
|
||||
"to create a TFRecordNode")
|
||||
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
|
||||
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
|
||||
std::optional<int32_t> shard_id, bool shard_equal_rows,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
if (!num_samples) {
|
||||
*num_samples = 0;
|
||||
}
|
||||
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
||||
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
|
||||
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(tfrecord->ValidateParams());
|
||||
return tfrecord;
|
||||
}))
|
||||
.def(py::init([](py::list dataset_files, std::string schema, std::optional<py::list> columns_list,
|
||||
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
|
||||
std::optional<int32_t> shard_id, bool shard_equal_rows,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
if (!num_samples) {
|
||||
*num_samples = 0;
|
||||
}
|
||||
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
||||
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
|
||||
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(tfrecord->ValidateParams());
|
||||
return tfrecord;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
|
||||
.def(
|
||||
py::init([](std::string dataset_dir, std::string task, std::string usage,
|
||||
std::optional<py::dict> class_indexing, bool decode,
|
||||
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
std::shared_ptr<VOCNode> voc =
|
||||
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
|
||||
toSamplerObj(sampler), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(voc->ValidateParams());
|
||||
return voc;
|
||||
}));
|
||||
}));
|
||||
|
||||
// PYBIND FOR NON-LEAF NODES
|
||||
// (In alphabetical order)
|
||||
|
||||
PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode",
|
||||
"to create a BatchNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t batch_size, bool drop_remainder,
|
||||
bool pad, py::list in_col_names, py::list out_col_names, py::list col_order,
|
||||
py::object size_obj, py::object map_obj, py::dict pad_info) {
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
|
||||
if (pad) {
|
||||
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
|
||||
}
|
||||
py::function size_func =
|
||||
py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function();
|
||||
py::function map_func =
|
||||
py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function();
|
||||
auto batch = std::make_shared<BatchNode>(
|
||||
self, batch_size, drop_remainder, pad, toStringVector(in_col_names),
|
||||
toStringVector(out_col_names), toStringVector(col_order), size_func, map_func, c_pad_info);
|
||||
THROW_IF_ERROR(batch->ValidateParams());
|
||||
return batch;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>(
|
||||
*m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> dataset, py::list column_names,
|
||||
std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes,
|
||||
py::object element_length_function, py::dict pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder) {
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
|
||||
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
|
||||
|
||||
auto bucket_batch = std::make_shared<BucketBatchByLengthNode>(
|
||||
dataset, toStringVector(column_names), bucket_boundaries, bucket_batch_sizes,
|
||||
toPyFuncOp(std::move(element_length_function), DataType::DE_INT32), c_pad_info,
|
||||
pad_to_bucket_boundary, drop_remainder);
|
||||
THROW_IF_ERROR(bucket_batch->ValidateParams());
|
||||
return bucket_batch;
|
||||
}),
|
||||
py::arg("dataset"), py::arg("column_names"), py::arg("bucket_boundaries"),
|
||||
py::arg("bucket_batch_sizes"), py::arg("element_length_function") = py::none(),
|
||||
py::arg("pad_info"), py::arg("pad_to_bucket_boundary"), py::arg("drop_remainder"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>(
|
||||
*m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<SentencePieceVocab> vocab,
|
||||
const std::vector<std::string> &col_names, uint32_t vocab_size,
|
||||
float character_coverage, SentencePieceModel model_type,
|
||||
const std::unordered_map<std::string, std::string> ¶ms) {
|
||||
auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>(
|
||||
self, vocab, col_names, vocab_size, character_coverage, model_type, params);
|
||||
THROW_IF_ERROR(build_sentence_vocab->ValidateParams());
|
||||
return build_sentence_vocab;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>(
|
||||
*m, "BuildVocabNode", "to create a BuildVocabNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<Vocab> vocab, py::list columns,
|
||||
py::tuple freq_range, int64_t top_k, py::list special_tokens, bool special_first) {
|
||||
auto build_vocab =
|
||||
std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range),
|
||||
top_k, toStringVector(special_tokens), special_first);
|
||||
THROW_IF_ERROR(build_vocab->ValidateParams());
|
||||
return build_vocab;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode",
|
||||
"to create a ConcatNode")
|
||||
.def(
|
||||
py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, std::optional<py::handle> sampler,
|
||||
py::list children_flag_and_nums, py::list children_start_end_index) {
|
||||
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
|
||||
toPairVector(children_flag_and_nums),
|
||||
toPairVector(children_start_end_index));
|
||||
THROW_IF_ERROR(concat->ValidateParams());
|
||||
return concat;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode",
|
||||
"to create a FilterNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::object predicate,
|
||||
std::vector<std::string> input_columns) {
|
||||
auto filter =
|
||||
std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns);
|
||||
THROW_IF_ERROR(filter->ValidateParams());
|
||||
return filter;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> operations,
|
||||
std::optional<py::list> input_columns, std::optional<py::list> output_columns,
|
||||
std::optional<py::list> project_columns,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc,
|
||||
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) {
|
||||
auto map = std::make_shared<MapNode>(
|
||||
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
|
||||
toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)),
|
||||
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()));
|
||||
THROW_IF_ERROR(map->ValidateParams());
|
||||
return map;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode",
|
||||
"to create a ProjectNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list columns) {
|
||||
auto project = std::make_shared<ProjectNode>(self, toStringVector(columns));
|
||||
THROW_IF_ERROR(project->ValidateParams());
|
||||
return project;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode",
|
||||
"to create a RenameNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> input_columns,
|
||||
std::optional<py::list> output_columns) {
|
||||
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns),
|
||||
toStringVector(output_columns));
|
||||
THROW_IF_ERROR(rename->ValidateParams());
|
||||
return rename;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode",
|
||||
"to create a RepeatNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> input, int32_t count) {
|
||||
auto repeat = std::make_shared<RepeatNode>(input, count);
|
||||
THROW_IF_ERROR(repeat->ValidateParams());
|
||||
return repeat;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode",
|
||||
"to create a ShuffleNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t shuffle_size, bool reset_every_epoch) {
|
||||
auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch);
|
||||
THROW_IF_ERROR(shuffle->ValidateParams());
|
||||
return shuffle;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode",
|
||||
"to create a SkipNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
|
||||
auto skip = std::make_shared<SkipNode>(self, count);
|
||||
THROW_IF_ERROR(skip->ValidateParams());
|
||||
return skip;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode",
|
||||
"to create a SyncWaitNode")
|
||||
.def(
|
||||
py::init([](std::shared_ptr<DatasetNode> self, std::string condition_name, py::object callback) {
|
||||
py::function callback_func =
|
||||
py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function();
|
||||
auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback);
|
||||
THROW_IF_ERROR(sync_wait->ValidateParams());
|
||||
return sync_wait;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode",
|
||||
"to create a TakeNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
|
||||
auto take = std::make_shared<TakeNode>(self, count);
|
||||
THROW_IF_ERROR(take->ValidateParams());
|
||||
return take;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
|
||||
"to create a TransferNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type,
|
||||
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) {
|
||||
auto transfer = std::make_shared<TransferNode>(self, queue_name, device_type, send_epoch_end,
|
||||
total_batch, create_data_info_queue);
|
||||
THROW_IF_ERROR(transfer->ValidateParams());
|
||||
return transfer;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode")
|
||||
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets) {
|
||||
auto zip = std::make_shared<ZipNode>(datasets);
|
||||
THROW_IF_ERROR(zip->ValidateParams());
|
||||
return zip;
|
||||
}));
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,168 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
|
||||
#include "minddata/dataset/engine/python_runtime_context.h"
|
||||
#include "minddata/dataset/engine/consumers/python_tree_consumer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
|
||||
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer");
|
||||
}));
|
||||
PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>(
|
||||
*m, "PythonIteratorConsumer")
|
||||
.def(py::init<int32_t>())
|
||||
.def("Init", [](PythonIteratorConsumer &self,
|
||||
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("GetNextAsMap",
|
||||
[](PythonIteratorConsumer &self) {
|
||||
py::dict output;
|
||||
THROW_IF_ERROR(self.GetNextAsDict(&output));
|
||||
return output;
|
||||
})
|
||||
.def("GetNextAsList", [](PythonIteratorConsumer &self) {
|
||||
py::list output;
|
||||
THROW_IF_ERROR(self.GetNextAsList(&output));
|
||||
return output;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonTreeGetters, TreeConsumer, std::shared_ptr<PythonTreeGetters>>(*m,
|
||||
"TreeGetters")
|
||||
.def(py::init<>())
|
||||
.def("Init",
|
||||
[](PythonTreeGetters &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("GetOutputShapes",
|
||||
[](PythonTreeGetters &self) {
|
||||
std::vector<TensorShape> shapes;
|
||||
THROW_IF_ERROR(self.GetOutputShapes(&shapes));
|
||||
return shapesToListOfShape(shapes);
|
||||
})
|
||||
.def("GetOutputTypes",
|
||||
[](PythonTreeGetters &self) {
|
||||
std::vector<DataType> types;
|
||||
THROW_IF_ERROR(self.GetOutputTypes(&types));
|
||||
return typesToListOfType(types);
|
||||
})
|
||||
.def("GetNumClasses",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t num_classes;
|
||||
THROW_IF_ERROR(self.GetNumClasses(&num_classes));
|
||||
return num_classes;
|
||||
})
|
||||
.def("GetRepeatCount",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t repeat_count;
|
||||
THROW_IF_ERROR(self.GetRepeatCount(&repeat_count));
|
||||
return repeat_count;
|
||||
})
|
||||
.def("GetBatchSize",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t batch_size;
|
||||
THROW_IF_ERROR(self.GetBatchSize(&batch_size));
|
||||
return batch_size;
|
||||
})
|
||||
.def("GetColumnNames",
|
||||
[](PythonTreeGetters &self) {
|
||||
std::vector<std::string> col_names;
|
||||
THROW_IF_ERROR(self.GetColumnNames(&col_names));
|
||||
return col_names;
|
||||
})
|
||||
.def("GetClassIndexing",
|
||||
[](PythonTreeGetters &self) {
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
||||
THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing));
|
||||
return output_class_indexing;
|
||||
})
|
||||
.def("GetDatasetSize",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t dataset_size;
|
||||
THROW_IF_ERROR(self.GetDatasetSize(&dataset_size));
|
||||
return dataset_size;
|
||||
})
|
||||
.def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonRuntimeContext, 2, ([](const py::module *m) {
|
||||
(void)py::class_<PythonRuntimeContext, std::shared_ptr<PythonRuntimeContext>>(*m,
|
||||
"PythonRuntimeContext")
|
||||
.def(py::init<>())
|
||||
.def("Init", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Init()); })
|
||||
.def("AssignConsumer", &PythonRuntimeContext::AssignConsumer)
|
||||
.def("Terminate", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Terminate()); })
|
||||
.def("GetConsumer", &PythonRuntimeContext::GetPythonConsumer, py::return_value_policy::reference)
|
||||
.def("__deepcopy__", [](py::object &runtime_context, py::dict memo) { return runtime_context; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonBuildVocabConsumer, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonBuildVocabConsumer, TreeConsumer, std::shared_ptr<PythonBuildVocabConsumer>>(
|
||||
*m, "PythonBuildVocabConsumer")
|
||||
.def(py::init<>())
|
||||
.def("Init", [](PythonBuildVocabConsumer &self,
|
||||
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("Start", [](PythonBuildVocabConsumer &self) { THROW_IF_ERROR(self.Start()); });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ToDevice, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ToDevice, TreeConsumer, std::shared_ptr<ToDevice>>(*m, "ToDevice")
|
||||
.def(py::init<int32_t>())
|
||||
.def("Init", [](ToDevice &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); })
|
||||
.def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); })
|
||||
.def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); })
|
||||
.def("GetDataInfo",
|
||||
[](ToDevice &self) {
|
||||
std::vector<DataType> types_c;
|
||||
std::vector<TensorShape> shapes_c;
|
||||
{
|
||||
py::gil_scoped_release rel;
|
||||
THROW_IF_ERROR(self.GetDataInfo(&types_c, &shapes_c));
|
||||
}
|
||||
py::list types, shapes;
|
||||
for (auto el : types_c) {
|
||||
types.append(el.AsNumpyType());
|
||||
py::list shape;
|
||||
}
|
||||
for (auto el : shapes_c) {
|
||||
py::list shape = el.AsPyList();
|
||||
shapes.append(shape);
|
||||
}
|
||||
return py::make_tuple(types, shapes);
|
||||
})
|
||||
.def("__deepcopy__", [](py::object &to_device, py::dict memo) { return to_device; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonSaveToDisk, TreeConsumer, std::shared_ptr<PythonSaveToDisk>>(
|
||||
*m, "PythonSaveToDisk")
|
||||
.def(py::init([](std::string &dataset_path, int32_t numFiles, std::string &datasetType) {
|
||||
auto save = std::make_shared<PythonSaveToDisk>(dataset_path, numFiles, datasetType);
|
||||
THROW_IF_ERROR(save->ValidateParams());
|
||||
return save;
|
||||
}))
|
||||
.def("Init",
|
||||
[](PythonSaveToDisk &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); });
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
SchemaObj, 0, ([](const py::module *m) {
|
||||
(void)py::class_<SchemaObj, std::shared_ptr<SchemaObj>>(*m, "SchemaObj", "to create a SchemaObj")
|
||||
.def(py::init([](std::string schema_file) {
|
||||
auto schema = std::make_shared<SchemaObj>(schema_file);
|
||||
THROW_IF_ERROR(schema->init());
|
||||
return schema;
|
||||
}))
|
||||
.def("add_column", [](SchemaObj &self, std::string name, TypeId de_type,
|
||||
std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); })
|
||||
.def("add_column", [](SchemaObj &self, std::string name, std::string de_type,
|
||||
std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); })
|
||||
.def("add_column",
|
||||
[](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
|
||||
.def("add_column", [](SchemaObj &self, std::string name,
|
||||
std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
|
||||
.def("to_json", &SchemaObj::to_json)
|
||||
.def("to_string", &SchemaObj::to_string)
|
||||
.def("from_string",
|
||||
[](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.FromJSONString(json_string)); })
|
||||
.def("set_dataset_type", [](SchemaObj &self, std::string dataset_type) { self.set_dataset_type(dataset_type); })
|
||||
.def("set_num_rows", [](SchemaObj &self, int32_t num_rows) { self.set_num_rows(num_rows); })
|
||||
.def("get_num_rows", &SchemaObj::get_num_rows)
|
||||
.def("__deepcopy__", [](py::object &schema, py::dict memo) { return schema; });
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
#include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,265 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/core/client.h" // DE client
|
||||
#include "minddata/dataset/engine/dataset_iterator.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
using json = nlohmann::json;
|
||||
using DsOpPtr = std::shared_ptr<DatasetOp>;
|
||||
|
||||
class CacheClient;
|
||||
|
||||
// enum for the dataset operator names
|
||||
enum OpName {
|
||||
kShuffle,
|
||||
kMindrecord,
|
||||
kBatch,
|
||||
kBucketBatch,
|
||||
kBarrier,
|
||||
kCache,
|
||||
kRepeat,
|
||||
kSkip,
|
||||
kTake,
|
||||
kZip,
|
||||
kConcat,
|
||||
kMap,
|
||||
kFilter,
|
||||
kDeviceQueue,
|
||||
kGenerator,
|
||||
kRename,
|
||||
kTfReader,
|
||||
kProject,
|
||||
kImageFolder,
|
||||
kMnist,
|
||||
kManifest,
|
||||
kVoc,
|
||||
kCoco,
|
||||
kCifar10,
|
||||
kCifar100,
|
||||
kCelebA,
|
||||
kRandomData,
|
||||
kTextFile,
|
||||
kBuildVocab,
|
||||
kClue,
|
||||
kEpochCtrl,
|
||||
kSentencePieceVocab,
|
||||
kCsv
|
||||
};
|
||||
|
||||
// The C++ binder class that we expose to the python script.
|
||||
class DEPipeline {
|
||||
public:
|
||||
DEPipeline();
|
||||
|
||||
~DEPipeline();
|
||||
|
||||
// Function to add a Node to the Execution Tree.
|
||||
Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output);
|
||||
|
||||
// Function to add a child and parent relationship.
|
||||
static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op);
|
||||
|
||||
// Function to assign the node as root.
|
||||
Status AssignRootNode(const DsOpPtr &dataset_op);
|
||||
|
||||
// Function to get the column names in the last node in the tree in order
|
||||
Status GetColumnNames(py::list *output);
|
||||
|
||||
// Function to prepare the tree for execution
|
||||
Status PrepareTree(const int32_t num_epochs);
|
||||
|
||||
// Function to launch the tree execution.
|
||||
Status LaunchTreeExec();
|
||||
|
||||
// Get a row of data as dictionary of column name to the value.
|
||||
Status GetNextAsMap(py::dict *output);
|
||||
|
||||
// Get a row of data as list.
|
||||
Status GetNextAsList(py::list *output);
|
||||
|
||||
Status GetOutputShapes(py::list *output);
|
||||
|
||||
Status GetOutputTypes(py::list *output);
|
||||
|
||||
Status GetDataInfo(py::list *types, py::list *shapes);
|
||||
|
||||
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
|
||||
|
||||
int GetDatasetSize() const;
|
||||
|
||||
int GetBatchSize() const;
|
||||
|
||||
int GetRepeatCount() const;
|
||||
|
||||
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
template <typename T, typename S>
|
||||
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
|
||||
std::unique_ptr<S> *s, bool need_convert = false);
|
||||
|
||||
Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
|
||||
const TensorRow &row, json *schema, std::vector<std::string> *index_fields);
|
||||
|
||||
Status FetchDataFromTensorRow(const TensorRow &row,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data,
|
||||
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
|
||||
|
||||
Status BuildMindrecordSamplerChain(const py::handle &handle,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
|
||||
int num_padded);
|
||||
|
||||
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseProjectOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseImageFolderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseManifestOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
void PrintTree();
|
||||
|
||||
int32_t GetNumClasses() const;
|
||||
|
||||
Status ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status SetBatchParameters(const py::dict &args);
|
||||
|
||||
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status StopSend();
|
||||
|
||||
Status ContinueSend();
|
||||
|
||||
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
private:
|
||||
// Execution tree that links the dataset operators.
|
||||
std::shared_ptr<ExecutionTree> tree_;
|
||||
|
||||
std::unique_ptr<DatasetIterator> iterator_;
|
||||
|
||||
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
|
||||
|
||||
/// \brief Helper function to inject a cache operator over top of the current operation being built.
|
||||
/// \param[in] cache_client The client to use for caching
|
||||
/// \param[in] num_workers The number of workers to use in the cache op
|
||||
/// \param[in] input_op The operator to build the cache on top of
|
||||
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
|
||||
/// the cache operator
|
||||
/// \return Status return code
|
||||
Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op,
|
||||
std::shared_ptr<DatasetOp> *cache_op);
|
||||
|
||||
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
|
||||
/// \param[in] shuffle_size The size to use in the shuffle buffer
|
||||
/// \param[in] input_op The operator to build shuffle on top of
|
||||
/// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be
|
||||
/// the shuffle operator
|
||||
/// \return Status return code
|
||||
Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
|
||||
std::shared_ptr<DatasetOp> *shuffle_op);
|
||||
|
||||
/// \brief Helper function to compute the shuffle size
|
||||
/// \param[in] num_files The number of files in the dataset
|
||||
/// \param[in] num_devices The number of devices in the dataset
|
||||
/// \param[in] num_rows The number of rows in the dataset
|
||||
/// \param[in] total_rows An upper bound on the total rows in the dataset
|
||||
/// \param[out] shuffle_size The resultant computed shuffle size
|
||||
/// \return Status return code
|
||||
Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
|
||||
int64_t *shuffle_size);
|
||||
|
||||
int batch_size_;
|
||||
int repeat_num_;
|
||||
int num_rows_;
|
||||
int num_classes_;
|
||||
|
||||
int temp_batch_size_;
|
||||
bool temp_drop_remainder_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
|
|
@ -0,0 +1,265 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
float toFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
|
||||
|
||||
int toInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
||||
|
||||
int64_t toInt64(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
||||
|
||||
bool toBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
|
||||
|
||||
std::string toString(const py::handle &handle) { return py::reinterpret_borrow<py::str>(handle); }
|
||||
|
||||
std::set<std::string> toStringSet(const std::optional<py::list> list) {
|
||||
std::set<std::string> set;
|
||||
if (list) {
|
||||
for (auto l : *list) {
|
||||
if (!l.is_none()) {
|
||||
(void)set.insert(py::str(l));
|
||||
}
|
||||
}
|
||||
}
|
||||
return set;
|
||||
}
|
||||
|
||||
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict) {
|
||||
std::map<std::string, int32_t> map;
|
||||
if (dict) {
|
||||
for (auto p : *dict) {
|
||||
(void)map.emplace(toString(p.first), toInt(p.second));
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
std::vector<std::string> toStringVector(const std::optional<py::list> list) {
|
||||
std::vector<std::string> vector;
|
||||
if (list) {
|
||||
for (auto l : *list) {
|
||||
if (l.is_none())
|
||||
vector.emplace_back("");
|
||||
else
|
||||
vector.push_back(py::str(l));
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple) {
|
||||
std::pair<int64_t, int64_t> pair;
|
||||
if (tuple) {
|
||||
pair = std::make_pair(toInt64((*tuple)[0]), toInt64((*tuple)[1]));
|
||||
}
|
||||
return pair;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int, int>> toPairVector(const py::list list) {
|
||||
std::vector<std::pair<int, int>> vector;
|
||||
if (list) {
|
||||
for (auto data : list) {
|
||||
auto l = data.cast<py::tuple>();
|
||||
if (l[1].is_none())
|
||||
vector.emplace_back(toInt64(l[0]), 0);
|
||||
else
|
||||
vector.emplace_back(toInt64(l[0]), toInt64(l[1]));
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> vector;
|
||||
if (operations) {
|
||||
for (auto op : *operations) {
|
||||
std::shared_ptr<TensorOp> tensor_op;
|
||||
if (py::isinstance<TensorOp>(op)) {
|
||||
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
||||
} else if (py::isinstance<py::function>(op)) {
|
||||
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
[]() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }());
|
||||
}
|
||||
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets) {
|
||||
std::vector<std::shared_ptr<DatasetNode>> vector;
|
||||
vector.push_back(self);
|
||||
if (datasets) {
|
||||
for (auto ds : *datasets) {
|
||||
if (py::isinstance<DatasetNode>(ds)) {
|
||||
vector.push_back(ds.cast<std::shared_ptr<DatasetNode>>());
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
[]() { RETURN_STATUS_UNEXPECTED("Error: datasets is not recognised (not a DatasetNode instance)."); }());
|
||||
}
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset) {
|
||||
if (py_sampler) {
|
||||
std::shared_ptr<SamplerObj> sampler_obj;
|
||||
if (!isMindDataset) {
|
||||
// Common Sampler
|
||||
std::shared_ptr<SamplerRT> sampler;
|
||||
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create");
|
||||
sampler = create().cast<std::shared_ptr<SamplerRT>>();
|
||||
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
||||
} else {
|
||||
// Mindrecord Sampler
|
||||
std::shared_ptr<mindrecord::ShardOperator> sampler;
|
||||
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create_for_minddataset");
|
||||
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
||||
}
|
||||
return sampler_obj;
|
||||
} else {
|
||||
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: sampler input is not SamplerRT."); }());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Here we take in a python object, that holds a reference to a C++ object
|
||||
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
if (cc) {
|
||||
std::shared_ptr<DatasetCache> built_cache;
|
||||
// Common Sampler
|
||||
built_cache = std::make_shared<PreBuiltDatasetCache>(std::move(cc.value()));
|
||||
return built_cache;
|
||||
} else {
|
||||
// don't need to check here as cache is not enabled.
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ShuffleMode toShuffleMode(const int32_t shuffle) {
|
||||
if (shuffle == 0) return ShuffleMode::kFalse;
|
||||
if (shuffle == 1) return ShuffleMode::kFiles;
|
||||
if (shuffle == 2) return ShuffleMode::kGlobal;
|
||||
return ShuffleMode();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases) {
|
||||
std::vector<std::shared_ptr<CsvBase>> vector;
|
||||
if (csv_bases) {
|
||||
for (auto base : *csv_bases) {
|
||||
if (py::isinstance<py::int_>(base)) {
|
||||
vector.push_back(std::make_shared<CsvRecord<int>>(CsvType::INT, toInt(base)));
|
||||
} else if (py::isinstance<py::float_>(base)) {
|
||||
vector.push_back(std::make_shared<CsvRecord<float>>(CsvType::FLOAT, toFloat(base)));
|
||||
} else if (py::isinstance<py::str>(base)) {
|
||||
vector.push_back(std::make_shared<CsvRecord<std::string>>(CsvType::STRING, toString(base)));
|
||||
} else {
|
||||
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: each default value must be int, float, or string"); }());
|
||||
}
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json,
|
||||
std::map<std::string, std::string> *sample_bytes) {
|
||||
for (const py::handle &key : padded_sample) {
|
||||
if (py::isinstance<py::bytes>(padded_sample[key])) {
|
||||
(*sample_bytes)[py::str(key).cast<std::string>()] = padded_sample[key].cast<std::string>();
|
||||
// py::str(key) enter here will loss its key name, so we create an unuse key for it in json, to pass ValidateParam
|
||||
(*padded_sample_json)[py::str(key).cast<std::string>()] = nlohmann::json::object();
|
||||
} else {
|
||||
nlohmann::json obj_json;
|
||||
if (padded_sample[key].is_none()) {
|
||||
obj_json = nullptr;
|
||||
} else if (py::isinstance<py::int_>(padded_sample[key])) {
|
||||
obj_json = padded_sample[key].cast<int64_t>();
|
||||
} else if (py::isinstance<py::float_>(padded_sample[key])) {
|
||||
obj_json = padded_sample[key].cast<double>();
|
||||
} else if (py::isinstance<py::str>(padded_sample[key])) {
|
||||
obj_json = padded_sample[key].cast<std::string>(); // also catch py::bytes
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Python object convert to json failed: " << py::cast<std::string>(padded_sample[key]);
|
||||
RETURN_STATUS_SYNTAX_ERROR("Python object convert to json failed");
|
||||
}
|
||||
(*padded_sample_json)[py::str(key).cast<std::string>()] = obj_json;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info) {
|
||||
for (auto p : value) {
|
||||
if (!p.second.is_none()) {
|
||||
auto tp = py::reinterpret_borrow<py::tuple>(p.second);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
|
||||
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
|
||||
std::shared_ptr<Tensor> pad_val = nullptr;
|
||||
if (py::isinstance<py::str>(tp[1])) {
|
||||
std::string pad_val_string = tp[1].is_none() ? "" : toString(tp[1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
Tensor::CreateFromVector(std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar(), &pad_val),
|
||||
"Cannot create pad_value Tensor");
|
||||
} else {
|
||||
float pad_val_float = tp[1].is_none() ? 0 : toFloat(tp[1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_val),
|
||||
"Cannot create pad_value Tensor");
|
||||
pad_val->SetItemAt<float>({}, pad_val_float);
|
||||
}
|
||||
(void)pad_info->insert({toString(p.first), {shape, pad_val}});
|
||||
} else { // tuple is None
|
||||
(void)pad_info->insert({toString(p.first), {TensorShape({}), nullptr}});
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type) {
|
||||
std::shared_ptr<TensorOp> py_func;
|
||||
if (!func.is_none()) {
|
||||
py::function py_function = func.cast<py::function>();
|
||||
py_func = std::make_shared<PyFuncOp>(py_function, data_type);
|
||||
} else {
|
||||
py_func = nullptr;
|
||||
}
|
||||
return py_func;
|
||||
}
|
||||
|
||||
py::list shapesToListOfShape(std::vector<TensorShape> shapes) {
|
||||
py::list shape_list;
|
||||
for (const auto &shape : shapes) {
|
||||
shape_list.append(shape.AsVector());
|
||||
}
|
||||
return shape_list;
|
||||
}
|
||||
|
||||
py::list typesToListOfType(std::vector<DataType> types) {
|
||||
py::list type_list;
|
||||
for (const auto &type : types) {
|
||||
type_list.append(type.AsNumpyType());
|
||||
}
|
||||
return type_list;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
float toFloat(const py::handle &handle);
|
||||
|
||||
int toInt(const py::handle &handle);
|
||||
|
||||
int64_t toInt64(const py::handle &handle);
|
||||
|
||||
bool toBool(const py::handle &handle);
|
||||
|
||||
std::string toString(const py::handle &handle);
|
||||
|
||||
std::set<std::string> toStringSet(const std::optional<py::list> list);
|
||||
|
||||
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict);
|
||||
|
||||
std::vector<std::string> toStringVector(const std::optional<py::list> list);
|
||||
|
||||
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple);
|
||||
|
||||
std::vector<std::pair<int, int>> toPairVector(const py::list list);
|
||||
|
||||
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations);
|
||||
|
||||
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets);
|
||||
|
||||
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false);
|
||||
|
||||
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc);
|
||||
|
||||
ShuffleMode toShuffleMode(const int32_t shuffle);
|
||||
|
||||
std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases);
|
||||
|
||||
std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type);
|
||||
|
||||
Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json,
|
||||
std::map<std::string, std::string> *sample_bytes);
|
||||
|
||||
Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info);
|
||||
|
||||
py::list shapesToListOfShape(std::vector<TensorShape> shapes);
|
||||
|
||||
py::list typesToListOfType(std::vector<DataType> types);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
|
|
@ -190,6 +190,23 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// PreBuiltOperation
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler)
|
||||
: sp_(std::move(sampler)), sp_minddataset_(nullptr) {}
|
||||
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
|
||||
: sp_(nullptr), sp_minddataset_(std::move(sampler)) {}
|
||||
#endif
|
||||
|
||||
bool PreBuiltSamplerObj::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
|
|
|
@ -222,6 +222,13 @@ Status OneHotOperation::ValidateParams() {
|
|||
|
||||
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
|
||||
|
||||
// PreBuiltOperation
|
||||
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {}
|
||||
|
||||
Status PreBuiltOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
|
||||
|
||||
// RandomApplyOperation
|
||||
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
|
||||
: transforms_(transforms), prob_(prob) {}
|
||||
|
|
|
@ -18,6 +18,7 @@ set(SRC_FILES_LIST
|
|||
dataset_iterator.cc
|
||||
tree_adapter.cc
|
||||
runtime_context.cc
|
||||
python_runtime_context.cc
|
||||
consumers/tree_consumer.cc
|
||||
)
|
||||
if (ENABLE_PYTHON)
|
||||
|
|
|
@ -32,15 +32,37 @@ Status PythonIteratorConsumer::GetNextAsList(py::list *out) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) {
|
||||
std::unordered_map<std::string, TensorPtr> row;
|
||||
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec;
|
||||
Status s;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
RETURN_IF_NOT_OK(GetNextAsMap(&row));
|
||||
s = GetNextAsOrderedPair(&vec);
|
||||
}
|
||||
for (auto el : row) {
|
||||
(*out)[common::SafeCStr(el.first)] = el.second;
|
||||
RETURN_IF_NOT_OK(s);
|
||||
// Generate Python dict, python dict maintains its insertion order
|
||||
for (const auto &pair : vec) {
|
||||
(*out)[common::SafeCStr(pair.first)] = pair.second;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonBuildVocabConsumer::Start() {
|
||||
py::gil_scoped_release gil_release;
|
||||
return BuildVocabConsumer::Start();
|
||||
}
|
||||
|
||||
Status PythonSaveToDisk::Save() {
|
||||
py::gil_scoped_release gil_release;
|
||||
return SaveToDisk::Save();
|
||||
}
|
||||
|
||||
PythonSaveToDisk::PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType)
|
||||
: SaveToDisk(datasetPath, numFiles, datasetType) {}
|
||||
|
||||
Status PythonTreeGetters::GetRow(TensorRow *r) {
|
||||
py::gil_scoped_release gil_release;
|
||||
return TreeGetters::GetRow(r);
|
||||
}
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -44,5 +44,21 @@ class PythonIteratorConsumer : public IteratorConsumer {
|
|||
/// \return Status error code
|
||||
Status GetNextAsDict(py::dict *out);
|
||||
};
|
||||
|
||||
class PythonBuildVocabConsumer : public BuildVocabConsumer {
|
||||
public:
|
||||
Status Start() override;
|
||||
};
|
||||
|
||||
class PythonSaveToDisk : public SaveToDisk {
|
||||
public:
|
||||
PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType);
|
||||
Status Save() override;
|
||||
};
|
||||
|
||||
class PythonTreeGetters : public TreeGetters {
|
||||
public:
|
||||
Status GetRow(TensorRow *r) override;
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <vector>
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/mindrecord/include/shard_header.h"
|
||||
|
@ -35,7 +36,7 @@ namespace mindspore::dataset {
|
|||
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
|
||||
|
||||
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); }
|
||||
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); }
|
||||
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->ServiceStop(); }
|
||||
|
||||
// IteratorConsumer
|
||||
Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
|
||||
|
@ -73,6 +74,38 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
|
||||
|
||||
TensorRow curr_row;
|
||||
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row));
|
||||
RETURN_OK_IF_TRUE(curr_row.empty());
|
||||
|
||||
size_t num_cols = curr_row.size(); // num_cols is non-empty.
|
||||
// order the column names according to their ids
|
||||
if (column_order_.empty()) {
|
||||
const int32_t invalid_col_id = -1;
|
||||
column_order_.resize(num_cols, {std::string(), invalid_col_id});
|
||||
for (const auto &itr : tree_adapter_->GetColumnNameMap()) {
|
||||
int32_t ind = itr.second;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
|
||||
column_order_[ind] = std::make_pair(itr.first, ind);
|
||||
}
|
||||
// error check, make sure the ids in col_name_id_map are continuous and starts from 0
|
||||
for (const auto &col : column_order_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(col.second != invalid_col_id, "column ids are not continuous.");
|
||||
}
|
||||
}
|
||||
|
||||
vec->reserve(num_cols);
|
||||
|
||||
std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
|
||||
[curr_row](const auto &col) { return std::make_pair(col.first, curr_row[col.second]); });
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ToDevice
|
||||
Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); }
|
||||
|
||||
|
@ -81,7 +114,6 @@ Status ToDevice::Send() {
|
|||
RETURN_IF_NOT_OK(tree_adapter_->Launch());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_IF_NOT_OK(root->GetNextBuffer(&db));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -101,9 +133,36 @@ Status ToDevice::Stop() {
|
|||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
|
||||
op->StopSend();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToDevice::GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes) {
|
||||
// tree_.root() must be DeviceQueueOp
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DeviceQueueOp");
|
||||
DATA_INFO data_info;
|
||||
RETURN_IF_NOT_OK(op->GetDataInfo(&data_info));
|
||||
for (auto el : data_info) {
|
||||
types->push_back(el.first);
|
||||
shapes->push_back(el.second);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToDevice::Terminate() {
|
||||
#ifdef ENABLE_TDTQUE
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
|
||||
op->StopWaiting();
|
||||
#endif
|
||||
return TreeConsumer::Terminate();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// SaveToDisk
|
||||
Status SaveToDisk::ValidateParams() {
|
||||
|
@ -282,50 +341,50 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
|
|||
if (column_type == DataType::DE_INT8) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<int8_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT16) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<int16_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT16) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<uint16_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT8) {
|
||||
std::unique_ptr<uint8_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT32) {
|
||||
std::unique_ptr<int32_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT32) {
|
||||
std::unique_ptr<int64_t> data;
|
||||
std::unique_ptr<uint32_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT64) {
|
||||
std::unique_ptr<int64_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_FLOAT32) {
|
||||
std::unique_ptr<float> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_FLOAT64) {
|
||||
std::unique_ptr<double> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_STRING) {
|
||||
|
@ -346,7 +405,7 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
|
||||
std::unique_ptr<S> *s, bool need_convert) {
|
||||
if (nullptr == src) {
|
||||
|
@ -379,47 +438,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
|
|||
}
|
||||
#endif
|
||||
|
||||
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) {
|
||||
tree_adapter_ = std::make_unique<TreeAdapter>();
|
||||
}
|
||||
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); }
|
||||
|
||||
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
|
||||
if (init_flag_) {
|
||||
return Status::OK();
|
||||
}
|
||||
Status s = tree_adapter_->Compile(std::move(d), 1);
|
||||
if (!s.IsError()) {
|
||||
init_flag_ = true;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
bool TreeGetters::isInitialized() { return init_flag_; }
|
||||
|
||||
Status TreeGetters::GetRow(TensorRow *row) {
|
||||
if (row_flag_ == false) {
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(row));
|
||||
row_flag_ = true;
|
||||
}
|
||||
root_ = std::move(d);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); }
|
||||
|
||||
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ == -1) {
|
||||
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize)));
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
|
||||
dataset_size_ = *dataset_size;
|
||||
if (*dataset_size == -1) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
int64_t num_rows = 0;
|
||||
TensorRow row = row_;
|
||||
while (row.size() != 0) {
|
||||
num_rows++;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
||||
if (*dataset_size == -1) { // run through the tree and get everything
|
||||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(GetRow(&row));
|
||||
int64_t row_cnt = 0;
|
||||
while (!row.empty()) {
|
||||
++row_cnt;
|
||||
RETURN_IF_NOT_OK(GetRow(&row));
|
||||
}
|
||||
dataset_size_ = num_rows;
|
||||
*dataset_size = row_cnt;
|
||||
}
|
||||
dataset_size_ = *dataset_size; // save the previous result
|
||||
}
|
||||
|
||||
*dataset_size = dataset_size_;
|
||||
|
@ -427,68 +471,88 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
|||
}
|
||||
|
||||
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
for (auto ts : row_) {
|
||||
DataType dt = ts->type();
|
||||
types->push_back(dt);
|
||||
}
|
||||
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
|
||||
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
|
||||
|
||||
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types),
|
||||
[](const TensorPtr &t) { return t->type(); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
for (auto ts : row_) {
|
||||
TensorShape t = ts->shape();
|
||||
shapes->push_back(t);
|
||||
}
|
||||
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
|
||||
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
|
||||
|
||||
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes),
|
||||
[](const TensorPtr &t) { return t->shape(); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetBatchSize(int64_t *batch_size) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
*batch_size = root->GetTreeBatchSize();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
*repeat_count = root->GetTreeRepeatCount();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetNumClasses(int64_t *num_classes) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetColumnNames(std::vector<std::string> *output) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map();
|
||||
if (column_name_id_map.empty()) RETURN_STATUS_UNEXPECTED("GetColumnNames: column_name_id map was empty.");
|
||||
std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(),
|
||||
column_name_id_map.end());
|
||||
std::sort(column_name_id_vector.begin(), column_name_id_vector.end(),
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map is empty.");
|
||||
std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end());
|
||||
std::sort(col_name_id_vec.begin(), col_name_id_vec.end(),
|
||||
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
|
||||
return a.second < b.second;
|
||||
});
|
||||
for (auto item : column_name_id_vector) {
|
||||
(*output).push_back(item.first);
|
||||
}
|
||||
std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output),
|
||||
[](const std::pair<std::string, int32_t> &p) { return p.first; });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::InternalInit(int8_t type) {
|
||||
if (init_flag_) return Status::OK();
|
||||
tree_adapter_->SetPrePassOverride([&type](OptPass pre) {
|
||||
pre.push_back(std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(type)));
|
||||
return pre;
|
||||
});
|
||||
Status s = tree_adapter_->Compile(std::move(root_), 1);
|
||||
if (!s.IsError()) init_flag_ = true;
|
||||
return s;
|
||||
}
|
||||
Status TreeGetters::InternalInit() {
|
||||
if (init_flag_) return Status::OK();
|
||||
Status s = tree_adapter_->Compile(std::move(root_), 1);
|
||||
if (!s.IsError()) init_flag_ = true;
|
||||
return s;
|
||||
}
|
||||
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
|
||||
|
||||
Status BuildVocabConsumer::Start() {
|
||||
|
|
|
@ -41,7 +41,7 @@ class TreeConsumer {
|
|||
/// \return Status error code.
|
||||
virtual Status Init(std::shared_ptr<DatasetNode> d);
|
||||
|
||||
Status Terminate();
|
||||
virtual Status Terminate();
|
||||
|
||||
protected:
|
||||
/// The class owns the tree_adapter that handles execution tree operations.
|
||||
|
@ -72,6 +72,11 @@ class IteratorConsumer : public TreeConsumer {
|
|||
/// \return Status error code
|
||||
Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out);
|
||||
|
||||
/// Returns the next row in as a map
|
||||
/// \param[out] out std::map of string to Tensor
|
||||
/// \return Status error code
|
||||
Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec);
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
/// \return string
|
||||
|
@ -79,6 +84,7 @@ class IteratorConsumer : public TreeConsumer {
|
|||
|
||||
private:
|
||||
int32_t num_epochs_;
|
||||
std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id
|
||||
};
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -101,7 +107,7 @@ class SaveToDisk : public TreeConsumer {
|
|||
/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
|
||||
/// would be written to disk)
|
||||
/// \return Status error code
|
||||
Status Save();
|
||||
virtual Status Save();
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
|
@ -110,7 +116,7 @@ class SaveToDisk : public TreeConsumer {
|
|||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
Status TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
|
||||
std::unique_ptr<S> *s, bool need_convert = false);
|
||||
|
||||
|
@ -131,24 +137,29 @@ class SaveToDisk : public TreeConsumer {
|
|||
/// Consumer that iterates over the dataset and send it to a device
|
||||
class ToDevice : public TreeConsumer {
|
||||
public:
|
||||
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1)
|
||||
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
|
||||
explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
|
||||
|
||||
~ToDevice() = default;
|
||||
|
||||
Status Init(std::shared_ptr<DatasetNode> d) override;
|
||||
|
||||
Status Terminate() override;
|
||||
|
||||
/// Send the data to device
|
||||
/// \return Status error code
|
||||
Status Send();
|
||||
virtual Status Send();
|
||||
|
||||
/// Stop to send data to device
|
||||
/// \return Status error code
|
||||
Status Stop();
|
||||
virtual Status Stop();
|
||||
|
||||
/// Continue to send data to device
|
||||
/// \return Status error code
|
||||
Status Continue();
|
||||
virtual Status Continue();
|
||||
|
||||
/// Get data info from TDT
|
||||
/// \return Status error code
|
||||
virtual Status GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes);
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
|
@ -156,8 +167,6 @@ class ToDevice : public TreeConsumer {
|
|||
std::string Name() override { return "ToDevice"; }
|
||||
|
||||
private:
|
||||
std::string device_type_;
|
||||
bool send_epoch_end_;
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
|
@ -167,6 +176,7 @@ class TreeGetters : public TreeConsumer {
|
|||
TreeGetters();
|
||||
~TreeGetters() = default;
|
||||
Status Init(std::shared_ptr<DatasetNode> d) override;
|
||||
|
||||
Status GetDatasetSize(int64_t *size);
|
||||
Status GetOutputTypes(std::vector<DataType> *types);
|
||||
Status GetOutputShapes(std::vector<TensorShape> *shapes);
|
||||
|
@ -175,15 +185,17 @@ class TreeGetters : public TreeConsumer {
|
|||
Status GetNumClasses(int64_t *num_classes);
|
||||
Status GetColumnNames(std::vector<std::string> *output);
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
|
||||
bool isInitialized();
|
||||
std::string Name() override { return "TreeGetters"; }
|
||||
Status GetRow(TensorRow *r);
|
||||
virtual Status GetRow(TensorRow *r);
|
||||
|
||||
private:
|
||||
std::shared_ptr<DatasetNode> root_;
|
||||
int64_t dataset_size_;
|
||||
TensorRow row_;
|
||||
TensorRow first_row_;
|
||||
bool init_flag_; // indicate whether the tree has initialized
|
||||
bool row_flag_; // indicate whether the first row has been stored in row_
|
||||
|
||||
Status InternalInit(int8_t type);
|
||||
Status InternalInit();
|
||||
};
|
||||
|
||||
class BuildVocabConsumer : public TreeConsumer {
|
||||
|
@ -197,7 +209,7 @@ class BuildVocabConsumer : public TreeConsumer {
|
|||
|
||||
/// Start consuming
|
||||
/// \return Status error code
|
||||
Status Start();
|
||||
virtual Status Start();
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
|
|
|
@ -44,9 +44,9 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
|
|||
}
|
||||
|
||||
// Constructor of the ConcatOp.
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums,
|
||||
std::vector<std::pair<int, int>> children_start_end_index)
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index)
|
||||
: PipelineOp(op_connector_size),
|
||||
children_num_(0),
|
||||
sampler_(sampler),
|
||||
|
|
|
@ -70,9 +70,9 @@ class ConcatOp : public PipelineOp {
|
|||
// @note The builder class should be used to call it
|
||||
// @param op_connector_size - connector size
|
||||
explicit ConcatOp(int32_t op_connector_size);
|
||||
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums,
|
||||
std::vector<std::pair<int, int>> children_start_end_index);
|
||||
ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index);
|
||||
|
||||
// Destructor
|
||||
~ConcatOp() = default;
|
||||
|
|
|
@ -346,6 +346,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return Name of the current Op
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
/// Op name and ID getter
|
||||
/// \return Name and ID of the current Op
|
||||
std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; }
|
||||
|
||||
/// Execution Tree getter
|
||||
/// \return Pointer to the ExecutionTree the current op belongs to, no ownership
|
||||
ExecutionTree *Tree() { return tree_; }
|
||||
|
|
|
@ -205,7 +205,6 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
}
|
||||
|
||||
tree_->SetFinished();
|
||||
MS_LOG(INFO) << "Device queue total batch is " << send_batch;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -39,10 +39,10 @@ using mindspore::device::GpuBufferMgr;
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>;
|
||||
using DATA_INFO_QUEUE = Queue<DATA_INFO>;
|
||||
const int kDataInfoQueueCapacity = 128;
|
||||
|
||||
class DeviceQueueOp : public PipelineOp {
|
||||
public:
|
||||
static const uint32_t INVALID_HANDLE = 0xffffffffUL;
|
||||
|
@ -184,7 +184,6 @@ class DeviceQueueOp : public PipelineOp {
|
|||
#ifdef ENABLE_TDTQUE
|
||||
Status SendDataToAscend();
|
||||
bool ascend_keep_waiting_;
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_GPUQUE
|
||||
|
|
|
@ -169,7 +169,7 @@ Status MapOp::operator()() {
|
|||
}
|
||||
|
||||
// The operator class just starts off threads by calling the tree_ function
|
||||
rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1));
|
||||
rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1), NameWithID());
|
||||
// Synchronize with TaskManager
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
|
|
|
@ -704,6 +704,8 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
}
|
||||
if (image_ids_.size() == 0) {
|
||||
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
|
||||
} else {
|
||||
num_rows = image_ids_.size();
|
||||
}
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
|
|
|
@ -480,13 +480,13 @@ Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = num_rows_, sample_size;
|
||||
int64_t num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) {
|
||||
std::shared_ptr<ShardOperator> op;
|
||||
// The last operator is parent sampler
|
||||
std::shared_ptr<ShardOperator> op = operators_.back();
|
||||
RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_));
|
||||
}
|
||||
sample_size = operators_[0]->GetNumSamples(num_rows, 0);
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1067,6 +1067,19 @@ Status TFReaderOp::PrepareNodePostAction() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the file list of the specific shard ID
|
||||
Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) {
|
||||
if (!shard_filenames->empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n");
|
||||
}
|
||||
for (int index = 0; index < dataset_files_list_.size(); index++) {
|
||||
if (index % num_devices_ == device_id_) {
|
||||
shard_filenames->push_back(dataset_files_list_.at(index));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
|
@ -1080,7 +1093,9 @@ Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
num_rows = num_rows_per_shard_;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, dataset_files_list_));
|
||||
std::vector<std::string> shard_file_list;
|
||||
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
|
||||
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list));
|
||||
}
|
||||
}
|
||||
sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
|
||||
|
|
|
@ -400,6 +400,11 @@ class TFReaderOp : public ParallelOp {
|
|||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
// Private function for computing the file list of the specific shard ID. This is because in distributed scenario,
|
||||
// data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case.
|
||||
// @return - Status - the status code returned.
|
||||
Status GetShardFileList(std::vector<std::string> *shard_filenames);
|
||||
|
||||
int32_t device_id_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
|
|
|
@ -536,6 +536,8 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
num_rows = static_cast<int64_t>(op->image_ids_.size());
|
||||
}
|
||||
} else {
|
||||
num_rows = image_ids_.size();
|
||||
}
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
|
|
|
@ -141,8 +141,6 @@ Status ExecutionTree::Launch() {
|
|||
" Expected state: " + std::to_string(static_cast<int>(kDeTStateReady));
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
std::ostringstream ss;
|
||||
ss << *this;
|
||||
|
||||
// Profiling infrastructures need to be initialized before Op launching
|
||||
if (profiling_manager_->IsProfilingEnable()) {
|
||||
|
@ -152,6 +150,8 @@ Status ExecutionTree::Launch() {
|
|||
RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor());
|
||||
}
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << *this;
|
||||
MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str();
|
||||
for (auto itr = this->begin(); itr != this->end(); ++itr) {
|
||||
// An inlined operator is one that has an output connector size of 0, and it does not
|
||||
|
@ -160,7 +160,7 @@ Status ExecutionTree::Launch() {
|
|||
// the launching tree/user thread. Do not exec any thread for an inlined op.
|
||||
itr->state_ = DatasetOp::OpState::kDeOpRunning;
|
||||
if (!itr->inlined()) {
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr)));
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(itr->NameWithID(), std::ref(*itr)));
|
||||
// Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp
|
||||
}
|
||||
}
|
||||
|
@ -189,10 +189,10 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_
|
|||
|
||||
// Given the number of workers, launches the worker entry function for each. Essentially a
|
||||
// wrapper for the TaskGroup handling that is stored inside the execution tree.
|
||||
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func) {
|
||||
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name) {
|
||||
// Launch the workers
|
||||
for (int32_t i = 0; i < num_workers; ++i) {
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i)));
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -150,7 +150,7 @@ class ExecutionTree {
|
|||
// @param num_workers - The number of workers to launch
|
||||
// @param func - The function entry point that workers will execute
|
||||
// @return Status - The error code return
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func);
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "");
|
||||
|
||||
// Getter method
|
||||
// @return shared_ptr to the root operator
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-ir-cache OBJECT
|
||||
dataset_cache_impl.cc)
|
||||
pre_built_dataset_cache.cc
|
||||
dataset_cache_impl.cc)
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
|
||||
/// \return Status Error code
|
||||
Status DatasetCacheImpl::Build() {
|
||||
|
@ -40,5 +40,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// DatasetCache is the IR of CacheClient
|
||||
class DatasetCacheImpl : public DatasetCache {
|
||||
public:
|
||||
|
@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache {
|
|||
std::optional<int32_t> num_connections_;
|
||||
std::optional<int32_t> prefetch_sz_;
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
|
||||
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
|
||||
/// \return Status Error code
|
||||
Status PreBuiltDatasetCache::Build() {
|
||||
// we actually want to keep a reference of the runtime object so it can be shared by different pipelines
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op));
|
||||
*ds = cache_op;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// DatasetCache is the IR of CacheClient
|
||||
class PreBuiltDatasetCache : public DatasetCache {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param cc a pre-built cache client
|
||||
explicit PreBuiltDatasetCache(std::shared_ptr<CacheClient> cc) : cache_client_(std::move(cc)) {}
|
||||
|
||||
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
|
||||
/// \return Status Error code
|
||||
Status Build() override;
|
||||
|
||||
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
||||
|
||||
Status ValidateParams() override { return Status::OK(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<CacheClient> cache_client_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
|
|
@ -31,7 +31,7 @@ namespace dataset {
|
|||
BucketBatchByLengthNode::BucketBatchByLengthNode(
|
||||
std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
|
||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function,
|
||||
std::shared_ptr<TensorOp> element_length_function,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder)
|
||||
: column_names_(column_names),
|
||||
|
@ -47,16 +47,13 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
|
|||
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::shared_ptr<TensorOp> c_func;
|
||||
if (element_length_function_ != nullptr) {
|
||||
c_func = std::make_shared<CFuncOp>(element_length_function_);
|
||||
} else {
|
||||
c_func = nullptr;
|
||||
bucket_boundaries_.insert(bucket_boundaries_.begin(), 0);
|
||||
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(
|
||||
column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_,
|
||||
pad_to_bucket_boundary_, drop_remainder_, connector_que_size_));
|
||||
if (bucket_boundaries_[0] == 0) {
|
||||
bucket_boundaries_.erase(bucket_boundaries_.begin());
|
||||
}
|
||||
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||
c_func, pad_info_, pad_to_bucket_boundary_,
|
||||
drop_remainder_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class BucketBatchByLengthNode : public DatasetNode {
|
|||
/// \brief Constructor
|
||||
BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
|
||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
|
||||
std::shared_ptr<TensorOp> element_length_function = nullptr,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
|
||||
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
|
||||
|
||||
|
@ -52,7 +52,7 @@ class BucketBatchByLengthNode : public DatasetNode {
|
|||
std::vector<std::string> column_names_;
|
||||
std::vector<int32_t> bucket_boundaries_;
|
||||
std::vector<int32_t> bucket_batch_sizes_;
|
||||
std::function<TensorRow(TensorRow)> element_length_function_;
|
||||
std::shared_ptr<TensorOp> element_length_function_;
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
|
||||
bool pad_to_bucket_boundary_;
|
||||
bool drop_remainder_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
||||
|
@ -27,7 +28,15 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Function to build ConcatOp
|
||||
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; }
|
||||
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets,
|
||||
const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index)
|
||||
: sampler_(sampler),
|
||||
children_flag_and_nums_(children_flag_and_nums),
|
||||
children_start_end_index_(children_start_end_index) {
|
||||
this->children = datasets;
|
||||
}
|
||||
|
||||
Status ConcatNode::ValidateParams() {
|
||||
if (children.size() < 2) {
|
||||
|
@ -42,14 +51,25 @@ Status ConcatNode::ValidateParams() {
|
|||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if ((children_flag_and_nums_.empty() && !children_start_end_index_.empty()) ||
|
||||
(!children_flag_and_nums_.empty() && children_start_end_index_.empty())) {
|
||||
std::string err_msg = "ConcatNode: children_flag_and_nums and children_start_end_index should be used together";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
|
||||
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
} else {
|
||||
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_,
|
||||
children_start_end_index_));
|
||||
}
|
||||
|
||||
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
@ -29,7 +30,10 @@ namespace dataset {
|
|||
class ConcatNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets);
|
||||
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets,
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums = {},
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index = {});
|
||||
|
||||
/// \brief Destructor
|
||||
~ConcatNode() = default;
|
||||
|
@ -41,6 +45,11 @@ class ConcatNode : public DatasetNode {
|
|||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -240,6 +240,7 @@ DatasetNode::DatasetNode() {
|
|||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
connector_que_size_ = cfg->op_connector_size();
|
||||
worker_connector_size_ = cfg->worker_connector_size();
|
||||
build_status = Status::OK(); // remove me after changing return val of Build()
|
||||
}
|
||||
|
||||
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
|
||||
|
@ -254,5 +255,13 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
|
|||
// This method will only be called if its derived class does not implement one.
|
||||
return p->VisitAfter(shared_from_this(), modified);
|
||||
}
|
||||
Status DatasetNode::GetShardId(int32_t *shard_id) {
|
||||
if (!Children().empty()) {
|
||||
// Get shard id from the child node
|
||||
return Children()[0]->GetShardId(shard_id);
|
||||
} else {
|
||||
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node");
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -99,9 +99,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
|
||||
/// \brief Pure virtual function for derived class to get the shard id of specific node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
virtual Status GetShardId(int32_t *shard_id) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
virtual Status GetShardId(int32_t *shard_id);
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
|
@ -126,6 +124,10 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \return Status of the node visit
|
||||
virtual Status AcceptAfter(NodePass *p, bool *modified);
|
||||
|
||||
/// \brief Method to get status from Node.Build()
|
||||
/// \notes Remove me after changing return val of Build()
|
||||
Status BuildStatus() { return build_status; }
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<DatasetNode>> children;
|
||||
std::shared_ptr<DatasetCache> cache_;
|
||||
|
@ -135,6 +137,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
int32_t rows_per_buffer_;
|
||||
int32_t connector_que_size_;
|
||||
int32_t worker_connector_size_;
|
||||
Status build_status; // remove me after changing return val of Build()
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Constructor for FilterNode
|
||||
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate,
|
||||
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
|
||||
std::vector<std::string> input_columns)
|
||||
: predicate_(predicate), input_columns_(input_columns) {
|
||||
this->children.push_back(child);
|
||||
|
@ -38,10 +38,7 @@ std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() {
|
|||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::shared_ptr<TensorOp> c_func;
|
||||
c_func = std::make_shared<CFuncOp>(predicate_);
|
||||
|
||||
node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, c_func));
|
||||
node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace dataset {
|
|||
class FilterNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate,
|
||||
FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
|
||||
std::vector<std::string> input_columns = {});
|
||||
|
||||
/// \brief Destructor
|
||||
|
@ -44,7 +44,7 @@ class FilterNode : public DatasetNode {
|
|||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::function<TensorRow(TensorRow)> predicate_;
|
||||
std::shared_ptr<TensorOp> predicate_;
|
||||
std::vector<std::string> input_columns_;
|
||||
};
|
||||
|
||||
|
|
|
@ -64,7 +64,8 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
|
|||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||
node_ops.push_back(project_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(map_op);
|
||||
return node_ops;
|
||||
|
|
|
@ -59,7 +59,8 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
|
|||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_));
|
||||
build_status = schema->LoadSchemaFile(schema_path_, column_names_);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Argument that is not exposed to user in the API.
|
||||
std::set<std::string> extensions = {};
|
||||
|
|
|
@ -60,7 +60,8 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, usage_, extensions_, std::move(schema),
|
||||
|
|
|
@ -56,7 +56,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
|
|
|
@ -54,7 +54,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
|
|
|
@ -197,18 +197,23 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
|
|||
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(clue_op->Init());
|
||||
|
||||
build_status = clue_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
build_status = ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
|
|
@ -111,7 +111,8 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
|
|||
std::shared_ptr<CocoOp> op =
|
||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(op);
|
||||
|
||||
|
|
|
@ -108,18 +108,23 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
|
|||
std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_,
|
||||
rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
|
||||
num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(csv_op->Init());
|
||||
|
||||
build_status = csv_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows));
|
||||
build_status = CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,25 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
|
|||
const std::vector<DataType> &column_types)
|
||||
: generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {}
|
||||
|
||||
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
|
||||
: generator_function_(generator_function), schema_(schema) {}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
|
||||
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
|
||||
|
||||
if (schema_ != nullptr) {
|
||||
column_names_.clear();
|
||||
column_types_.clear();
|
||||
std::string schema_json_string = schema_->to_json();
|
||||
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, {}));
|
||||
|
||||
for (int32_t i = 0; i < data_schema->NumColumns(); i++) {
|
||||
ColDescriptor col = data_schema->column(i);
|
||||
column_names_.push_back(col.name());
|
||||
column_types_.push_back((col.type()));
|
||||
}
|
||||
}
|
||||
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
|
||||
|
@ -43,6 +61,8 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
|
|||
// This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which
|
||||
// best be delivered when the test cases for this api is ready.
|
||||
Status rc = op->Init();
|
||||
build_status = rc; // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (rc.IsOk()) {
|
||||
node_ops.push_back(op);
|
||||
|
@ -56,5 +76,11 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
|
|||
// no validation is needed for generator op.
|
||||
Status GeneratorNode::ValidateParams() { return Status::OK(); }
|
||||
|
||||
Status GeneratorNode::GetShardId(int32_t *shard_id) {
|
||||
RETURN_UNEXPECTED_IF_NULL(shard_id);
|
||||
*shard_id = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,9 @@ class GeneratorNode : public DatasetNode {
|
|||
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
|
||||
const std::vector<DataType> &column_types);
|
||||
|
||||
/// \brief Constructor
|
||||
GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema);
|
||||
|
||||
/// \brief Destructor
|
||||
~GeneratorNode() = default;
|
||||
|
||||
|
@ -46,10 +49,15 @@ class GeneratorNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node, is always 0 because generator_node doesn't support sharding
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
py::function generator_function_;
|
||||
std::vector<std::string> column_names_;
|
||||
std::vector<DataType> column_types_;
|
||||
std::shared_ptr<SchemaObj> schema_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -62,7 +62,8 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
recursive_, decode_, exts_, class_indexing_, std::move(schema),
|
||||
|
|
|
@ -79,7 +79,8 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
|
|||
manifest_op =
|
||||
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
|
||||
class_index_, std::move(schema), std::move(sampler_->Build()), usage_);
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(manifest_op);
|
||||
|
||||
|
|
|
@ -138,7 +138,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
|
|||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::vector<std::shared_ptr<ShardOperator>> operators_;
|
||||
RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_));
|
||||
build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
std::shared_ptr<MindRecordOp> mindrecord_op;
|
||||
// If pass a string to MindData(), it will be treated as a pattern to search for matched files,
|
||||
|
@ -154,7 +155,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
|
|||
padded_sample_, sample_bytes_);
|
||||
}
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(mindrecord_op->Init());
|
||||
build_status = mindrecord_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
node_ops.push_back(mindrecord_op);
|
||||
|
||||
return node_ops;
|
||||
|
|
|
@ -51,7 +51,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
|
|||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_->Build())));
|
||||
|
|
|
@ -98,7 +98,8 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
|
|||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
|
||||
std::move(data_schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(op);
|
||||
|
||||
|
|
|
@ -78,7 +78,8 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
|
|||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
|
||||
build_status = text_file_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
@ -86,14 +87,17 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
build_status = TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
// Add TextFileOp
|
||||
node_ops.push_back(text_file_op);
|
||||
|
|
|
@ -118,7 +118,8 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
|
|||
std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_,
|
||||
shard_id_, shard_equal_rows_, std::move(sampler_->Build()));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init());
|
||||
build_status = tf_reader_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
@ -127,14 +128,17 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
|
||||
build_status = TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
build_status = AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
// Add TFReaderOp
|
||||
node_ops.push_back(tf_reader_op);
|
||||
|
|
|
@ -106,7 +106,8 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
|
|||
std::shared_ptr<VOCOp> voc_op;
|
||||
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(voc_op);
|
||||
return node_ops;
|
||||
|
|
|
@ -27,9 +27,8 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Constructor for SyncWaitNode
|
||||
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch,
|
||||
py::function callback)
|
||||
: condition_name_(condition_name), num_batch_(num_batch), callback_(callback) {
|
||||
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback)
|
||||
: condition_name_(condition_name), callback_(callback) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
|
@ -38,20 +37,16 @@ std::vector<std::shared_ptr<DatasetOp>> SyncWaitNode::Build() {
|
|||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<BarrierOp>(num_batch_, connector_que_size_, condition_name_, callback_));
|
||||
// Right now barrier should only take num_rows_per_buffer = 1
|
||||
// The reason for this is because having it otherwise can lead to blocking issues
|
||||
// See barrier_op.h for more details
|
||||
int32_t rows_per_buffer = 1;
|
||||
node_ops.push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for SyncWaitNode
|
||||
Status SyncWaitNode::ValidateParams() {
|
||||
if (num_batch_ <= 0) {
|
||||
std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
Status SyncWaitNode::ValidateParams() { return Status::OK(); }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,8 +31,7 @@ namespace dataset {
|
|||
class SyncWaitNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch,
|
||||
py::function callback);
|
||||
SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback);
|
||||
|
||||
/// \brief Destructor
|
||||
~SyncWaitNode() = default;
|
||||
|
@ -47,7 +46,6 @@ class SyncWaitNode : public DatasetNode {
|
|||
|
||||
private:
|
||||
std::string condition_name_;
|
||||
int32_t num_batch_;
|
||||
py::function callback_;
|
||||
};
|
||||
|
||||
|
|
|
@ -18,73 +18,81 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor for TransferNode
|
||||
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end)
|
||||
: prefetch_size_(16), send_epoch_end_(send_epoch_end), total_batch_(0) {
|
||||
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type,
|
||||
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
|
||||
: prefetch_size_(16),
|
||||
queue_name_(std::move(queue_name)),
|
||||
device_type_(std::move(device_type)),
|
||||
send_epoch_end_(send_epoch_end),
|
||||
total_batch_(total_batch),
|
||||
create_data_info_queue_(create_data_info_queue),
|
||||
device_id_(0) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Validator for TransferNode
|
||||
Status TransferNode::ValidateParams() {
|
||||
// Check if device_type_ is in {"CPU", "GPU", "Ascend"}
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"}));
|
||||
if (total_batch_ < 0) {
|
||||
std::string err_msg = "TransferNode: Total batches should be >= 0, value given: ";
|
||||
MS_LOG(ERROR) << err_msg << total_batch_;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build TransferNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
|
||||
// Get a uuid for queue name
|
||||
queue_name_ = Services::GetUniqueID();
|
||||
// TODO(CRC):
|
||||
if (queue_name_.empty()) {
|
||||
// Get a uuid for queue name
|
||||
queue_name_ = Services::GetUniqueID();
|
||||
}
|
||||
if (device_type_.empty()) {
|
||||
auto context = MsContext::GetInstance();
|
||||
if (context == nullptr) {
|
||||
device_type_ = kCPUDevice;
|
||||
} else {
|
||||
device_type_ = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
}
|
||||
}
|
||||
|
||||
// Get device type from ms context
|
||||
device_type_ = "CPU";
|
||||
// Get device ID from children
|
||||
// Convert device_type_ from string to DeviceType
|
||||
DeviceQueueOp::DeviceType type;
|
||||
if (device_type_ == kCPUDevice) {
|
||||
type = DeviceQueueOp::DeviceType::CPU;
|
||||
} else if (device_type_ == kGPUDevice) {
|
||||
type = DeviceQueueOp::DeviceType::GPU;
|
||||
} else if (device_type_ == kAscendDevice) {
|
||||
type = DeviceQueueOp::DeviceType::Ascend;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknown device target.";
|
||||
return {};
|
||||
}
|
||||
|
||||
// Get device ID (shard ID) from children
|
||||
device_id_ = 0;
|
||||
RETURN_EMPTY_IF_ERROR(TransferNode::get_distribution(shared_from_this(), &device_id_));
|
||||
build_status = this->GetShardId(&device_id_); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Convert device_type_ from string to DeviceType
|
||||
DeviceQueueOp::DeviceType type;
|
||||
if (device_type_ == "CPU") {
|
||||
type = DeviceQueueOp::DeviceType::CPU;
|
||||
} else if (device_type_ == "GPU") {
|
||||
type = DeviceQueueOp::DeviceType::GPU;
|
||||
} else if (device_type_ == "Ascend") {
|
||||
type = DeviceQueueOp::DeviceType::Ascend;
|
||||
}
|
||||
node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
||||
total_batch_, false));
|
||||
total_batch_, create_data_info_queue_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to get the device_id
|
||||
Status TransferNode::get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id) {
|
||||
// Get device id according to the type of dataset
|
||||
Status rc = ds->GetShardId(device_id);
|
||||
if (rc != Status::OK()) {
|
||||
// Get device id from the child node
|
||||
if (ds->Children().size()) {
|
||||
ds = ds->Children()[0];
|
||||
return TransferNode::get_distribution(ds, device_id);
|
||||
} else {
|
||||
std::string err_msg = "Unknown dataset type.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,8 @@ namespace dataset {
|
|||
class TransferNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end);
|
||||
TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, bool send_epoch_end,
|
||||
int32_t total_batch, bool create_data_info_queue);
|
||||
|
||||
/// \brief Destructor
|
||||
~TransferNode() = default;
|
||||
|
@ -42,8 +43,6 @@ class TransferNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
|
||||
|
||||
private:
|
||||
std::string queue_name_;
|
||||
int32_t device_id_;
|
||||
|
@ -51,6 +50,7 @@ class TransferNode : public DatasetNode {
|
|||
int32_t prefetch_size_;
|
||||
bool send_epoch_end_;
|
||||
int32_t total_batch_;
|
||||
bool create_data_info_queue_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -40,21 +40,7 @@ Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *mo
|
|||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
if (type_ == kOutputShapeAndType) {
|
||||
nodes_to_clear_callback_.push_back(node);
|
||||
} else if (type_ == kDatasetSize) {
|
||||
nodes_to_remove_.push_back(node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
|
||||
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
|
||||
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
|
||||
nodes_to_clear_callback_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -83,5 +69,6 @@ Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,10 @@ class GetterPass : public TreePass {
|
|||
enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 };
|
||||
/// \brief Constructor
|
||||
explicit GetterPass(GetterType tp) : pass_(tp) {}
|
||||
|
||||
/// \brief default copy Constructor
|
||||
explicit GetterPass(const GetterPass &) = default;
|
||||
|
||||
/// \brief Destructor
|
||||
~GetterPass() = default;
|
||||
|
||||
|
@ -51,11 +55,10 @@ class GetterPass : public TreePass {
|
|||
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override { return Status::OK(); }
|
||||
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
|
||||
// whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp
|
||||
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
|
||||
|
||||
|
@ -67,7 +70,7 @@ class GetterPass : public TreePass {
|
|||
std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_;
|
||||
std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_;
|
||||
};
|
||||
// outter class needs only to own the inner class object since it automatically has access to its private variables
|
||||
// outer class needs only to own the inner class object since it automatically has access to its private variables
|
||||
GetterNodes pass_;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -19,7 +19,14 @@
|
|||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }
|
||||
Status PythonRuntimeContext::Terminate() {
|
||||
MS_LOG(INFO) << "Terminating a PythonRuntime";
|
||||
if (tree_consumer_ != nullptr) {
|
||||
return TerminateImpl();
|
||||
}
|
||||
MS_LOG(WARNING) << "TreeConsumer was not initialized";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonRuntimeContext::TerminateImpl() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
|
||||
|
|
|
@ -22,7 +22,14 @@ namespace mindspore::dataset {
|
|||
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
|
||||
tree_consumer_ = std::move(tree_consumer);
|
||||
}
|
||||
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }
|
||||
Status NativeRuntimeContext::Terminate() {
|
||||
MS_LOG(INFO) << "Terminating a NativeRuntime";
|
||||
if (tree_consumer_ != nullptr) {
|
||||
return TerminateImpl();
|
||||
}
|
||||
MS_LOG(WARNING) << "TreeConsumer was not initialized";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NativeRuntimeContext::TerminateImpl() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
|
||||
|
|
|
@ -97,6 +97,8 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
|
|||
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
|
||||
// Build the DatasetOp ExecutionTree from the optimized IR tree
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
|
||||
RETURN_IF_NOT_OK(ir->BuildStatus()); // remove me after changing return val of Build()
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node.");
|
||||
|
||||
(*op) = ops.front(); // return the first op to be added as child by the caller of this function
|
||||
|
@ -141,6 +143,8 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
|
|||
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||
|
||||
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
|
||||
|
||||
// Note: We will gradually move the pre pass, optimizer pass, and post pass
|
||||
// on ExecutionTree to perform on IR tree.
|
||||
// Prepare the tree
|
||||
|
@ -149,6 +153,11 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
|
|||
// After the tree is prepared, the col_name_id_map can safely be obtained
|
||||
column_name_map_ = tree_->root()->column_name_id_map();
|
||||
|
||||
// Profiling parameters init
|
||||
cur_batch_num_ = 0;
|
||||
cur_connector_size_ = 0;
|
||||
cur_connector_capacity_ = 0;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -156,21 +165,55 @@ Status TreeAdapter::GetNext(TensorRow *row) {
|
|||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
row->clear(); // make sure row is empty
|
||||
|
||||
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||
|
||||
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
|
||||
if (cur_db_ == nullptr) {
|
||||
RETURN_IF_NOT_OK(tree_->Launch());
|
||||
// Profiling
|
||||
std::shared_ptr<Tracing> node;
|
||||
Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
|
||||
if (s.IsOk()) {
|
||||
tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
|
||||
}
|
||||
if (tracing_ != nullptr) {
|
||||
cur_connector_size_ = tree_->root()->ConnectorSize();
|
||||
cur_connector_capacity_ = tree_->root()->ConnectorCapacity();
|
||||
}
|
||||
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag
|
||||
RETURN_OK_IF_TRUE(cur_db_->eoe()); // return empty tensor if 1st buf is a ctrl buf (no rows)
|
||||
if (cur_db_->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
|
||||
MS_LOG(INFO) << "End of data iteration.";
|
||||
if (isProfilingEnable) {
|
||||
tree_->SetEpochEnd();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached.");
|
||||
|
||||
if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf
|
||||
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_));
|
||||
RETURN_OK_IF_TRUE(cur_db_->eoe() || cur_db_->eof()); // return empty if this new buffer is a ctrl flag
|
||||
if (cur_db_->eoe()) { // return empty if this new buffer is a ctrl flag
|
||||
MS_LOG(INFO) << "End of data iteration.";
|
||||
if (isProfilingEnable) {
|
||||
tree_->SetEpochEnd();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (cur_db_->eof()) {
|
||||
tree_->SetFinished();
|
||||
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(cur_db_->PopRow(row));
|
||||
// Record profiling info
|
||||
if (tracing_ != nullptr) {
|
||||
cur_batch_num_++;
|
||||
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -60,6 +61,9 @@ class TreeAdapter {
|
|||
// Set optional optimization pass
|
||||
void SetOptimize(bool value) { optimize_ = value; }
|
||||
|
||||
// function to override override the pre-pass
|
||||
void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; }
|
||||
|
||||
// Optional optimizations status
|
||||
bool OptimizationEnabled() const { return optimize_; }
|
||||
|
||||
|
@ -82,9 +86,14 @@ class TreeAdapter {
|
|||
|
||||
std::unique_ptr<DataBuffer> cur_db_;
|
||||
std::unordered_map<std::string, int32_t> column_name_map_;
|
||||
std::unique_ptr<ExecutionTree> tree_;
|
||||
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
|
||||
int32_t num_epochs_;
|
||||
bool optimize_; // Flag to enable optional optimization pass
|
||||
bool optimize_; // Flag to enable optional optimization pass
|
||||
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
|
||||
int32_t cur_batch_num_; // current batch number, used for profiling
|
||||
int32_t cur_connector_size_; // current connector size of root op, used for profiling
|
||||
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
|
||||
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -145,9 +145,16 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \brief Function to transfer data through a device.
|
||||
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
|
||||
/// of data transmission per time is 256M.
|
||||
/// \param[in] queue_name Channel name (default="", create new unique name).
|
||||
/// \param[in] device_type Type of device (default="", get from MSContext).
|
||||
/// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
|
||||
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
|
||||
/// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
|
||||
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
|
||||
/// of data or not(default=false).
|
||||
/// \return Returns true if no error encountered else false.
|
||||
bool DeviceQueue(bool send_epoch_end = true);
|
||||
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t num_epochs = -1,
|
||||
bool send_epoch_end = true, int32_t total_batches = 0, bool create_data_info_queue = false);
|
||||
|
||||
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
|
||||
/// \note Usage restrictions:
|
||||
|
@ -371,21 +378,34 @@ class SchemaObj {
|
|||
|
||||
/// \brief SchemaObj init function
|
||||
/// \return bool true if schema init success
|
||||
bool init();
|
||||
Status init();
|
||||
|
||||
/// \brief Add new column to the schema with unknown shape of rank 1
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(TypeId).
|
||||
/// \return bool true if schema init success
|
||||
Status add_column(std::string name, TypeId de_type);
|
||||
|
||||
/// \brief Add new column to the schema with unknown shape of rank 1
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(std::string).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
Status add_column(std::string name, std::string de_type);
|
||||
|
||||
/// \brief Add new column to the schema
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(TypeId).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
bool add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
|
||||
Status add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
|
||||
|
||||
/// \brief Add new column to the schema
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(std::string).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
bool add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
|
||||
Status add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
|
||||
|
||||
/// \brief Get a JSON string of the schema
|
||||
/// \return JSON string of the schema
|
||||
|
@ -395,25 +415,27 @@ class SchemaObj {
|
|||
std::string to_string() { return to_json(); }
|
||||
|
||||
/// \brief set a new value to dataset_type
|
||||
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = dataset_type; }
|
||||
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); }
|
||||
|
||||
/// \brief set a new value to num_rows
|
||||
inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
|
||||
|
||||
/// \brief get the current num_rows
|
||||
inline int32_t get_num_rows() { return num_rows_; }
|
||||
inline int32_t get_num_rows() const { return num_rows_; }
|
||||
|
||||
Status FromJSONString(const std::string &json_string);
|
||||
|
||||
private:
|
||||
/// \brief Parse the columns and add it to columns
|
||||
/// \param[in] columns dataset attribution information, decoded from schema file.
|
||||
/// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
|
||||
/// \return JSON string of the schema
|
||||
bool parse_column(nlohmann::json columns);
|
||||
Status parse_column(nlohmann::json columns);
|
||||
|
||||
/// \brief Get schema file from json file
|
||||
/// \param[in] json_obj object of json parsed.
|
||||
/// \return bool true if json dump success
|
||||
bool from_json(nlohmann::json json_obj);
|
||||
Status from_json(nlohmann::json json_obj);
|
||||
|
||||
int32_t num_rows_;
|
||||
std::string dataset_type_;
|
||||
|
|
|
@ -61,6 +61,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
|
||||
class DistributedSamplerObj;
|
||||
class PKSamplerObj;
|
||||
class PreBuiltSamplerObj;
|
||||
class RandomSamplerObj;
|
||||
class SequentialSamplerObj;
|
||||
class SubsetRandomSamplerObj;
|
||||
|
@ -171,6 +172,31 @@ class PKSamplerObj : public SamplerObj {
|
|||
int64_t num_samples_;
|
||||
};
|
||||
|
||||
class PreBuiltSamplerObj : public SamplerObj {
|
||||
public:
|
||||
#ifndef ENABLE_ANDROID
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
|
||||
#endif
|
||||
|
||||
~PreBuiltSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerRT> sp_;
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
|
||||
#endif
|
||||
};
|
||||
|
||||
class RandomSamplerObj : public SamplerObj {
|
||||
public:
|
||||
RandomSamplerObj(bool replacement, int64_t num_samples);
|
||||
|
|
|
@ -70,6 +70,7 @@ namespace transforms {
|
|||
class ComposeOperation;
|
||||
class DuplicateOperation;
|
||||
class OneHotOperation;
|
||||
class PreBuiltOperation;
|
||||
class RandomApplyOperation;
|
||||
class RandomChoiceOperation;
|
||||
class TypeCastOperation;
|
||||
|
@ -164,6 +165,20 @@ class OneHotOperation : public TensorOperation {
|
|||
float num_classes_;
|
||||
};
|
||||
|
||||
class PreBuiltOperation : public TensorOperation {
|
||||
public:
|
||||
explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op);
|
||||
|
||||
~PreBuiltOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<TensorOp> op_;
|
||||
};
|
||||
|
||||
class RandomApplyOperation : public TensorOperation {
|
||||
public:
|
||||
explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob);
|
||||
|
@ -192,7 +207,6 @@ class RandomChoiceOperation : public TensorOperation {
|
|||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
||||
};
|
||||
|
||||
class TypeCastOperation : public TensorOperation {
|
||||
public:
|
||||
explicit TypeCastOperation(std::string data_type);
|
||||
|
|
|
@ -71,6 +71,15 @@ namespace dataset {
|
|||
return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \
|
||||
} while (false)
|
||||
|
||||
#define RETURN_SECOND_IF_ERROR(_s, _r) \
|
||||
do { \
|
||||
Status __rc = (_s); \
|
||||
if (__rc.IsError()) { \
|
||||
MS_LOG(ERROR) << __rc; \
|
||||
return _r; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
enum class StatusCode : char {
|
||||
kOK = 0,
|
||||
kOutOfMemory = 1,
|
||||
|
|
|
@ -138,7 +138,9 @@ Status Task::Join(WaitFlag blocking) {
|
|||
while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) {
|
||||
// We can't tell which conditional_variable this thread is waiting on. So we may need
|
||||
// to interrupt everything one more time.
|
||||
MS_LOG(INFO) << "Some threads not responding. Interrupt again";
|
||||
std::stringstream ss;
|
||||
ss << get_id();
|
||||
MS_LOG(ERROR) << MyName() << " Thread ID " << ss.str() << " is not responding. Interrupt again";
|
||||
interrupt_svc->InterruptAll();
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -21,7 +21,8 @@ import numpy
|
|||
import mindspore._c_dataengine as cde
|
||||
|
||||
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
|
||||
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load']
|
||||
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load',
|
||||
'get_callback_timeout']
|
||||
|
||||
INT32_MAX = 2147483647
|
||||
UINT32_MAX = 4294967295
|
||||
|
|
|
@ -65,5 +65,7 @@ def mstypelist_to_detypelist(type_list):
|
|||
for index, _ in enumerate(type_list):
|
||||
if type_list[index] is not None:
|
||||
type_list[index] = mstype_to_detype(type_list[index])
|
||||
else:
|
||||
type_list[index] = cde.DataType("")
|
||||
|
||||
return type_list
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -15,17 +15,13 @@
|
|||
"""Built-in iterators.
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
import copy
|
||||
import weakref
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_dataengine import DEPipeline
|
||||
from mindspore._c_dataengine import OpName
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from mindspore import log as logger
|
||||
from . import datasets as de
|
||||
|
||||
|
||||
_ITERATOR_CLEANUP = False
|
||||
|
||||
|
@ -57,29 +53,6 @@ def _cleanup():
|
|||
itr.release()
|
||||
|
||||
|
||||
def alter_tree(node):
|
||||
"""Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
|
||||
if not node.children:
|
||||
return _alter_node(node)
|
||||
|
||||
converted_children = []
|
||||
for input_op in node.children:
|
||||
converted_children.append(alter_tree(input_op))
|
||||
node.children = converted_children
|
||||
return _alter_node(node)
|
||||
|
||||
|
||||
def _alter_node(node):
|
||||
"""DEPRECATED"""
|
||||
# Please check ccsrc/dataset/engine/opt for tree transformation.
|
||||
if isinstance(node, de.MapDataset):
|
||||
if node.python_multiprocessing:
|
||||
# Bootstrap can only be performed on a copy of the original dataset node.
|
||||
# Bootstrap on original dataset node will make all iterators share the same process pool
|
||||
node.iterator_bootstrap()
|
||||
return node
|
||||
|
||||
|
||||
class Iterator:
|
||||
"""
|
||||
General Iterator over a dataset.
|
||||
|
@ -89,185 +62,62 @@ class Iterator:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, num_epochs=-1, output_numpy=False):
|
||||
self.num_epochs = num_epochs
|
||||
self.output_numpy = output_numpy
|
||||
self._col_names = None
|
||||
|
||||
# create a copy of tree and work on it.
|
||||
self.ori_dataset = dataset
|
||||
|
||||
self.ir_tree, self.dataset = dataset.create_ir_tree()
|
||||
|
||||
self._runtime_context = cde.PythonRuntimeContext()
|
||||
self._runtime_context.Init()
|
||||
consumer = cde.PythonIteratorConsumer(num_epochs)
|
||||
consumer.Init(self.ir_tree)
|
||||
self._runtime_context.AssignConsumer(consumer)
|
||||
self._iterator = self._runtime_context.GetConsumer()
|
||||
|
||||
self._transform_tensor = lambda t: t.as_array()
|
||||
if not output_numpy:
|
||||
self._transform_tensor = lambda t: Tensor(t.as_array())
|
||||
self._index = 0
|
||||
|
||||
# todo remove next when ContextManager is done
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
_unset_iterator_cleanup()
|
||||
# create a copy of tree and work on it.
|
||||
self.dataset = copy.deepcopy(dataset)
|
||||
self.ori_dataset = dataset
|
||||
self.parent_subtree = []
|
||||
#######
|
||||
|
||||
# The dataset passed into the iterator is not the root of the tree.
|
||||
# Trim the tree by saving the parent subtree into self.parent_subtree and
|
||||
# restore it after launching our C++ pipeline.
|
||||
if self.dataset.parent:
|
||||
logger.info("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
|
||||
self.parent_subtree = self.dataset.parent
|
||||
self.dataset.parent = []
|
||||
|
||||
self.dataset = alter_tree(self.dataset)
|
||||
if not self.__is_tree():
|
||||
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers).")
|
||||
self.depipeline = DEPipeline()
|
||||
|
||||
# for manifest temporary use
|
||||
self.__batch_node(self.dataset, 0)
|
||||
|
||||
root = self.__convert_node_postorder(self.dataset)
|
||||
self.depipeline.AssignRootNode(root)
|
||||
self.depipeline.PrepareTree(self.num_epochs)
|
||||
self._index = 0
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Manually terminate Python iterator instead of relying on out of scope destruction.
|
||||
"""
|
||||
logger.info("Terminating Python iterator. This will also terminate C++ pipeline.")
|
||||
if hasattr(self, 'depipeline') and self.depipeline:
|
||||
del self.depipeline
|
||||
|
||||
def __is_tree_node(self, node):
|
||||
"""Check if a node is tree node."""
|
||||
if not node.children:
|
||||
if len(node.parent) > 1:
|
||||
return False
|
||||
|
||||
if len(node.parent) > 1:
|
||||
return False
|
||||
|
||||
for input_node in node.children:
|
||||
cls = self.__is_tree_node(input_node)
|
||||
if not cls:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __is_tree(self):
|
||||
return self.__is_tree_node(self.dataset)
|
||||
|
||||
@staticmethod
|
||||
def __get_dataset_type(dataset):
|
||||
"""Get the dataset type."""
|
||||
op_type = None
|
||||
if isinstance(dataset, de.ShuffleDataset):
|
||||
op_type = OpName.SHUFFLE
|
||||
elif isinstance(dataset, de.MindDataset):
|
||||
op_type = OpName.MINDRECORD
|
||||
elif isinstance(dataset, de.BatchDataset):
|
||||
op_type = OpName.BATCH
|
||||
elif isinstance(dataset, de.BucketBatchByLengthDataset):
|
||||
op_type = OpName.BUCKETBATCH
|
||||
elif isinstance(dataset, de.SyncWaitDataset):
|
||||
op_type = OpName.BARRIER
|
||||
elif isinstance(dataset, de.ZipDataset):
|
||||
op_type = OpName.ZIP
|
||||
elif isinstance(dataset, de.ConcatDataset):
|
||||
op_type = OpName.CONCAT
|
||||
elif isinstance(dataset, de.MapDataset):
|
||||
op_type = OpName.MAP
|
||||
elif isinstance(dataset, de.FilterDataset):
|
||||
op_type = OpName.FILTER
|
||||
elif isinstance(dataset, de.RepeatDataset):
|
||||
op_type = OpName.REPEAT
|
||||
elif isinstance(dataset, de.SkipDataset):
|
||||
op_type = OpName.SKIP
|
||||
elif isinstance(dataset, de.TakeDataset):
|
||||
op_type = OpName.TAKE
|
||||
elif isinstance(dataset, de.ImageFolderDataset):
|
||||
op_type = OpName.IMAGEFOLDER
|
||||
elif isinstance(dataset, de.GeneratorDataset):
|
||||
op_type = OpName.GENERATOR
|
||||
elif isinstance(dataset, de.TransferDataset):
|
||||
op_type = OpName.DEVICEQUEUE
|
||||
elif isinstance(dataset, de.RenameDataset):
|
||||
op_type = OpName.RENAME
|
||||
elif isinstance(dataset, de.TFRecordDataset):
|
||||
op_type = OpName.TFREADER
|
||||
elif isinstance(dataset, de.ProjectDataset):
|
||||
op_type = OpName.PROJECT
|
||||
elif isinstance(dataset, de.MnistDataset):
|
||||
op_type = OpName.MNIST
|
||||
elif isinstance(dataset, de.ManifestDataset):
|
||||
op_type = OpName.MANIFEST
|
||||
elif isinstance(dataset, de.VOCDataset):
|
||||
op_type = OpName.VOC
|
||||
elif isinstance(dataset, de.CocoDataset):
|
||||
op_type = OpName.COCO
|
||||
elif isinstance(dataset, de.Cifar10Dataset):
|
||||
op_type = OpName.CIFAR10
|
||||
elif isinstance(dataset, de.Cifar100Dataset):
|
||||
op_type = OpName.CIFAR100
|
||||
elif isinstance(dataset, de.CelebADataset):
|
||||
op_type = OpName.CELEBA
|
||||
elif isinstance(dataset, de.RandomDataset):
|
||||
op_type = OpName.RANDOMDATA
|
||||
elif isinstance(dataset, de.TextFileDataset):
|
||||
op_type = OpName.TEXTFILE
|
||||
elif isinstance(dataset, de.BuildVocabDataset):
|
||||
op_type = OpName.BUILDVOCAB
|
||||
elif isinstance(dataset, de.BuildSentencePieceVocabDataset):
|
||||
op_type = OpName.SENTENCEPIECEVOCAB
|
||||
elif isinstance(dataset, de.CLUEDataset):
|
||||
op_type = OpName.CLUE
|
||||
elif isinstance(dataset, de.CSVDataset):
|
||||
op_type = OpName.CSV
|
||||
else:
|
||||
raise ValueError("Unsupported DatasetOp.")
|
||||
|
||||
return op_type
|
||||
|
||||
# Convert Python node into C node and add to C layer execution tree in postorder traversal.
|
||||
def __convert_node_postorder(self, node):
|
||||
self.check_node_type(node)
|
||||
op_type = self.__get_dataset_type(node)
|
||||
c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
|
||||
|
||||
for py_child in node.children:
|
||||
c_child = self.__convert_node_postorder(py_child)
|
||||
self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"])
|
||||
|
||||
return c_nodes["top"]
|
||||
|
||||
def __batch_node(self, dataset, level):
|
||||
"""Recursively get batch node in the dataset tree."""
|
||||
if isinstance(dataset, de.BatchDataset):
|
||||
return
|
||||
for input_op in dataset.children:
|
||||
self.__batch_node(input_op, level + 1)
|
||||
|
||||
@staticmethod
|
||||
def __print_local(dataset, level):
|
||||
"""Recursively print the name and address of nodes in the dataset tree."""
|
||||
name = dataset.__class__.__name__
|
||||
ptr = hex(id(dataset))
|
||||
for _ in range(level):
|
||||
logger.info("\t", end='')
|
||||
if not dataset.children:
|
||||
logger.info("-%s (%s)", name, ptr)
|
||||
else:
|
||||
logger.info("+%s (%s)", name, ptr)
|
||||
for input_op in dataset.children:
|
||||
Iterator.__print_local(input_op, level + 1)
|
||||
|
||||
def print(self):
|
||||
"""Print the dataset tree"""
|
||||
self.__print_local(self.dataset, 0)
|
||||
if hasattr(self, '_runtime_context') and self._runtime_context:
|
||||
if hasattr(self, '_iterator') and self._iterator:
|
||||
self._runtime_context.Terminate()
|
||||
del self._iterator
|
||||
del self._runtime_context
|
||||
del self.dataset
|
||||
|
||||
def release(self):
|
||||
if hasattr(self, 'depipeline') and self.depipeline:
|
||||
del self.depipeline
|
||||
self.stop()
|
||||
|
||||
def __del__(self):
|
||||
self.release()
|
||||
|
||||
@abstractmethod
|
||||
def get_next(self):
|
||||
def _get_next(self):
|
||||
raise RuntimeError("Calling base class Iterator's get_next is invalid.")
|
||||
|
||||
def __next__(self):
|
||||
if not self.depipeline:
|
||||
if not self._runtime_context:
|
||||
logger.warning("Iterator does not have a running C++ pipeline." +
|
||||
"It might because Iterator stop() had been called, or C++ pipeline crashed silently.")
|
||||
raise RuntimeError("Iterator does not have a running C++ pipeline.")
|
||||
|
||||
data = self.get_next()
|
||||
data = self._get_next()
|
||||
if not data:
|
||||
if self._index == 0:
|
||||
logger.warning("No records available.")
|
||||
|
@ -277,100 +127,56 @@ class Iterator:
|
|||
self._index += 1
|
||||
return data
|
||||
|
||||
@abstractmethod
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def get_output_shapes(self):
|
||||
return [t for t in self.depipeline.GetOutputShapes()]
|
||||
|
||||
def get_output_types(self):
|
||||
return [t for t in self.depipeline.GetOutputTypes()]
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.depipeline.GetDatasetSize()
|
||||
|
||||
def get_batch_size(self):
|
||||
return self.depipeline.GetBatchSize()
|
||||
|
||||
def get_repeat_count(self):
|
||||
return self.depipeline.GetRepeatCount()
|
||||
|
||||
def num_classes(self):
|
||||
return self.depipeline.GetNumClasses()
|
||||
|
||||
def get_col_names(self):
|
||||
return self.depipeline.GetColumnNames()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self
|
||||
|
||||
def _getters(self):
|
||||
"""
|
||||
Get pipeline information.
|
||||
"""
|
||||
getter = cde.TreeGetters()
|
||||
getter.Init(self.ir_tree)
|
||||
self._runtime_context.AssignConsumer(getter)
|
||||
self._col_names = getter.GetColumnNames()
|
||||
|
||||
class SaveOp(Iterator):
|
||||
"""
|
||||
The derived class of Iterator with dict type.
|
||||
"""
|
||||
def __init__(self, dataset, num_epochs=-1):
|
||||
super().__init__(dataset, num_epochs)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def get_next(self):
|
||||
pass
|
||||
|
||||
def check_node_type(self, node):
|
||||
if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
|
||||
logger.warning("Used shuffle, repeat, batch before save operator.")
|
||||
|
||||
def save(self, file_names, file_type):
|
||||
return self.depipeline.SaveDataset(file_names, file_type)
|
||||
def get_col_names(self):
|
||||
"""
|
||||
Get names of the columns in the dataset
|
||||
"""
|
||||
if self._col_names is None:
|
||||
self._getters()
|
||||
return self._col_names
|
||||
|
||||
|
||||
class DictIterator(Iterator):
|
||||
"""
|
||||
The derived class of Iterator with dict type.
|
||||
"""
|
||||
def __init__(self, dataset, num_epochs=-1, output_numpy=False):
|
||||
super().__init__(dataset, num_epochs, output_numpy)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def get_next(self):
|
||||
def _get_next(self):
|
||||
"""
|
||||
Returns the next record in the dataset as dictionary
|
||||
|
||||
Returns:
|
||||
Dict, the next record in the dataset.
|
||||
"""
|
||||
|
||||
if self.output_numpy:
|
||||
return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()}
|
||||
return {k: Tensor(v.as_array()) for k, v in self.depipeline.GetNextAsMap().items()}
|
||||
return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()}
|
||||
|
||||
|
||||
class TupleIterator(Iterator):
|
||||
"""
|
||||
The derived class of Iterator with list type.
|
||||
"""
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False):
|
||||
if columns is not None:
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
# todo: move next to IR
|
||||
dataset = dataset.project(columns)
|
||||
super().__init__(dataset, num_epochs, output_numpy)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def get_next(self):
|
||||
def _get_next(self):
|
||||
"""
|
||||
Returns the next record in the dataset as a list
|
||||
|
||||
|
@ -378,15 +184,14 @@ class TupleIterator(Iterator):
|
|||
List, the next record in the dataset.
|
||||
"""
|
||||
|
||||
if self.output_numpy:
|
||||
return [t.as_array() for t in self.depipeline.GetNextAsList()]
|
||||
return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()]
|
||||
return [self._transform_tensor(t) for t in self._iterator.GetNextAsList()]
|
||||
|
||||
|
||||
class DummyIterator:
|
||||
"""
|
||||
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, mode):
|
||||
self.mode = mode
|
||||
self.shapes = dataset.output_shapes()
|
||||
|
|
|
@ -283,9 +283,12 @@ def create_node(node):
|
|||
node.get('shard_id'), sampler)
|
||||
|
||||
elif dataset_op == 'TFRecordDataset':
|
||||
shuffle = node.get('shuffle')
|
||||
if shuffle is not None and isinstance(shuffle, str):
|
||||
shuffle = de.Shuffle(shuffle)
|
||||
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
|
||||
node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id'))
|
||||
shuffle, node.get('num_shards'), node.get('shard_id'))
|
||||
|
||||
elif dataset_op == 'ManifestDataset':
|
||||
sampler = construct_sampler(node.get('sampler'))
|
||||
|
|
|
@ -293,14 +293,38 @@ def check_save(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_iterator(method):
|
||||
def check_tuple_iterator(method):
|
||||
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
[columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
nreq_param_bool = ['output_numpy']
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
if num_epochs is not None:
|
||||
type_check(num_epochs, (int,), "num_epochs")
|
||||
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
||||
|
||||
if columns is not None:
|
||||
check_columns(columns, "column_names")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_dict_iterator(method):
|
||||
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
nreq_param_bool = ['output_numpy']
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
if num_epochs is not None:
|
||||
type_check(num_epochs, (int,), "num_epochs")
|
||||
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -523,6 +547,8 @@ def check_batch(method):
|
|||
sig = ins.signature(batch_size)
|
||||
if len(sig.parameters) != 1:
|
||||
raise ValueError("callable batch_size should take one parameter (BatchInfo).")
|
||||
else:
|
||||
check_pos_int32(int(batch_size), "batch_size")
|
||||
|
||||
if num_parallel_workers is not None:
|
||||
check_num_parallel_workers(num_parallel_workers)
|
||||
|
@ -807,6 +833,21 @@ def check_project(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_schema(method):
|
||||
"""check the input arguments of Schema.__init__."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[schema_file], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if schema_file is not None:
|
||||
type_check(schema_file, (str,), "schema_file")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_add_column(method):
|
||||
"""check the input arguments of add_column."""
|
||||
|
||||
|
@ -1261,3 +1302,23 @@ def check_cache_option(cache):
|
|||
"""Sanity check for cache parameter"""
|
||||
if cache is not None:
|
||||
type_check(cache, (cache_client.DatasetCache,), "cache")
|
||||
|
||||
|
||||
def check_to_device_send(method):
|
||||
"""A wrapper that wraps a parameter checker around the check_to_device_send."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[num_epochs], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if num_epochs is not None:
|
||||
type_check(num_epochs, (int,), "num_epochs")
|
||||
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def replace_none(value, default):
|
||||
return value if value is not None else default
|
||||
|
|
|
@ -18,13 +18,13 @@ use to_bytes and to_str to encode and decode strings into a specified format.
|
|||
"""
|
||||
from enum import IntEnum
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \
|
||||
check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Vocab", "SentencePieceVocab", "to_str", "to_bytes"
|
||||
]
|
||||
|
@ -39,8 +39,7 @@ class Vocab(cde.Vocab):
|
|||
|
||||
@classmethod
|
||||
@check_from_dataset
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
|
||||
special_first=True):
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, special_first=True):
|
||||
"""
|
||||
Build a vocab from a dataset.
|
||||
|
||||
|
@ -69,21 +68,7 @@ class Vocab(cde.Vocab):
|
|||
Returns:
|
||||
Vocab, Vocab object built from dataset.
|
||||
"""
|
||||
|
||||
vocab = Vocab()
|
||||
if columns is None:
|
||||
columns = []
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
if freq_range is None:
|
||||
freq_range = (None, None)
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
|
||||
for d in root.create_dict_iterator(num_epochs=1):
|
||||
if d is not None:
|
||||
raise ValueError("from_dataset should receive data other than None.")
|
||||
return vocab
|
||||
return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first)
|
||||
|
||||
@classmethod
|
||||
@check_from_list
|
||||
|
@ -143,6 +128,7 @@ class SentencePieceVocab(cde.SentencePieceVocab):
|
|||
"""
|
||||
SentencePiece obiect that is used to segmentate words
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@check_from_dataset_sentencepiece
|
||||
def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params):
|
||||
|
@ -164,13 +150,8 @@ class SentencePieceVocab(cde.SentencePieceVocab):
|
|||
SentencePiece, SentencePiece object from dataset.
|
||||
"""
|
||||
|
||||
vocab = SentencePieceVocab()
|
||||
root = copy.deepcopy(dataset).build_sentencepiece_vocab(vocab, col_names, vocab_size, character_coverage,
|
||||
model_type, params)
|
||||
for d in root.create_dict_iterator(num_epochs=1):
|
||||
if d is None:
|
||||
raise ValueError("from_dataset should receive data other than None.")
|
||||
return vocab
|
||||
return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage,
|
||||
DE_C_INTER_SENTENCEPIECE_MODE[model_type], params)
|
||||
|
||||
@classmethod
|
||||
@check_from_file_sentencepiece
|
||||
|
@ -270,6 +251,7 @@ class SentencePieceModel(IntEnum):
|
|||
CHAR = 2
|
||||
WORD = 3
|
||||
|
||||
|
||||
DE_C_INTER_SENTENCEPIECE_MODE = {
|
||||
SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM,
|
||||
SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE,
|
||||
|
|
|
@ -432,7 +432,7 @@ def check_from_dataset_sentencepiece(method):
|
|||
[_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if col_names is not None:
|
||||
type_check(col_names, (list,), "col_names")
|
||||
type_check_list(col_names, (str,), "col_names")
|
||||
|
||||
if vocab_size is not None:
|
||||
check_uint32(vocab_size, "vocab_size")
|
||||
|
|
|
@ -146,6 +146,7 @@ if (BUILD_MINDDATA STREQUAL "full")
|
|||
|
||||
list(REMOVE_ITEM MINDDATA_ENGINE_IR_CACHE_SRC_FILES
|
||||
"${MINDDATA_DIR}/engine/ir/cache/dataset_cache_impl.cc"
|
||||
"${MINDDATA_DIR}/engine/ir/cache/pre_built_dataset_cache.cc"
|
||||
)
|
||||
|
||||
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -123,6 +123,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||
return network
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
|
||||
|
@ -197,7 +198,6 @@ class DatasetHelper:
|
|||
def get_data_info(self):
|
||||
return self.iter.get_data_info()
|
||||
|
||||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset helper"""
|
||||
|
||||
|
@ -331,7 +331,6 @@ class _DatasetIterPSLite(_DatasetIter):
|
|||
|
||||
class _DatasetIterNormal:
|
||||
"""Iter for normal(non sink) mode, feed the data from host."""
|
||||
|
||||
def __init__(self, dataset, epoch_num=-1):
|
||||
self.dataset = dataset
|
||||
self.device_num = _get_device_num()
|
||||
|
|
|
@ -61,15 +61,15 @@ class MindData:
|
|||
def send(self, num_epochs=-1):
|
||||
pass
|
||||
|
||||
def get_data_info(self):
|
||||
pass
|
||||
|
||||
def stop_send(self):
|
||||
pass
|
||||
|
||||
def continue_send(self):
|
||||
pass
|
||||
|
||||
def get_data_info(self):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return self._size
|
||||
|
||||
|
|
|
@ -177,8 +177,8 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) {
|
|||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
// 5 batches of size 2
|
||||
EXPECT_EQ(i, 5);
|
||||
// With 2 boundaries, 3 buckets are created
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
|
|
@ -132,6 +132,6 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
|
|||
// verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not
|
||||
EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos);
|
||||
EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
|
|||
const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}};
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {2, 2, 0, 0};
|
||||
std::vector<size_t> row_sizes = {2, 2, 0};
|
||||
|
||||
TensorRow row;
|
||||
for (size_t sz : row_sizes) {
|
||||
|
@ -75,7 +75,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
|
|||
rc = tree_adapter.GetNext(&row);
|
||||
EXPECT_TRUE(rc.IsError());
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
|
||||
|
@ -97,7 +97,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
|
|||
const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap();
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0, 0};
|
||||
std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0};
|
||||
|
||||
TensorRow row;
|
||||
for (size_t sz : row_sizes) {
|
||||
|
@ -107,7 +107,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
|
|||
}
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
|
||||
|
@ -135,7 +135,7 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
|
|||
const std::unordered_map<std::string, int32_t> map = {{"label", 0}};
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0, 0};
|
||||
std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0};
|
||||
TensorRow row;
|
||||
|
||||
for (size_t sz : row_sizes) {
|
||||
|
@ -145,5 +145,5 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
|
|||
}
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
|
||||
}
|
||||
|
|
|
@ -451,6 +451,10 @@ def test_batch_exception_13():
|
|||
|
||||
|
||||
def test_batch_exception_14():
|
||||
"""
|
||||
Test per_batch_map and input column name
|
||||
"""
|
||||
logger.info("test_batch_exception_14")
|
||||
batch_size = 2
|
||||
input_columns = ["num"]
|
||||
data1 = ds.TFRecordDataset(DATA_DIR)
|
||||
|
@ -460,6 +464,22 @@ def test_batch_exception_14():
|
|||
assert "per_batch_map and input_columns need to be passed in together." in str(e)
|
||||
|
||||
|
||||
def test_batch_exception_15():
|
||||
"""
|
||||
Test batch_size = int32 max value + 1
|
||||
"""
|
||||
logger.info("test_batch_exception_15")
|
||||
batch_size = 2147483647 + 1
|
||||
input_columns = ["num"]
|
||||
data1 = ds.TFRecordDataset(DATA_DIR)
|
||||
err_msg = ""
|
||||
try:
|
||||
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
|
||||
except ValueError as e:
|
||||
err_msg = str(e)
|
||||
assert "batch_size is not within the required interval of (1 to 2147483647)" in err_msg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_batch_01()
|
||||
test_batch_02()
|
||||
|
@ -486,4 +506,5 @@ if __name__ == '__main__':
|
|||
test_batch_exception_12()
|
||||
test_batch_exception_13()
|
||||
test_batch_exception_14()
|
||||
test_batch_exception_15()
|
||||
logger.info('\n')
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
|
@ -354,6 +355,18 @@ def test_clue_to_device():
|
|||
data.send()
|
||||
|
||||
|
||||
def test_clue_invalid_files():
|
||||
"""
|
||||
Test CLUE with invalid files
|
||||
"""
|
||||
AFQMC_DIR = '../data/dataset/testCLUE/afqmc'
|
||||
afqmc_train_json = os.path.join(AFQMC_DIR)
|
||||
with pytest.raises(ValueError) as info:
|
||||
_ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False)
|
||||
assert "The following patterns did not match any files" in str(info.value)
|
||||
assert AFQMC_DIR in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_clue()
|
||||
test_clue_num_shards()
|
||||
|
@ -366,3 +379,4 @@ if __name__ == "__main__":
|
|||
test_clue_tnews()
|
||||
test_clue_wsc()
|
||||
test_clue_to_device()
|
||||
test_clue_invalid_files()
|
||||
|
|
|
@ -195,6 +195,19 @@ def test_csv_dataset_size():
|
|||
assert data.get_dataset_size() == 5
|
||||
|
||||
|
||||
def test_csv_dataset_type_error():
|
||||
TEST_FILE = '../data/dataset/testCSV/exception.csv'
|
||||
data = ds.CSVDataset(
|
||||
TEST_FILE,
|
||||
column_defaults=["", 0, "", ""],
|
||||
column_names=['col1', 'col2', 'col3', 'col4'],
|
||||
shuffle=False)
|
||||
with pytest.raises(Exception) as err:
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert "type does not match" in str(err.value)
|
||||
|
||||
|
||||
def test_csv_dataset_exception():
|
||||
TEST_FILE = '../data/dataset/testCSV/exception.csv'
|
||||
data = ds.CSVDataset(
|
||||
|
@ -208,17 +221,16 @@ def test_csv_dataset_exception():
|
|||
assert "failed to parse file" in str(err.value)
|
||||
|
||||
|
||||
def test_csv_dataset_type_error():
|
||||
TEST_FILE = '../data/dataset/testCSV/exception.csv'
|
||||
def test_csv_dataset_duplicate_columns():
|
||||
data = ds.CSVDataset(
|
||||
TEST_FILE,
|
||||
column_defaults=["", 0, "", ""],
|
||||
column_names=['col1', 'col2', 'col3', 'col4'],
|
||||
DATA_FILE,
|
||||
column_defaults=["1", "2", "3", "4"],
|
||||
column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'],
|
||||
shuffle=False)
|
||||
with pytest.raises(Exception) as err:
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert "type does not match" in str(err.value)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_ = data.create_dict_iterator(num_epochs=1, output_numpy=True)
|
||||
assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value)
|
||||
assert "column_names" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -234,5 +246,6 @@ if __name__ == "__main__":
|
|||
test_csv_dataset_header()
|
||||
test_csv_dataset_number()
|
||||
test_csv_dataset_size()
|
||||
test_csv_dataset_exception()
|
||||
test_csv_dataset_type_error()
|
||||
test_csv_dataset_exception()
|
||||
test_csv_dataset_duplicate_columns()
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
|
||||
IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train"
|
||||
IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
|
@ -21,9 +22,18 @@ IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-000
|
|||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||
MIND_CV_FILE_NAME = "../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord"
|
||||
SCHEMA_FILE = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
|
||||
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data"
|
||||
CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data"
|
||||
VOC_DATA_DIR = "../data/dataset/testVOC2012"
|
||||
COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
|
||||
ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
|
||||
CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
|
||||
CLUE_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
CSV_FILE = '../data/dataset/testCSV/1.csv'
|
||||
TEXT_DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
|
||||
|
||||
|
||||
def test_imagenet_rawdata_dataset_size():
|
||||
|
@ -50,8 +60,15 @@ def test_imagenet_tf_file_dataset_size():
|
|||
ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0)
|
||||
assert ds_shard_2_0.get_dataset_size() == 6
|
||||
|
||||
# FIXME: dataset_size == 6 looks wrong but seem it aims to match the current code.
|
||||
# Correct answer should be 12/3=4, the code issue should be addressed.
|
||||
ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0)
|
||||
assert ds_shard_3_0.get_dataset_size() == 4
|
||||
assert ds_shard_3_0.get_dataset_size() == 6
|
||||
|
||||
count = 0
|
||||
for _ in ds_shard_3_0.create_dict_iterator():
|
||||
count += 1
|
||||
assert ds_shard_3_0.get_dataset_size() == count
|
||||
|
||||
|
||||
def test_mnist_dataset_size():
|
||||
|
@ -76,6 +93,14 @@ def test_mnist_dataset_size():
|
|||
assert ds_shard_3_0.get_dataset_size() == 3334
|
||||
|
||||
|
||||
def test_mind_dataset_size():
|
||||
dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0")
|
||||
assert dataset.get_dataset_size() == 20
|
||||
|
||||
dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 10
|
||||
|
||||
|
||||
def test_manifest_dataset_size():
|
||||
ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE)
|
||||
assert ds_total.get_dataset_size() == 4
|
||||
|
@ -95,10 +120,11 @@ def test_cifar10_dataset_size():
|
|||
assert ds_total.get_dataset_size() == 10000
|
||||
|
||||
# test get_dataset_size with usage flag
|
||||
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 0
|
||||
train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 10000
|
||||
test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size()
|
||||
assert test_size == 0
|
||||
|
||||
all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size()
|
||||
assert all_size == 10000
|
||||
|
||||
|
@ -120,8 +146,6 @@ def test_cifar100_dataset_size():
|
|||
assert ds_total.get_dataset_size() == 10000
|
||||
|
||||
# test get_dataset_size with usage flag
|
||||
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 0
|
||||
test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size()
|
||||
assert test_size == 10000
|
||||
all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size()
|
||||
|
@ -137,10 +161,97 @@ def test_cifar100_dataset_size():
|
|||
assert ds_shard_3_0.get_dataset_size() == 3334
|
||||
|
||||
|
||||
def test_voc_dataset_size():
|
||||
dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
|
||||
assert dataset.get_dataset_size() == 10
|
||||
|
||||
dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True,
|
||||
num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 5
|
||||
|
||||
|
||||
def test_coco_dataset_size():
|
||||
dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
|
||||
decode=True, shuffle=False)
|
||||
assert dataset.get_dataset_size() == 6
|
||||
|
||||
dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
|
||||
shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 3
|
||||
|
||||
|
||||
def test_celeba_dataset_size():
|
||||
dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
|
||||
assert dataset.get_dataset_size() == 4
|
||||
|
||||
dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_clue_dataset_size():
|
||||
dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
|
||||
dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_csv_dataset_size():
|
||||
dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
|
||||
shuffle=False)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
|
||||
dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
|
||||
shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_text_file_dataset_size():
|
||||
dataset = ds.TextFileDataset(TEXT_DATA_FILE)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
|
||||
dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_padded_dataset_size():
|
||||
dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}])
|
||||
assert dataset.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_pipeline_get_dataset_size():
|
||||
dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False)
|
||||
assert dataset.get_dataset_size() == 12
|
||||
|
||||
dataset = dataset.shuffle(buffer_size=3)
|
||||
assert dataset.get_dataset_size() == 12
|
||||
|
||||
decode_op = vision.Decode()
|
||||
resize_op = vision.RandomResize(10)
|
||||
|
||||
dataset = dataset.map([decode_op, resize_op], input_columns=["image"])
|
||||
assert dataset.get_dataset_size() == 12
|
||||
|
||||
dataset = dataset.batch(batch_size=3)
|
||||
assert dataset.get_dataset_size() == 4
|
||||
|
||||
dataset = dataset.repeat(count=2)
|
||||
assert dataset.get_dataset_size() == 8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_imagenet_rawdata_dataset_size()
|
||||
test_imagenet_tf_file_dataset_size()
|
||||
test_mnist_dataset_size()
|
||||
test_mind_dataset_size()
|
||||
test_manifest_dataset_size()
|
||||
test_cifar10_dataset_size()
|
||||
test_cifar100_dataset_size()
|
||||
test_voc_dataset_size()
|
||||
test_coco_dataset_size()
|
||||
test_celeba_dataset_size()
|
||||
test_clue_dataset_size()
|
||||
test_csv_dataset_size()
|
||||
test_text_file_dataset_size()
|
||||
test_padded_dataset_size()
|
||||
test_pipeline_get_dataset_size()
|
||||
|
|
|
@ -521,7 +521,7 @@ def test_chained_sampler_04():
|
|||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 24
|
||||
assert data1_size == 6
|
||||
|
||||
# Verify number of iterations
|
||||
num_iter = 0
|
||||
|
|
|
@ -182,6 +182,15 @@ def test_voc_exception():
|
|||
pass
|
||||
|
||||
|
||||
def test_voc_num_classes():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
assert data1.num_classes() is None
|
||||
|
||||
class_index = {'car': 0, 'cat': 1, 'train': 5}
|
||||
data2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True)
|
||||
assert data2.num_classes() is None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_voc_segmentation()
|
||||
test_voc_detection()
|
||||
|
@ -191,3 +200,4 @@ if __name__ == '__main__':
|
|||
test_case_1()
|
||||
test_case_2()
|
||||
test_voc_exception()
|
||||
test_voc_num_classes()
|
||||
|
|
|
@ -107,7 +107,7 @@ def test_decode_op():
|
|||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
assert "object has no attribute '_runtime_context'" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter2.__next__()
|
||||
|
@ -205,7 +205,7 @@ def test_generator_dict_3():
|
|||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
assert "object has no attribute '_runtime_context'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_dict_4():
|
||||
|
@ -396,7 +396,7 @@ def test_generator_tuple_3():
|
|||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
assert "object has no attribute '_runtime_context'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_tuple_4():
|
||||
|
@ -546,7 +546,7 @@ def test_generator_tuple_repeat_repeat_2():
|
|||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
assert "object has no attribute '_runtime_context'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_tuple_repeat_repeat_3():
|
||||
|
|
|
@ -74,9 +74,11 @@ def test_case2():
|
|||
|
||||
|
||||
def test_case3():
|
||||
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
|
||||
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(5)
|
||||
data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2)
|
||||
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename(
|
||||
["col_sint64"], ["a1"])
|
||||
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(5).rename(
|
||||
["col_sint64"], ["a2"])
|
||||
data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).rename(["col_sint64"], ["a3"])
|
||||
|
||||
data4 = ds.zip((data1, data2, data3))
|
||||
|
||||
|
@ -84,8 +86,9 @@ def test_case3():
|
|||
|
||||
|
||||
def test_case4():
|
||||
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
|
||||
data2 = ds.TFRecordDataset(FILES)
|
||||
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename(
|
||||
["col_sint64"], ["a1"])
|
||||
data2 = ds.TFRecordDataset(FILES, columns_list=["col_sint64"]).rename(["col_sint64"], ["a2"])
|
||||
assert data2.get_dataset_size() == 12
|
||||
data2 = data2.batch(2)
|
||||
assert data2.get_dataset_size() == 6
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue