added python api based on cpp api

1st draft of python iterator

Added Cifar10 and Cifar100 pybind port

Change pybind to use IR for Skip and Manifest

Signed-off-by: alex-yuyue <yue.yu1@huawei.com>

DatasetNode as a base for all IR nodes

namespace change

Fix the namespace issue and make ut tests work

Signed-off-by: alex-yuyue <yue.yu1@huawei.com>

Add VOCDataset

!63 Added RandomDataset
* Added RandomDataset

add imagefolder ir

Pybind switch: CelebA and UT

!61 CLUE example with class definition
* Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind
* Passing testcases
* Added CLUE, not working

add ManifestDataset IR

Signed-off-by: alex-yuyue <yue.yu1@huawei.com>

Update Coco & VOC & TFReader, Update clang-format, Reorder
datasets_binding

!69 Add Generator and move c_dataset.Iterator to dataset.Iterator
* Add GeneratorDataset to c_dataset
* Add GeneratorDataset to c_dataset

!67 Moving c_datasets and adding sampler wrapper
* Need to add create() method in datasets.py
* migration from c_dataset to dataset part 1

!71 Fix indent error
* Fix indentation error

!72 Fix c_api tests cases
* Fix c_api tests cases

!73 Added CSV Dataset
* Added CSVDataset

pybind switch: Take and CelebA fixes

!75 move c_dataset functionality to datasets
* Fixed existing testcases
* Added working clue and imagefolder
* Added sampler conversion from pybind
* Added sampler creation

!77 Add Python API tree
* Python API tree

add minddataset

TextFileDataset pybind

Rename to skip test_concat.py and test_minddataset_exception.py

!80 Add batch IR to python-api branch, most test cases work
* staging III
* staging, add pybind

Enable more c_api take and CelebA tests; delete util_c_api

!84 Schema changes in datasets.py
* Schema changes

!85 Remove input_indexes from sub-classes
* remove input_index from each subclass

!83 Remove C datasets
* Removed c_dataset package
* Remove c_datasets

!82  pybind switch: shuffle
* pybind switch: shuffle

!86 Add build_vocab
* Add build_vocab

Rebase with upstream/master
_shuffle conflict
BatchNode error

!88 Fix rebase problem
* fix rebase problem

Enable more unit tests; code typo/nit fixes

!91 Fix python vocag hang
* Fix python vocab hang

!89 Added BucketBatchByLength Pybind switch
* Added BucketBatchByLength

Update and enable more tet_c_api_*.py tests

!95 Add BuildSentencePeiceVocab
* - Add BuildSentencePeiceVocab

!96 Fix more tests
* - Fix some tests

- Enable more test_c_api_*
- Add syncwait

!99 pybind switch for device op
* pybind switch for device op

!93 Add getters to python API
* Add getters to python API

!101 Validate tree, error if graph
* - Add sync wait

!103 TFrecord/Random Datasets schema problem
* - TfRecord/Random schem aproblem

!102 Added filter pybind switch
* Added Filter pybind switch

!104 Fix num_samples
* - TfRecord/Random schem aproblem

!105 Fix to_device hang
* Fix to_device hang

!94 Adds Cache support for CLUE dataset
* Added cache for all dataset ops
* format change
* Added CLUE cache support
* Added Cache conversion

Add save pybind

fix compile err

init modify concat_node

!107 Fix some tests cases
* Fix tests cases

Enable and fix more tests

!109 pybind switch for get dataset size
* pybind_get_dataset_size

some check-code fixes for pylint, cpplint and clang-format

!113 Add callback
* revert
* dataset_sz 1 line
* fix typo
* get callback to work

!114 Make Android compile clean
* Make Android Compile Clean

Fix build issues due to rebase

!115 Fix more tests
* Fix tests cases
* !93 Add getters to python API

fix test_profiling.py

!116 fix get dataset size
* fix get dataset size

!117 GetColumnNames pybind switch
* Added GetColumnNames pybind switch

code-check fixes: clangformat, cppcheck, cpplint, pylint

Delete duplicate test_c_api_*.py files; more lint fixes

!121 Fix cpp tests
* Remove extra call to getNext in cpp tests

!122 Fix Schema with Generator
* Fix Schema with Generator

fix some cases of csv & mindrecord

!124 fix tfrecord get_dataset_size and add some UTs
* fix tfrecord get dataset size and add some ut for get_dataset_size

!125 getter separation
* Getter separation

!126 Fix sampler.GetNumSamples
* Fix sampler.GetNumSampler

!127 Assign runtime getter to each get function
* Assign runtime getter to each get function

Fix compile issues

!128 Match master code
* Match master code

!129 Cleanup DeviceOp/save code
* Cleanup ToDevice/Save code

!130 Add cache fix
* Added cache fix for map and image folder

!132 Fix testing team issues
* Pass queue_name from python to C++
* Add Schema.from_json

!131 Fix Cache op issues and delete de_pipeline
* Roll back C++ change
* Removed de_pipeline and passing all cache tests.
* fixed cache tests

!134 Cleanup datasets.py part1
* Cleanup dataset.py part1

!133 Updated validation for SentencePieceVocab.from_dataset
* Added type_check for column names in SentencePieceVocab.from_dataset

Rebase on master 181120 10:20

fix profiling

temporary solution of catching stauts from Node.Build()

!141 ToDevice Termination
* ToDevice termination

pylint fixes

!137 Fix test team issues and add some corresponding tests
* Fix test team issues and add some corresponding tests

!138 TreeGetter changes to use OptPass
* Getter changes to use OptPass (Zirui)

Rebase fix

!143 Fix cpplint issue
* Fix cpplint issue

pylint fixes in updated testcases

!145 Reset exceptions testcase
* reset exception test to master

!146 Fix Check_Pylint Error
* Fix Check_Pylint Error

!147 fix android
* fix android

!148 ToDevice changes
* Add ToDevice to the iterator List for cleanup at exit

!149 Pylint issue
* Add ToDevice to the iterator List for cleanup at exit

!150 Pylint 2
* Add ToDevice to the iterator List for cleanup at exit

!152 ExecutionTree error
* ET destructor error

!153 in getter_pass, only remove callback, without deleting map op
* getter pass no longer removes map

!156 early __del__ of iterator/to_device
* early __del__ of iterator

!155 Address review comments Eric 1
* Added one liner fix to validators.py
* roll back signature fix
* lint fix
* Eric Address comments 2
* C++ lint fix
* Address comments Eric 1

!158 Review rework for dataset bindings - part 1
* Reorder nodes repeat and rename
* Review rework for dataset bindings - part 1

!154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py)
* Fixing minor problems in the comments (datasets.py, python_tree_consum…

!157 add replace none
* Add replace_none to datasets.py, address comments in tests

Trying to resolve copy

Override the deepcopy method of deviceop

Create_ir_tree method

Create_ir_tree method 2

Create_ir_tree method 2

del to_device if already exists

del to_device if already exists

cache getters shapes and types

Added yolov3 relaxation, to be rolled back

Get shapes and types together

bypass yolo

NumWorkers for MapOp

revert Yolo

revert Thor

Print more info

Debug code: Update LOG INFO to LOG ERROR

do not remove epochctrl for getter pass

Remove repeat(1)

pritn batch size

add log to tree_consumer and device_queue op

Revert PR 8744

Signed-off-by: alex-yuyue <yue.yu1@huawei.com>

__del__ toDEvice

__del__ toDevice2

!165 add ifndef ENABLE_ANDROID to device queue print
* Add ifndef ENABLE_ANDROID to device queue print

revert some changes

!166 getter: get_data_info
* getter: get_data_info

!168 add back tree print
* revert info to warnning in one log
* add back the missed print tree log

Release GIL in GetDataInfo
This commit is contained in:
Eric Zhang 2020-07-16 17:34:09 -04:00 committed by Eric
parent d9b4b5c750
commit 809e1d5086
113 changed files with 3254 additions and 4629 deletions

View File

@ -2,9 +2,11 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_PYTHON) if (ENABLE_PYTHON)
add_library(APItoPython OBJECT add_library(APItoPython OBJECT
python/de_pipeline.cc
python/pybind_register.cc python/pybind_register.cc
python/bindings.cc python/pybind_conversion.cc
python/bindings/dataset/include/datasets_bindings.cc
python/bindings/dataset/include/iterator_bindings.cc
python/bindings/dataset/include/schema_bindings.cc
python/bindings/dataset/engine/cache/bindings.cc python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/core/bindings.cc python/bindings/dataset/core/bindings.cc
python/bindings/dataset/callback/bindings.cc python/bindings/dataset/callback/bindings.cc

View File

@ -115,7 +115,8 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// Function to return a transferred Node that transfers data through a device. // Function to return a transferred Node that transfers data through a device.
bool Dataset::DeviceQueue(bool send_epoch_end) { bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32_t num_epochs, bool send_epoch_end,
int32_t total_batches, bool create_data_info_queue) {
Status rc; Status rc;
// Build and launch tree // Build and launch tree
@ -126,11 +127,12 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
return false; return false;
} }
// Add TransferNode IR on top of dataset d // Add TransferNode IR on top of dataset
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end); auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), queue_name, device_type, send_epoch_end,
total_batches, create_data_info_queue);
// Get ToDevice consumer // Get ToDevice consumer
auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1); auto consumer = std::make_unique<ToDevice>(num_epochs);
ToDevice *consumer_ = consumer.get(); ToDevice *consumer_ = consumer.get();
rc = consumer->Init(ds); rc = consumer->Init(ds);
if (rc.IsError()) { if (rc.IsError()) {
@ -199,127 +201,55 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize() { int64_t Dataset::GetDatasetSize() {
int64_t dataset_size; int64_t dataset_size;
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1);
return -1; return dataset_size;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetDatasetSize(&dataset_size);
return rc.IsError() ? -1 : dataset_size;
} }
std::vector<DataType> Dataset::GetOutputTypes() { std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types; std::vector<DataType> types;
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputTypes(&types), {});
return types;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
return types;
}
rc = tree_getters_->GetOutputTypes(&types);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed.";
types.clear();
return types;
}
return types; return types;
} }
std::vector<TensorShape> Dataset::GetOutputShapes() { std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes; std::vector<TensorShape> shapes;
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputShapes(&shapes), {});
return shapes;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
return shapes;
}
rc = tree_getters_->GetOutputShapes(&shapes);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed.";
shapes.clear();
return shapes;
}
return shapes; return shapes;
} }
int64_t Dataset::GetNumClasses() { int64_t Dataset::GetNumClasses() {
int64_t num_classes; int64_t num_classes;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; RETURN_SECOND_IF_ERROR(tree_getters_->GetNumClasses(&num_classes), -1);
return -1; return num_classes;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetNumClasses(&num_classes);
return rc.IsError() ? -1 : num_classes;
} }
std::vector<std::string> Dataset::GetColumnNames() { std::vector<std::string> Dataset::GetColumnNames() {
std::vector<std::string> col_names; std::vector<std::string> col_names;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed."; RETURN_SECOND_IF_ERROR(tree_getters_->GetColumnNames(&col_names), {});
return std::vector<std::string>(); return col_names;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed.";
return std::vector<std::string>();
}
rc = tree_getters_->GetColumnNames(&col_names);
return rc.IsError() ? std::vector<std::string>() : col_names;
} }
std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() { std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing; std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed."; RETURN_SECOND_IF_ERROR(tree_getters_->GetClassIndexing(&output_class_indexing), {});
return output_class_indexing;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed.";
return output_class_indexing;
}
rc = tree_getters_->GetClassIndexing(&output_class_indexing);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed.";
output_class_indexing.clear();
return output_class_indexing;
}
return output_class_indexing; return output_class_indexing;
} }
@ -501,9 +431,13 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset(
std::function<TensorRow(TensorRow)> element_length_function, std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) { bool drop_remainder) {
auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, std::shared_ptr<TensorOp> c_func = nullptr;
bucket_batch_sizes, element_length_function, pad_info, if (element_length_function != nullptr) {
pad_to_bucket_boundary, drop_remainder); c_func = std::make_shared<CFuncOp>(element_length_function);
}
auto ds =
std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, bucket_batch_sizes,
c_func, pad_info, pad_to_bucket_boundary, drop_remainder);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
@ -522,7 +456,9 @@ ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datase
FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate, FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate,
std::vector<std::string> input_columns) { std::vector<std::string> input_columns) {
auto ds = std::make_shared<FilterNode>(input->IRNode(), predicate, input_columns); std::shared_ptr<TensorOp> c_func = nullptr;
if (predicate) c_func = std::make_shared<CFuncOp>(predicate);
auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, input_columns);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
@ -604,40 +540,20 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
#endif #endif
int64_t Dataset::GetBatchSize() { int64_t Dataset::GetBatchSize() {
int64_t batch_size; int64_t batch_size;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1);
return -1; return batch_size;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetBatchSize(&batch_size);
return rc.IsError() ? -1 : batch_size;
} }
int64_t Dataset::GetRepeatCount() { int64_t Dataset::GetRepeatCount() {
int64_t repeat_count; int64_t repeat_count;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init(); RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
if (rc.IsError()) { RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0);
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0);
return -1; return repeat_count;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetRepeatCount(&repeat_count);
return rc.IsError() ? 0 : repeat_count;
} }
std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) { std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
@ -720,62 +636,65 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}
// SchemaObj init function // SchemaObj init function
bool SchemaObj::init() { Status SchemaObj::init() {
if (schema_file_ != "") { if (!schema_file_.empty()) {
Path schema_file(schema_file_); Path schema_file(schema_file_);
if (!schema_file.Exists()) { CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(),
MS_LOG(ERROR) << "The file " << schema_file << " does not exist or permission denied!"; "The file " + schema_file_ + " does not exist or permission denied!");
return false;
}
nlohmann::json js; nlohmann::json js;
try { try {
std::ifstream in(schema_file_); std::ifstream in(schema_file_);
in >> js; in >> js;
if (js.find("columns") == js.end()) { CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
MS_LOG(ERROR) << "\"columns\" node is required in the schema json file."; "\"columns\" node is required in the schema json file.");
return false;
}
} catch (const std::exception &err) { } catch (const std::exception &err) {
MS_LOG(ERROR) << "Schema file failed to load"; RETURN_STATUS_SYNTAX_ERROR("Schema file failed to load");
return false;
} }
return from_json(js); return from_json(js);
} }
return true; return Status::OK();
} }
// Function to add a column to schema with a mstype de_type // Function to add a column to schema with a mstype de_type and known shape
bool SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) { Status SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) {
nlohmann::json new_column;
new_column["name"] = name;
// if de_type is mstype
DataType data_type = dataset::MSTypeToDEType(de_type); DataType data_type = dataset::MSTypeToDEType(de_type);
new_column["type"] = data_type.ToString(); return add_column(name, data_type.ToString(), shape);
if (shape.size() > 0) {
new_column["shape"] = shape;
new_column["rank"] = shape.size();
} else {
new_column["rank"] = 1;
}
columns_.push_back(new_column);
return true;
} }
// Function to add a column to schema with a string de_type // Function to add a column to schema with a string de_type and known shape
bool SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) { Status SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) {
DataType data_type(de_type);
CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
nlohmann::json new_column; nlohmann::json new_column;
new_column["name"] = name; new_column["name"] = name;
DataType data_type(de_type);
new_column["type"] = data_type.ToString(); new_column["type"] = data_type.ToString();
if (shape.size() > 0) { new_column["shape"] = shape;
new_column["shape"] = shape; new_column["rank"] = shape.size();
new_column["rank"] = shape.size();
} else {
new_column["rank"] = 1;
}
columns_.push_back(new_column); columns_.push_back(new_column);
return true; return Status::OK();
}
// Function to add a column to schema with a mstype de_type and without shape
Status SchemaObj::add_column(std::string name, TypeId de_type) {
DataType data_type = dataset::MSTypeToDEType(de_type);
return add_column(name, data_type.ToString());
}
// Function to add a column to schema with a string de_type and without shape
Status SchemaObj::add_column(std::string name, std::string de_type) {
DataType data_type(de_type);
CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
nlohmann::json new_column;
new_column["name"] = name;
new_column["type"] = data_type.ToString();
new_column["rank"] = 1;
columns_.push_back(new_column);
return Status::OK();
} }
std::string SchemaObj::to_json() { std::string SchemaObj::to_json() {
@ -792,7 +711,7 @@ std::string SchemaObj::to_json() {
return json_file.dump(2); return json_file.dump(2);
} }
bool SchemaObj::parse_column(nlohmann::json columns) { Status SchemaObj::parse_column(nlohmann::json columns) {
std::string name, de_type; std::string name, de_type;
std::vector<int32_t> shape; std::vector<int32_t> shape;
@ -802,15 +721,13 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
for (auto column : columns) { for (auto column : columns) {
auto key_name = column.find("name"); auto key_name = column.find("name");
if (key_name == column.end()) { if (key_name == column.end()) {
MS_LOG(ERROR) << "Column's name is missing"; RETURN_STATUS_SYNTAX_ERROR("Column's name is missing");
return false;
} }
name = *key_name; name = *key_name;
auto key_type = column.find("type"); auto key_type = column.find("type");
if (key_type == column.end()) { if (key_type == column.end()) {
MS_LOG(ERROR) << "Column's type is missing"; RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
return false;
} }
de_type = *key_type; de_type = *key_type;
@ -819,17 +736,14 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
if (key_shape != column.end()) { if (key_shape != column.end()) {
shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
} }
if (!add_column(name, de_type, shape)) { RETURN_IF_NOT_OK(add_column(name, de_type, shape));
return false;
}
} }
} else if (columns.type() == nlohmann::json::value_t::object) { } else if (columns.type() == nlohmann::json::value_t::object) {
for (const auto &it_child : columns.items()) { for (const auto &it_child : columns.items()) {
name = it_child.key(); name = it_child.key();
auto key_type = it_child.value().find("type"); auto key_type = it_child.value().find("type");
if (key_type == it_child.value().end()) { if (key_type == it_child.value().end()) {
MS_LOG(ERROR) << "Column's type is missing"; RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
return false;
} }
de_type = *key_type; de_type = *key_type;
@ -839,43 +753,45 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
} }
if (!add_column(name, de_type, shape)) { RETURN_IF_NOT_OK(add_column(name, de_type, shape));
return false;
}
} }
} else { } else {
MS_LOG(ERROR) << "columns must be dict or list, columns contain name, type, shape(optional)."; RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional).");
return false;
} }
return true; return Status::OK();
} }
bool SchemaObj::from_json(nlohmann::json json_obj) { Status SchemaObj::from_json(nlohmann::json json_obj) {
for (const auto &it_child : json_obj.items()) { for (const auto &it_child : json_obj.items()) {
if (it_child.key() == "datasetType") { if (it_child.key() == "datasetType") {
dataset_type_ = it_child.value(); dataset_type_ = it_child.value();
} else if (it_child.key() == "numRows") { } else if (it_child.key() == "numRows") {
num_rows_ = it_child.value(); num_rows_ = it_child.value();
} else if (it_child.key() == "columns") { } else if (it_child.key() == "columns") {
if (!parse_column(it_child.value())) { RETURN_IF_NOT_OK(parse_column(it_child.value()));
MS_LOG(ERROR) << "parse columns failed";
return false;
}
} else { } else {
MS_LOG(ERROR) << "Unknown field " << it_child.key(); RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key());
return false;
} }
} }
if (columns_.empty()) { if (columns_.empty()) {
MS_LOG(ERROR) << "Columns are missing."; RETURN_STATUS_SYNTAX_ERROR("Columns are missing.");
return false;
} }
if (num_rows_ <= 0) { if (num_rows_ < 0) {
MS_LOG(ERROR) << "numRows must be greater than 0"; RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0");
return false;
} }
return true; return Status::OK();
}
Status SchemaObj::FromJSONString(const std::string &json_string) {
try {
nlohmann::json js = nlohmann::json::parse(json_string);
CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
"\"columns\" node is required in the schema json JSON.");
RETURN_IF_NOT_OK(from_json(js));
} catch (const std::exception &err) {
RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse");
}
return Status::OK();
} }
// OTHER FUNCTIONS // OTHER FUNCTIONS

View File

@ -1,136 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/api/python/de_pipeline.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
DEPipeline, 0, ([](const py::module *m) {
(void)py::class_<DEPipeline>(*m, "DEPipeline")
.def(py::init<>())
.def(
"AddNodeToTree",
[](DEPipeline &de, const OpName &op_name, const py::dict &args) {
py::dict out;
THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
return out;
},
py::return_value_policy::reference)
.def_static("AddChildToParentNode",
[](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
})
.def("AssignRootNode",
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
.def("SetBatchParameters",
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
.def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); })
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
.def("GetColumnNames",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetColumnNames(&out));
return out;
})
.def("GetNextAsMap",
[](DEPipeline &de) {
py::dict out;
THROW_IF_ERROR(de.GetNextAsMap(&out));
return out;
})
.def("GetNextAsList",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetNextAsList(&out));
return out;
})
.def("GetOutputShapes",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetOutputShapes(&out));
return out;
})
.def("GetOutputTypes",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetOutputTypes(&out));
return out;
})
.def("GetDataInfo",
[](DEPipeline &de) {
py::list types, shapes;
THROW_IF_ERROR(de.GetDataInfo(&types, &shapes));
return py::make_tuple(types, shapes);
})
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
.def("GetBatchSize", &DEPipeline::GetBatchSize)
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
.def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;
});
}));
PYBIND_REGISTER(OpName, 0, ([](const py::module *m) {
(void)py::enum_<OpName>(*m, "OpName", py::arithmetic())
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BUCKETBATCH", OpName::kBucketBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)
.value("REPEAT", OpName::kRepeat)
.value("SKIP", OpName::kSkip)
.value("TAKE", OpName::kTake)
.value("ZIP", OpName::kZip)
.value("CONCAT", OpName::kConcat)
.value("MAP", OpName::kMap)
.value("FILTER", OpName::kFilter)
.value("DEVICEQUEUE", OpName::kDeviceQueue)
.value("GENERATOR", OpName::kGenerator)
.export_values()
.value("RENAME", OpName::kRename)
.value("TFREADER", OpName::kTfReader)
.value("PROJECT", OpName::kProject)
.value("IMAGEFOLDER", OpName::kImageFolder)
.value("MNIST", OpName::kMnist)
.value("MANIFEST", OpName::kManifest)
.value("VOC", OpName::kVoc)
.value("COCO", OpName::kCoco)
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData)
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile)
.value("EPOCHCTRL", OpName::kEpochCtrl)
.value("CSV", OpName::kCsv)
.value("CLUE", OpName::kClue);
}));
} // namespace dataset
} // namespace mindspore

View File

@ -19,8 +19,10 @@
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h" #include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/client.h" // DE client
#include "minddata/dataset/util/status.h"
#include "pybind11/numpy.h"
#include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/constants.h"
#include "minddata/dataset/api/python/de_pipeline.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {

View File

@ -0,0 +1,551 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/datasets.h"
// IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
// IR non-leaf nodes - for android
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
#endif
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/services.h"
// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
// IR leaf nodes disabled for android
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
#endif
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
.def("SetNumWorkers",
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) {
return num_workers ? self->SetNumWorkers(*num_workers) : self;
})
.def(
"Zip",
[](std::shared_ptr<DatasetNode> self, py::list datasets) {
auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets)));
THROW_IF_ERROR(zip->ValidateParams());
return zip;
},
py::arg("datasets"));
}));
// PYBIND FOR LEAF NODES
// (In alphabetical order)
PYBIND_REGISTER(
CelebANode, 2, ([](const py::module *m) {
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", "to create a CelebANode")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, bool decode,
std::optional<py::list> extensions, std::optional<std::shared_ptr<CacheClient>> cc) {
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
toStringSet(extensions), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(celebA->ValidateParams());
return celebA;
}));
}));
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
"to create a Cifar10Node")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(cifar10->ValidateParams());
return cifar10;
}));
}));
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
"to create a Cifar100Node")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(cifar100->ValidateParams());
return cifar100;
}));
}));
PYBIND_REGISTER(
CLUENode, 2, ([](const py::module *m) {
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode")
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<CLUENode> clue_node =
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(clue_node->ValidateParams());
return clue_node;
}));
}));
PYBIND_REGISTER(
CocoNode, 2, ([](const py::module *m) {
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", "to create a CocoNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, bool decode,
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(coco->ValidateParams());
return coco;
}));
}));
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults),
column_names, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(csv->ValidateParams());
return csv;
}));
}));
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
*m, "GeneratorNode", "to create a GeneratorNode")
.def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types) {
auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types);
THROW_IF_ERROR(gen->ValidateParams());
return gen;
}))
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema) {
auto gen = std::make_shared<GeneratorNode>(generator_function, schema);
THROW_IF_ERROR(gen->ValidateParams());
return gen;
}));
}));
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
*m, "ImageFolderNode", "to create an ImageFolderNode")
.def(py::init([](std::string dataset_dir, bool decode, std::optional<py::handle> sampler,
std::optional<py::list> extensions, std::optional<py::dict> class_indexing,
std::optional<std::shared_ptr<CacheClient>> cc) {
bool recursive = true;
auto imagefolder = std::make_shared<ImageFolderNode>(
dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions),
toStringMap(class_indexing), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(imagefolder->ValidateParams());
return imagefolder;
}));
}));
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
"to create a ManifestNode")
.def(py::init([](std::string dataset_file, std::string usage, std::optional<py::handle> sampler,
std::optional<py::dict> class_indexing, bool decode,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
toStringMap(class_indexing), decode,
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(manifest->ValidateParams());
return manifest;
}));
}));
PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
"to create a MindDataNode")
.def(py::init([](std::string dataset_file, std::optional<py::list> columns_list,
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
nlohmann::json padded_sample_json;
std::map<std::string, std::string> sample_bytes;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
auto minddata =
std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list),
toSamplerObj(sampler, true), padded_sample_json, num_padded);
minddata->SetSampleBytes(&sample_bytes);
THROW_IF_ERROR(minddata->ValidateParams());
return minddata;
}))
.def(py::init([](py::list dataset_file, std::optional<py::list> columns_list,
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
nlohmann::json padded_sample_json;
std::map<std::string, std::string> sample_bytes;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
auto minddata =
std::make_shared<MindDataNode>(toStringVector(dataset_file), toStringVector(columns_list),
toSamplerObj(sampler, true), padded_sample_json, num_padded);
minddata->SetSampleBytes(&sample_bytes);
THROW_IF_ERROR(minddata->ValidateParams());
return minddata;
}));
}));
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
"to create an MnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(mnist->ValidateParams());
return mnist;
}));
}));
PYBIND_REGISTER(
RandomNode, 2, ([](const py::module *m) {
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode")
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto random_node =
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(random_node->ValidateParams());
return random_node;
}))
.def(py::init([](int32_t total_rows, std::string schema, std::optional<py::list> columns_list,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto random_node =
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(random_node->ValidateParams());
return random_node;
}));
}));
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
"to create a TextFileNode")
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>(
toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id,
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(textfile_node->ValidateParams());
return textfile_node;
}));
}));
PYBIND_REGISTER(
TFRecordNode, 2, ([](const py::module *m) {
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
"to create a TFRecordNode")
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
std::optional<int32_t> shard_id, bool shard_equal_rows,
std::optional<std::shared_ptr<CacheClient>> cc) {
if (!num_samples) {
*num_samples = 0;
}
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord;
}))
.def(py::init([](py::list dataset_files, std::string schema, std::optional<py::list> columns_list,
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
std::optional<int32_t> shard_id, bool shard_equal_rows,
std::optional<std::shared_ptr<CacheClient>> cc) {
if (!num_samples) {
*num_samples = 0;
}
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord;
}));
}));
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
.def(
py::init([](std::string dataset_dir, std::string task, std::string usage,
std::optional<py::dict> class_indexing, bool decode,
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<VOCNode> voc =
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
toSamplerObj(sampler), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(voc->ValidateParams());
return voc;
}));
}));
// PYBIND FOR NON-LEAF NODES
// (In alphabetical order)
PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) {
(void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode",
"to create a BatchNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t batch_size, bool drop_remainder,
bool pad, py::list in_col_names, py::list out_col_names, py::list col_order,
py::object size_obj, py::object map_obj, py::dict pad_info) {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
if (pad) {
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
}
py::function size_func =
py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function();
py::function map_func =
py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function();
auto batch = std::make_shared<BatchNode>(
self, batch_size, drop_remainder, pad, toStringVector(in_col_names),
toStringVector(out_col_names), toStringVector(col_order), size_func, map_func, c_pad_info);
THROW_IF_ERROR(batch->ValidateParams());
return batch;
}));
}));
PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
(void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>(
*m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode")
.def(py::init([](std::shared_ptr<DatasetNode> dataset, py::list column_names,
std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes,
py::object element_length_function, py::dict pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
auto bucket_batch = std::make_shared<BucketBatchByLengthNode>(
dataset, toStringVector(column_names), bucket_boundaries, bucket_batch_sizes,
toPyFuncOp(std::move(element_length_function), DataType::DE_INT32), c_pad_info,
pad_to_bucket_boundary, drop_remainder);
THROW_IF_ERROR(bucket_batch->ValidateParams());
return bucket_batch;
}),
py::arg("dataset"), py::arg("column_names"), py::arg("bucket_boundaries"),
py::arg("bucket_batch_sizes"), py::arg("element_length_function") = py::none(),
py::arg("pad_info"), py::arg("pad_to_bucket_boundary"), py::arg("drop_remainder"));
}));
PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) {
(void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>(
*m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<SentencePieceVocab> vocab,
const std::vector<std::string> &col_names, uint32_t vocab_size,
float character_coverage, SentencePieceModel model_type,
const std::unordered_map<std::string, std::string> &params) {
auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>(
self, vocab, col_names, vocab_size, character_coverage, model_type, params);
THROW_IF_ERROR(build_sentence_vocab->ValidateParams());
return build_sentence_vocab;
}));
}));
PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
(void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>(
*m, "BuildVocabNode", "to create a BuildVocabNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<Vocab> vocab, py::list columns,
py::tuple freq_range, int64_t top_k, py::list special_tokens, bool special_first) {
auto build_vocab =
std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range),
top_k, toStringVector(special_tokens), special_first);
THROW_IF_ERROR(build_vocab->ValidateParams());
return build_vocab;
}));
}));
PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
(void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode",
"to create a ConcatNode")
.def(
py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, std::optional<py::handle> sampler,
py::list children_flag_and_nums, py::list children_start_end_index) {
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
toPairVector(children_flag_and_nums),
toPairVector(children_start_end_index));
THROW_IF_ERROR(concat->ValidateParams());
return concat;
}));
}));
PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
(void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode",
"to create a FilterNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::object predicate,
std::vector<std::string> input_columns) {
auto filter =
std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns);
THROW_IF_ERROR(filter->ValidateParams());
return filter;
}));
}));
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> operations,
std::optional<py::list> input_columns, std::optional<py::list> output_columns,
std::optional<py::list> project_columns,
std::optional<std::shared_ptr<CacheClient>> cc,
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) {
auto map = std::make_shared<MapNode>(
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)),
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()));
THROW_IF_ERROR(map->ValidateParams());
return map;
}));
}));
PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {
(void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode",
"to create a ProjectNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list columns) {
auto project = std::make_shared<ProjectNode>(self, toStringVector(columns));
THROW_IF_ERROR(project->ValidateParams());
return project;
}));
}));
PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) {
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode",
"to create a RenameNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> input_columns,
std::optional<py::list> output_columns) {
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns),
toStringVector(output_columns));
THROW_IF_ERROR(rename->ValidateParams());
return rename;
}));
}));
PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
(void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode",
"to create a RepeatNode")
.def(py::init([](std::shared_ptr<DatasetNode> input, int32_t count) {
auto repeat = std::make_shared<RepeatNode>(input, count);
THROW_IF_ERROR(repeat->ValidateParams());
return repeat;
}));
}));
PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) {
(void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode",
"to create a ShuffleNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t shuffle_size, bool reset_every_epoch) {
auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch);
THROW_IF_ERROR(shuffle->ValidateParams());
return shuffle;
}));
}));
PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
(void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode",
"to create a SkipNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
auto skip = std::make_shared<SkipNode>(self, count);
THROW_IF_ERROR(skip->ValidateParams());
return skip;
}));
}));
PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) {
(void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode",
"to create a SyncWaitNode")
.def(
py::init([](std::shared_ptr<DatasetNode> self, std::string condition_name, py::object callback) {
py::function callback_func =
py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function();
auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback);
THROW_IF_ERROR(sync_wait->ValidateParams());
return sync_wait;
}));
}));
PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
(void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode",
"to create a TakeNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
auto take = std::make_shared<TakeNode>(self, count);
THROW_IF_ERROR(take->ValidateParams());
return take;
}));
}));
PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
"to create a TransferNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type,
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) {
auto transfer = std::make_shared<TransferNode>(self, queue_name, device_type, send_epoch_end,
total_batch, create_data_info_queue);
THROW_IF_ERROR(transfer->ValidateParams());
return transfer;
}));
}));
PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
(void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode")
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets) {
auto zip = std::make_shared<ZipNode>(datasets);
THROW_IF_ERROR(zip->ValidateParams());
return zip;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,168 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/engine/python_runtime_context.h"
#include "minddata/dataset/engine/consumers/python_tree_consumer.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer");
}));
PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) {
(void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>(
*m, "PythonIteratorConsumer")
.def(py::init<int32_t>())
.def("Init", [](PythonIteratorConsumer &self,
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("GetNextAsMap",
[](PythonIteratorConsumer &self) {
py::dict output;
THROW_IF_ERROR(self.GetNextAsDict(&output));
return output;
})
.def("GetNextAsList", [](PythonIteratorConsumer &self) {
py::list output;
THROW_IF_ERROR(self.GetNextAsList(&output));
return output;
});
}));
PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) {
(void)py::class_<PythonTreeGetters, TreeConsumer, std::shared_ptr<PythonTreeGetters>>(*m,
"TreeGetters")
.def(py::init<>())
.def("Init",
[](PythonTreeGetters &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("GetOutputShapes",
[](PythonTreeGetters &self) {
std::vector<TensorShape> shapes;
THROW_IF_ERROR(self.GetOutputShapes(&shapes));
return shapesToListOfShape(shapes);
})
.def("GetOutputTypes",
[](PythonTreeGetters &self) {
std::vector<DataType> types;
THROW_IF_ERROR(self.GetOutputTypes(&types));
return typesToListOfType(types);
})
.def("GetNumClasses",
[](PythonTreeGetters &self) {
int64_t num_classes;
THROW_IF_ERROR(self.GetNumClasses(&num_classes));
return num_classes;
})
.def("GetRepeatCount",
[](PythonTreeGetters &self) {
int64_t repeat_count;
THROW_IF_ERROR(self.GetRepeatCount(&repeat_count));
return repeat_count;
})
.def("GetBatchSize",
[](PythonTreeGetters &self) {
int64_t batch_size;
THROW_IF_ERROR(self.GetBatchSize(&batch_size));
return batch_size;
})
.def("GetColumnNames",
[](PythonTreeGetters &self) {
std::vector<std::string> col_names;
THROW_IF_ERROR(self.GetColumnNames(&col_names));
return col_names;
})
.def("GetClassIndexing",
[](PythonTreeGetters &self) {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing));
return output_class_indexing;
})
.def("GetDatasetSize",
[](PythonTreeGetters &self) {
int64_t dataset_size;
THROW_IF_ERROR(self.GetDatasetSize(&dataset_size));
return dataset_size;
})
.def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; });
}));
PYBIND_REGISTER(PythonRuntimeContext, 2, ([](const py::module *m) {
(void)py::class_<PythonRuntimeContext, std::shared_ptr<PythonRuntimeContext>>(*m,
"PythonRuntimeContext")
.def(py::init<>())
.def("Init", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Init()); })
.def("AssignConsumer", &PythonRuntimeContext::AssignConsumer)
.def("Terminate", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Terminate()); })
.def("GetConsumer", &PythonRuntimeContext::GetPythonConsumer, py::return_value_policy::reference)
.def("__deepcopy__", [](py::object &runtime_context, py::dict memo) { return runtime_context; });
}));
PYBIND_REGISTER(PythonBuildVocabConsumer, 1, ([](const py::module *m) {
(void)py::class_<PythonBuildVocabConsumer, TreeConsumer, std::shared_ptr<PythonBuildVocabConsumer>>(
*m, "PythonBuildVocabConsumer")
.def(py::init<>())
.def("Init", [](PythonBuildVocabConsumer &self,
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("Start", [](PythonBuildVocabConsumer &self) { THROW_IF_ERROR(self.Start()); });
}));
PYBIND_REGISTER(ToDevice, 1, ([](const py::module *m) {
(void)py::class_<ToDevice, TreeConsumer, std::shared_ptr<ToDevice>>(*m, "ToDevice")
.def(py::init<int32_t>())
.def("Init", [](ToDevice &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); })
.def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); })
.def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); })
.def("GetDataInfo",
[](ToDevice &self) {
std::vector<DataType> types_c;
std::vector<TensorShape> shapes_c;
{
py::gil_scoped_release rel;
THROW_IF_ERROR(self.GetDataInfo(&types_c, &shapes_c));
}
py::list types, shapes;
for (auto el : types_c) {
types.append(el.AsNumpyType());
py::list shape;
}
for (auto el : shapes_c) {
py::list shape = el.AsPyList();
shapes.append(shape);
}
return py::make_tuple(types, shapes);
})
.def("__deepcopy__", [](py::object &to_device, py::dict memo) { return to_device; });
}));
PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) {
(void)py::class_<PythonSaveToDisk, TreeConsumer, std::shared_ptr<PythonSaveToDisk>>(
*m, "PythonSaveToDisk")
.def(py::init([](std::string &dataset_path, int32_t numFiles, std::string &datasetType) {
auto save = std::make_shared<PythonSaveToDisk>(dataset_path, numFiles, datasetType);
THROW_IF_ERROR(save->ValidateParams());
return save;
}))
.def("Init",
[](PythonSaveToDisk &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); });
}));
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
SchemaObj, 0, ([](const py::module *m) {
(void)py::class_<SchemaObj, std::shared_ptr<SchemaObj>>(*m, "SchemaObj", "to create a SchemaObj")
.def(py::init([](std::string schema_file) {
auto schema = std::make_shared<SchemaObj>(schema_file);
THROW_IF_ERROR(schema->init());
return schema;
}))
.def("add_column", [](SchemaObj &self, std::string name, TypeId de_type,
std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); })
.def("add_column", [](SchemaObj &self, std::string name, std::string de_type,
std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); })
.def("add_column",
[](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
.def("add_column", [](SchemaObj &self, std::string name,
std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
.def("to_json", &SchemaObj::to_json)
.def("to_string", &SchemaObj::to_string)
.def("from_string",
[](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.FromJSONString(json_string)); })
.def("set_dataset_type", [](SchemaObj &self, std::string dataset_type) { self.set_dataset_type(dataset_type); })
.def("set_num_rows", [](SchemaObj &self, int32_t num_rows) { self.set_num_rows(num_rows); })
.def("get_num_rows", &SchemaObj::get_num_rows)
.def("__deepcopy__", [](py::object &schema, py::dict memo) { return schema; });
}));
} // namespace dataset
} // namespace mindspore

View File

@ -17,7 +17,6 @@
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h" #include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/api/python/de_pipeline.h"
#include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h" #include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h"
#include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h" #include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h"

File diff suppressed because it is too large Load Diff

View File

@ -1,265 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
#include <iostream>
#include <map>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "minddata/dataset/core/client.h" // DE client
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/util/status.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace mindspore {
namespace dataset {
using json = nlohmann::json;
using DsOpPtr = std::shared_ptr<DatasetOp>;
class CacheClient;
// enum for the dataset operator names
enum OpName {
kShuffle,
kMindrecord,
kBatch,
kBucketBatch,
kBarrier,
kCache,
kRepeat,
kSkip,
kTake,
kZip,
kConcat,
kMap,
kFilter,
kDeviceQueue,
kGenerator,
kRename,
kTfReader,
kProject,
kImageFolder,
kMnist,
kManifest,
kVoc,
kCoco,
kCifar10,
kCifar100,
kCelebA,
kRandomData,
kTextFile,
kBuildVocab,
kClue,
kEpochCtrl,
kSentencePieceVocab,
kCsv
};
// The C++ binder class that we expose to the python script.
class DEPipeline {
public:
DEPipeline();
~DEPipeline();
// Function to add a Node to the Execution Tree.
Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output);
// Function to add a child and parent relationship.
static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op);
// Function to assign the node as root.
Status AssignRootNode(const DsOpPtr &dataset_op);
// Function to get the column names in the last node in the tree in order
Status GetColumnNames(py::list *output);
// Function to prepare the tree for execution
Status PrepareTree(const int32_t num_epochs);
// Function to launch the tree execution.
Status LaunchTreeExec();
// Get a row of data as dictionary of column name to the value.
Status GetNextAsMap(py::dict *output);
// Get a row of data as list.
Status GetNextAsList(py::list *output);
Status GetOutputShapes(py::list *output);
Status GetOutputTypes(py::list *output);
Status GetDataInfo(py::list *types, py::list *shapes);
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
int GetDatasetSize() const;
int GetBatchSize() const;
int GetRepeatCount() const;
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
template <typename T, typename S>
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert = false);
Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
const TensorRow &row, json *schema, std::vector<std::string> *index_fields);
Status FetchDataFromTensorRow(const TensorRow &row,
const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data,
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
Status BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded);
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);
Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseProjectOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseImageFolderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseManifestOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
void PrintTree();
int32_t GetNumClasses() const;
Status ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status SetBatchParameters(const py::dict &args);
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status StopSend();
Status ContinueSend();
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
private:
// Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_;
std::unique_ptr<DatasetIterator> iterator_;
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
/// \brief Helper function to inject a cache operator over top of the current operation being built.
/// \param[in] cache_client The client to use for caching
/// \param[in] num_workers The number of workers to use in the cache op
/// \param[in] input_op The operator to build the cache on top of
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the cache operator
/// \return Status return code
Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *cache_op);
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
/// \param[in] shuffle_size The size to use in the shuffle buffer
/// \param[in] input_op The operator to build shuffle on top of
/// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the shuffle operator
/// \return Status return code
Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *shuffle_op);
/// \brief Helper function to compute the shuffle size
/// \param[in] num_files The number of files in the dataset
/// \param[in] num_devices The number of devices in the dataset
/// \param[in] num_rows The number of rows in the dataset
/// \param[in] total_rows An upper bound on the total rows in the dataset
/// \param[out] shuffle_size The resultant computed shuffle size
/// \return Status return code
Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int64_t *shuffle_size);
int batch_size_;
int repeat_num_;
int num_rows_;
int num_classes_;
int temp_batch_size_;
bool temp_drop_remainder_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_

View File

@ -0,0 +1,265 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_conversion.h"
namespace mindspore {
namespace dataset {
float toFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
int toInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
int64_t toInt64(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
bool toBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
std::string toString(const py::handle &handle) { return py::reinterpret_borrow<py::str>(handle); }
std::set<std::string> toStringSet(const std::optional<py::list> list) {
std::set<std::string> set;
if (list) {
for (auto l : *list) {
if (!l.is_none()) {
(void)set.insert(py::str(l));
}
}
}
return set;
}
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict) {
std::map<std::string, int32_t> map;
if (dict) {
for (auto p : *dict) {
(void)map.emplace(toString(p.first), toInt(p.second));
}
}
return map;
}
std::vector<std::string> toStringVector(const std::optional<py::list> list) {
std::vector<std::string> vector;
if (list) {
for (auto l : *list) {
if (l.is_none())
vector.emplace_back("");
else
vector.push_back(py::str(l));
}
}
return vector;
}
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple) {
std::pair<int64_t, int64_t> pair;
if (tuple) {
pair = std::make_pair(toInt64((*tuple)[0]), toInt64((*tuple)[1]));
}
return pair;
}
std::vector<std::pair<int, int>> toPairVector(const py::list list) {
std::vector<std::pair<int, int>> vector;
if (list) {
for (auto data : list) {
auto l = data.cast<py::tuple>();
if (l[1].is_none())
vector.emplace_back(toInt64(l[0]), 0);
else
vector.emplace_back(toInt64(l[0]), toInt64(l[1]));
}
}
return vector;
}
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations) {
std::vector<std::shared_ptr<TensorOperation>> vector;
if (operations) {
for (auto op : *operations) {
std::shared_ptr<TensorOp> tensor_op;
if (py::isinstance<TensorOp>(op)) {
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
} else if (py::isinstance<py::function>(op)) {
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
} else {
THROW_IF_ERROR(
[]() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }());
}
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
}
}
return vector;
}
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets) {
std::vector<std::shared_ptr<DatasetNode>> vector;
vector.push_back(self);
if (datasets) {
for (auto ds : *datasets) {
if (py::isinstance<DatasetNode>(ds)) {
vector.push_back(ds.cast<std::shared_ptr<DatasetNode>>());
} else {
THROW_IF_ERROR(
[]() { RETURN_STATUS_UNEXPECTED("Error: datasets is not recognised (not a DatasetNode instance)."); }());
}
}
}
return vector;
}
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset) {
if (py_sampler) {
std::shared_ptr<SamplerObj> sampler_obj;
if (!isMindDataset) {
// Common Sampler
std::shared_ptr<SamplerRT> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create");
sampler = create().cast<std::shared_ptr<SamplerRT>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
} else {
// Mindrecord Sampler
std::shared_ptr<mindrecord::ShardOperator> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create_for_minddataset");
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
}
return sampler_obj;
} else {
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: sampler input is not SamplerRT."); }());
}
return nullptr;
}
// Here we take in a python object, that holds a reference to a C++ object
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc) {
if (cc) {
std::shared_ptr<DatasetCache> built_cache;
// Common Sampler
built_cache = std::make_shared<PreBuiltDatasetCache>(std::move(cc.value()));
return built_cache;
} else {
// don't need to check here as cache is not enabled.
return nullptr;
}
}
ShuffleMode toShuffleMode(const int32_t shuffle) {
if (shuffle == 0) return ShuffleMode::kFalse;
if (shuffle == 1) return ShuffleMode::kFiles;
if (shuffle == 2) return ShuffleMode::kGlobal;
return ShuffleMode();
}
std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases) {
std::vector<std::shared_ptr<CsvBase>> vector;
if (csv_bases) {
for (auto base : *csv_bases) {
if (py::isinstance<py::int_>(base)) {
vector.push_back(std::make_shared<CsvRecord<int>>(CsvType::INT, toInt(base)));
} else if (py::isinstance<py::float_>(base)) {
vector.push_back(std::make_shared<CsvRecord<float>>(CsvType::FLOAT, toFloat(base)));
} else if (py::isinstance<py::str>(base)) {
vector.push_back(std::make_shared<CsvRecord<std::string>>(CsvType::STRING, toString(base)));
} else {
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: each default value must be int, float, or string"); }());
}
}
}
return vector;
}
Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json,
std::map<std::string, std::string> *sample_bytes) {
for (const py::handle &key : padded_sample) {
if (py::isinstance<py::bytes>(padded_sample[key])) {
(*sample_bytes)[py::str(key).cast<std::string>()] = padded_sample[key].cast<std::string>();
// py::str(key) enter here will loss its key name, so we create an unuse key for it in json, to pass ValidateParam
(*padded_sample_json)[py::str(key).cast<std::string>()] = nlohmann::json::object();
} else {
nlohmann::json obj_json;
if (padded_sample[key].is_none()) {
obj_json = nullptr;
} else if (py::isinstance<py::int_>(padded_sample[key])) {
obj_json = padded_sample[key].cast<int64_t>();
} else if (py::isinstance<py::float_>(padded_sample[key])) {
obj_json = padded_sample[key].cast<double>();
} else if (py::isinstance<py::str>(padded_sample[key])) {
obj_json = padded_sample[key].cast<std::string>(); // also catch py::bytes
} else {
MS_LOG(ERROR) << "Python object convert to json failed: " << py::cast<std::string>(padded_sample[key]);
RETURN_STATUS_SYNTAX_ERROR("Python object convert to json failed");
}
(*padded_sample_json)[py::str(key).cast<std::string>()] = obj_json;
}
}
return Status::OK();
}
Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info) {
for (auto p : value) {
if (!p.second.is_none()) {
auto tp = py::reinterpret_borrow<py::tuple>(p.second);
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
std::shared_ptr<Tensor> pad_val = nullptr;
if (py::isinstance<py::str>(tp[1])) {
std::string pad_val_string = tp[1].is_none() ? "" : toString(tp[1]);
CHECK_FAIL_RETURN_UNEXPECTED(
Tensor::CreateFromVector(std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar(), &pad_val),
"Cannot create pad_value Tensor");
} else {
float pad_val_float = tp[1].is_none() ? 0 : toFloat(tp[1]);
CHECK_FAIL_RETURN_UNEXPECTED(
Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_val),
"Cannot create pad_value Tensor");
pad_val->SetItemAt<float>({}, pad_val_float);
}
(void)pad_info->insert({toString(p.first), {shape, pad_val}});
} else { // tuple is None
(void)pad_info->insert({toString(p.first), {TensorShape({}), nullptr}});
}
}
return Status::OK();
}
std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type) {
std::shared_ptr<TensorOp> py_func;
if (!func.is_none()) {
py::function py_function = func.cast<py::function>();
py_func = std::make_shared<PyFuncOp>(py_function, data_type);
} else {
py_func = nullptr;
}
return py_func;
}
py::list shapesToListOfShape(std::vector<TensorShape> shapes) {
py::list shape_list;
for (const auto &shape : shapes) {
shape_list.append(shape.AsVector());
}
return shape_list;
}
py::list typesToListOfType(std::vector<DataType> types) {
py::list type_list;
for (const auto &type : types) {
type_list.append(type.AsNumpyType());
}
return type_list;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/kernels/py_func_op.h"
namespace py = pybind11;
namespace mindspore {
namespace dataset {
float toFloat(const py::handle &handle);
int toInt(const py::handle &handle);
int64_t toInt64(const py::handle &handle);
bool toBool(const py::handle &handle);
std::string toString(const py::handle &handle);
std::set<std::string> toStringSet(const std::optional<py::list> list);
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict);
std::vector<std::string> toStringVector(const std::optional<py::list> list);
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple);
std::vector<std::pair<int, int>> toPairVector(const py::list list);
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations);
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets);
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false);
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc);
ShuffleMode toShuffleMode(const int32_t shuffle);
std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases);
std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type);
Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json,
std::map<std::string, std::string> *sample_bytes);
Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info);
py::list shapesToListOfShape(std::vector<TensorShape> shapes);
py::list typesToListOfType(std::vector<DataType> types);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_

View File

@ -190,6 +190,23 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
return sampler; return sampler;
} }
#ifndef ENABLE_ANDROID
// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler)
: sp_(std::move(sampler)), sp_minddataset_(nullptr) {}
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_(nullptr), sp_minddataset_(std::move(sampler)) {}
#endif
bool PreBuiltSamplerObj::ValidateParams() { return true; }
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object // runtime mindrecord sampler object

View File

@ -222,6 +222,13 @@ Status OneHotOperation::ValidateParams() {
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
// PreBuiltOperation
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {}
Status PreBuiltOperation::ValidateParams() { return Status::OK(); }
std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
// RandomApplyOperation // RandomApplyOperation
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
: transforms_(transforms), prob_(prob) {} : transforms_(transforms), prob_(prob) {}

View File

@ -18,6 +18,7 @@ set(SRC_FILES_LIST
dataset_iterator.cc dataset_iterator.cc
tree_adapter.cc tree_adapter.cc
runtime_context.cc runtime_context.cc
python_runtime_context.cc
consumers/tree_consumer.cc consumers/tree_consumer.cc
) )
if (ENABLE_PYTHON) if (ENABLE_PYTHON)

View File

@ -32,15 +32,37 @@ Status PythonIteratorConsumer::GetNextAsList(py::list *out) {
} }
return Status::OK(); return Status::OK();
} }
Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) { Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) {
std::unordered_map<std::string, TensorPtr> row; std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec;
Status s;
{ {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
RETURN_IF_NOT_OK(GetNextAsMap(&row)); s = GetNextAsOrderedPair(&vec);
} }
for (auto el : row) { RETURN_IF_NOT_OK(s);
(*out)[common::SafeCStr(el.first)] = el.second; // Generate Python dict, python dict maintains its insertion order
for (const auto &pair : vec) {
(*out)[common::SafeCStr(pair.first)] = pair.second;
} }
return Status::OK(); return Status::OK();
} }
Status PythonBuildVocabConsumer::Start() {
py::gil_scoped_release gil_release;
return BuildVocabConsumer::Start();
}
Status PythonSaveToDisk::Save() {
py::gil_scoped_release gil_release;
return SaveToDisk::Save();
}
PythonSaveToDisk::PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType)
: SaveToDisk(datasetPath, numFiles, datasetType) {}
Status PythonTreeGetters::GetRow(TensorRow *r) {
py::gil_scoped_release gil_release;
return TreeGetters::GetRow(r);
}
} // namespace mindspore::dataset } // namespace mindspore::dataset

View File

@ -44,5 +44,21 @@ class PythonIteratorConsumer : public IteratorConsumer {
/// \return Status error code /// \return Status error code
Status GetNextAsDict(py::dict *out); Status GetNextAsDict(py::dict *out);
}; };
class PythonBuildVocabConsumer : public BuildVocabConsumer {
public:
Status Start() override;
};
class PythonSaveToDisk : public SaveToDisk {
public:
PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType);
Status Save() override;
};
class PythonTreeGetters : public TreeGetters {
public:
Status GetRow(TensorRow *r) override;
};
} // namespace mindspore::dataset } // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_

View File

@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include "minddata/dataset/engine/consumers/tree_consumer.h" #include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/tree_adapter.h" #include "minddata/dataset/engine/tree_adapter.h"
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_header.h" #include "minddata/mindrecord/include/shard_header.h"
@ -35,7 +36,7 @@ namespace mindspore::dataset {
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); } Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); }
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->ServiceStop(); }
// IteratorConsumer // IteratorConsumer
Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) { Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
@ -73,6 +74,38 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
return Status::OK(); return Status::OK();
} }
Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec) {
CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
TensorRow curr_row;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row));
RETURN_OK_IF_TRUE(curr_row.empty());
size_t num_cols = curr_row.size(); // num_cols is non-empty.
// order the column names according to their ids
if (column_order_.empty()) {
const int32_t invalid_col_id = -1;
column_order_.resize(num_cols, {std::string(), invalid_col_id});
for (const auto &itr : tree_adapter_->GetColumnNameMap()) {
int32_t ind = itr.second;
CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
column_order_[ind] = std::make_pair(itr.first, ind);
}
// error check, make sure the ids in col_name_id_map are continuous and starts from 0
for (const auto &col : column_order_) {
CHECK_FAIL_RETURN_UNEXPECTED(col.second != invalid_col_id, "column ids are not continuous.");
}
}
vec->reserve(num_cols);
std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
[curr_row](const auto &col) { return std::make_pair(col.first, curr_row[col.second]); });
return Status::OK();
}
// ToDevice // ToDevice
Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); } Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); }
@ -81,7 +114,6 @@ Status ToDevice::Send() {
RETURN_IF_NOT_OK(tree_adapter_->Launch()); RETURN_IF_NOT_OK(tree_adapter_->Launch());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_IF_NOT_OK(root->GetNextBuffer(&db));
return Status::OK(); return Status::OK();
} }
@ -101,9 +133,36 @@ Status ToDevice::Stop() {
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get()); DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp"); CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
op->StopSend(); op->StopSend();
return Status::OK(); return Status::OK();
} }
Status ToDevice::GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes) {
// tree_.root() must be DeviceQueueOp
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DeviceQueueOp");
DATA_INFO data_info;
RETURN_IF_NOT_OK(op->GetDataInfo(&data_info));
for (auto el : data_info) {
types->push_back(el.first);
shapes->push_back(el.second);
}
return Status::OK();
}
Status ToDevice::Terminate() {
#ifdef ENABLE_TDTQUE
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
op->StopWaiting();
#endif
return TreeConsumer::Terminate();
}
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// SaveToDisk // SaveToDisk
Status SaveToDisk::ValidateParams() { Status SaveToDisk::ValidateParams() {
@ -282,50 +341,50 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
if (column_type == DataType::DE_INT8) { if (column_type == DataType::DE_INT8) {
std::unique_ptr<int32_t> data; std::unique_ptr<int32_t> data;
std::unique_ptr<int8_t> dummy; std::unique_ptr<int8_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT16) { } else if (column_type == DataType::DE_INT16) {
std::unique_ptr<int32_t> data; std::unique_ptr<int32_t> data;
std::unique_ptr<int16_t> dummy; std::unique_ptr<int16_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT16) { } else if (column_type == DataType::DE_UINT16) {
std::unique_ptr<int32_t> data; std::unique_ptr<int32_t> data;
std::unique_ptr<uint16_t> dummy; std::unique_ptr<uint16_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT8) { } else if (column_type == DataType::DE_UINT8) {
std::unique_ptr<uint8_t> data, dummy; std::unique_ptr<uint8_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT32) { } else if (column_type == DataType::DE_INT32) {
std::unique_ptr<int32_t> data, dummy; std::unique_ptr<int32_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT32) { } else if (column_type == DataType::DE_UINT32) {
std::unique_ptr<int64_t> data; std::unique_ptr<int64_t> data;
std::unique_ptr<uint32_t> dummy; std::unique_ptr<uint32_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT64) { } else if (column_type == DataType::DE_INT64) {
std::unique_ptr<int64_t> data, dummy; std::unique_ptr<int64_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_FLOAT32) { } else if (column_type == DataType::DE_FLOAT32) {
std::unique_ptr<float> data, dummy; std::unique_ptr<float> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_FLOAT64) { } else if (column_type == DataType::DE_FLOAT64) {
std::unique_ptr<double> data, dummy; std::unique_ptr<double> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_STRING) { } else if (column_type == DataType::DE_STRING) {
@ -346,7 +405,7 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
} }
template <typename T, typename S> template <typename T, typename S>
Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert) { std::unique_ptr<S> *s, bool need_convert) {
if (nullptr == src) { if (nullptr == src) {
@ -379,47 +438,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
} }
#endif #endif
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) { TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); }
tree_adapter_ = std::make_unique<TreeAdapter>();
}
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
if (init_flag_) { root_ = std::move(d);
return Status::OK();
}
Status s = tree_adapter_->Compile(std::move(d), 1);
if (!s.IsError()) {
init_flag_ = true;
}
return s;
}
bool TreeGetters::isInitialized() { return init_flag_; }
Status TreeGetters::GetRow(TensorRow *row) {
if (row_flag_ == false) {
RETURN_IF_NOT_OK(tree_adapter_->GetNext(row));
row_flag_ = true;
}
return Status::OK(); return Status::OK();
} }
Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); }
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ == -1) { if (dataset_size_ == -1) {
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize)));
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); RETURN_UNEXPECTED_IF_NULL(root);
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
dataset_size_ = *dataset_size; if (*dataset_size == -1) { // run through the tree and get everything
if (*dataset_size == -1) { TensorRow row;
RETURN_IF_NOT_OK(GetRow(&row_)); RETURN_IF_NOT_OK(GetRow(&row));
int64_t num_rows = 0; int64_t row_cnt = 0;
TensorRow row = row_; while (!row.empty()) {
while (row.size() != 0) { ++row_cnt;
num_rows++; RETURN_IF_NOT_OK(GetRow(&row));
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
} }
dataset_size_ = num_rows; *dataset_size = row_cnt;
} }
dataset_size_ = *dataset_size; // save the previous result
} }
*dataset_size = dataset_size_; *dataset_size = dataset_size_;
@ -427,68 +471,88 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
} }
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) { Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
RETURN_IF_NOT_OK(GetRow(&row_)); RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
for (auto ts : row_) { if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
DataType dt = ts->type();
types->push_back(dt); std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types),
} [](const TensorPtr &t) { return t->type(); });
return Status::OK(); return Status::OK();
} }
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) { Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
RETURN_IF_NOT_OK(GetRow(&row_)); RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
for (auto ts : row_) { if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
TensorShape t = ts->shape();
shapes->push_back(t); std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes),
} [](const TensorPtr &t) { return t->shape(); });
return Status::OK(); return Status::OK();
} }
Status TreeGetters::GetBatchSize(int64_t *batch_size) { Status TreeGetters::GetBatchSize(int64_t *batch_size) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); RETURN_UNEXPECTED_IF_NULL(root);
*batch_size = root->GetTreeBatchSize(); *batch_size = root->GetTreeBatchSize();
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size."); CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size.");
return Status::OK(); return Status::OK();
} }
Status TreeGetters::GetRepeatCount(int64_t *repeat_count) { Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); RETURN_UNEXPECTED_IF_NULL(root);
*repeat_count = root->GetTreeRepeatCount(); *repeat_count = root->GetTreeRepeatCount();
return Status::OK(); return Status::OK();
} }
Status TreeGetters::GetNumClasses(int64_t *num_classes) { Status TreeGetters::GetNumClasses(int64_t *num_classes) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); RETURN_UNEXPECTED_IF_NULL(root);
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
return Status::OK(); return Status::OK();
} }
Status TreeGetters::GetColumnNames(std::vector<std::string> *output) { Status TreeGetters::GetColumnNames(std::vector<std::string> *output) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
RETURN_UNEXPECTED_IF_NULL(root);
std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map(); std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map();
if (column_name_id_map.empty()) RETURN_STATUS_UNEXPECTED("GetColumnNames: column_name_id map was empty."); CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map is empty.");
std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(), std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end());
column_name_id_map.end()); std::sort(col_name_id_vec.begin(), col_name_id_vec.end(),
std::sort(column_name_id_vector.begin(), column_name_id_vector.end(),
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) { [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
return a.second < b.second; return a.second < b.second;
}); });
for (auto item : column_name_id_vector) { std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output),
(*output).push_back(item.first); [](const std::pair<std::string, int32_t> &p) { return p.first; });
}
return Status::OK(); return Status::OK();
} }
Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); RETURN_UNEXPECTED_IF_NULL(root);
RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing)); RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
return Status::OK(); return Status::OK();
} }
Status TreeGetters::InternalInit(int8_t type) {
if (init_flag_) return Status::OK();
tree_adapter_->SetPrePassOverride([&type](OptPass pre) {
pre.push_back(std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(type)));
return pre;
});
Status s = tree_adapter_->Compile(std::move(root_), 1);
if (!s.IsError()) init_flag_ = true;
return s;
}
Status TreeGetters::InternalInit() {
if (init_flag_) return Status::OK();
Status s = tree_adapter_->Compile(std::move(root_), 1);
if (!s.IsError()) init_flag_ = true;
return s;
}
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); } Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
Status BuildVocabConsumer::Start() { Status BuildVocabConsumer::Start() {

View File

@ -41,7 +41,7 @@ class TreeConsumer {
/// \return Status error code. /// \return Status error code.
virtual Status Init(std::shared_ptr<DatasetNode> d); virtual Status Init(std::shared_ptr<DatasetNode> d);
Status Terminate(); virtual Status Terminate();
protected: protected:
/// The class owns the tree_adapter that handles execution tree operations. /// The class owns the tree_adapter that handles execution tree operations.
@ -72,6 +72,11 @@ class IteratorConsumer : public TreeConsumer {
/// \return Status error code /// \return Status error code
Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out); Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out);
/// Returns the next row in as a map
/// \param[out] out std::map of string to Tensor
/// \return Status error code
Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec);
protected: protected:
/// Method to return the name of the consumer /// Method to return the name of the consumer
/// \return string /// \return string
@ -79,6 +84,7 @@ class IteratorConsumer : public TreeConsumer {
private: private:
int32_t num_epochs_; int32_t num_epochs_;
std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id
}; };
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@ -101,7 +107,7 @@ class SaveToDisk : public TreeConsumer {
/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
/// would be written to disk) /// would be written to disk)
/// \return Status error code /// \return Status error code
Status Save(); virtual Status Save();
protected: protected:
/// Method to return the name of the consumer /// Method to return the name of the consumer
@ -110,7 +116,7 @@ class SaveToDisk : public TreeConsumer {
private: private:
template <typename T, typename S> template <typename T, typename S>
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, Status TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert = false); std::unique_ptr<S> *s, bool need_convert = false);
@ -131,24 +137,29 @@ class SaveToDisk : public TreeConsumer {
/// Consumer that iterates over the dataset and send it to a device /// Consumer that iterates over the dataset and send it to a device
class ToDevice : public TreeConsumer { class ToDevice : public TreeConsumer {
public: public:
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
~ToDevice() = default; ~ToDevice() = default;
Status Init(std::shared_ptr<DatasetNode> d) override; Status Init(std::shared_ptr<DatasetNode> d) override;
Status Terminate() override;
/// Send the data to device /// Send the data to device
/// \return Status error code /// \return Status error code
Status Send(); virtual Status Send();
/// Stop to send data to device /// Stop to send data to device
/// \return Status error code /// \return Status error code
Status Stop(); virtual Status Stop();
/// Continue to send data to device /// Continue to send data to device
/// \return Status error code /// \return Status error code
Status Continue(); virtual Status Continue();
/// Get data info from TDT
/// \return Status error code
virtual Status GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes);
protected: protected:
/// Method to return the name of the consumer /// Method to return the name of the consumer
@ -156,8 +167,6 @@ class ToDevice : public TreeConsumer {
std::string Name() override { return "ToDevice"; } std::string Name() override { return "ToDevice"; }
private: private:
std::string device_type_;
bool send_epoch_end_;
int32_t num_epochs_; int32_t num_epochs_;
}; };
@ -167,6 +176,7 @@ class TreeGetters : public TreeConsumer {
TreeGetters(); TreeGetters();
~TreeGetters() = default; ~TreeGetters() = default;
Status Init(std::shared_ptr<DatasetNode> d) override; Status Init(std::shared_ptr<DatasetNode> d) override;
Status GetDatasetSize(int64_t *size); Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types); Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes); Status GetOutputShapes(std::vector<TensorShape> *shapes);
@ -175,15 +185,17 @@ class TreeGetters : public TreeConsumer {
Status GetNumClasses(int64_t *num_classes); Status GetNumClasses(int64_t *num_classes);
Status GetColumnNames(std::vector<std::string> *output); Status GetColumnNames(std::vector<std::string> *output);
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
bool isInitialized();
std::string Name() override { return "TreeGetters"; } std::string Name() override { return "TreeGetters"; }
Status GetRow(TensorRow *r); virtual Status GetRow(TensorRow *r);
private: private:
std::shared_ptr<DatasetNode> root_;
int64_t dataset_size_; int64_t dataset_size_;
TensorRow row_; TensorRow first_row_;
bool init_flag_; // indicate whether the tree has initialized bool init_flag_; // indicate whether the tree has initialized
bool row_flag_; // indicate whether the first row has been stored in row_
Status InternalInit(int8_t type);
Status InternalInit();
}; };
class BuildVocabConsumer : public TreeConsumer { class BuildVocabConsumer : public TreeConsumer {
@ -197,7 +209,7 @@ class BuildVocabConsumer : public TreeConsumer {
/// Start consuming /// Start consuming
/// \return Status error code /// \return Status error code
Status Start(); virtual Status Start();
protected: protected:
/// Method to return the name of the consumer /// Method to return the name of the consumer

View File

@ -44,9 +44,9 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
} }
// Constructor of the ConcatOp. // Constructor of the ConcatOp.
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
std::vector<std::pair<int, int>> children_flag_and_nums, const std::vector<std::pair<int, int>> &children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index) const std::vector<std::pair<int, int>> &children_start_end_index)
: PipelineOp(op_connector_size), : PipelineOp(op_connector_size),
children_num_(0), children_num_(0),
sampler_(sampler), sampler_(sampler),

View File

@ -70,9 +70,9 @@ class ConcatOp : public PipelineOp {
// @note The builder class should be used to call it // @note The builder class should be used to call it
// @param op_connector_size - connector size // @param op_connector_size - connector size
explicit ConcatOp(int32_t op_connector_size); explicit ConcatOp(int32_t op_connector_size);
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
std::vector<std::pair<int, int>> children_flag_and_nums, const std::vector<std::pair<int, int>> &children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index); const std::vector<std::pair<int, int>> &children_start_end_index);
// Destructor // Destructor
~ConcatOp() = default; ~ConcatOp() = default;

View File

@ -346,6 +346,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Name of the current Op /// \return Name of the current Op
virtual std::string Name() const = 0; virtual std::string Name() const = 0;
/// Op name and ID getter
/// \return Name and ID of the current Op
std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; }
/// Execution Tree getter /// Execution Tree getter
/// \return Pointer to the ExecutionTree the current op belongs to, no ownership /// \return Pointer to the ExecutionTree the current op belongs to, no ownership
ExecutionTree *Tree() { return tree_; } ExecutionTree *Tree() { return tree_; }

View File

@ -205,7 +205,6 @@ Status DeviceQueueOp::SendDataToAscend() {
} }
tree_->SetFinished(); tree_->SetFinished();
MS_LOG(INFO) << "Device queue total batch is " << send_batch;
return Status::OK(); return Status::OK();
} }

View File

@ -39,10 +39,10 @@ using mindspore::device::GpuBufferMgr;
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>; using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>;
using DATA_INFO_QUEUE = Queue<DATA_INFO>; using DATA_INFO_QUEUE = Queue<DATA_INFO>;
const int kDataInfoQueueCapacity = 128; const int kDataInfoQueueCapacity = 128;
class DeviceQueueOp : public PipelineOp { class DeviceQueueOp : public PipelineOp {
public: public:
static const uint32_t INVALID_HANDLE = 0xffffffffUL; static const uint32_t INVALID_HANDLE = 0xffffffffUL;
@ -184,7 +184,6 @@ class DeviceQueueOp : public PipelineOp {
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
Status SendDataToAscend(); Status SendDataToAscend();
bool ascend_keep_waiting_; bool ascend_keep_waiting_;
#endif #endif
#ifdef ENABLE_GPUQUE #ifdef ENABLE_GPUQUE

View File

@ -169,7 +169,7 @@ Status MapOp::operator()() {
} }
// The operator class just starts off threads by calling the tree_ function // The operator class just starts off threads by calling the tree_ function
rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1), NameWithID());
// Synchronize with TaskManager // Synchronize with TaskManager
TaskManager::FindMe()->Post(); TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(rc); RETURN_IF_NOT_OK(rc);

View File

@ -704,6 +704,8 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
} }
if (image_ids_.size() == 0) { if (image_ids_.size() == 0) {
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
} else {
num_rows = image_ids_.size();
} }
sample_size = sampler_->CalculateNumSamples(num_rows); sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;

View File

@ -480,13 +480,13 @@ Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) {
*dataset_size = dataset_size_; *dataset_size = dataset_size_;
return Status::OK(); return Status::OK();
} }
int64_t num_rows = num_rows_, sample_size; int64_t num_rows = num_rows_;
if (num_rows_ <= 0) { if (num_rows_ <= 0) {
std::shared_ptr<ShardOperator> op; // The last operator is parent sampler
std::shared_ptr<ShardOperator> op = operators_.back();
RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_)); RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_));
} }
sample_size = operators_[0]->GetNumSamples(num_rows, 0); *dataset_size = num_rows;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size; dataset_size_ = *dataset_size;
return Status::OK(); return Status::OK();
} }

View File

@ -1067,6 +1067,19 @@ Status TFReaderOp::PrepareNodePostAction() {
return Status::OK(); return Status::OK();
} }
// Get the file list of the specific shard ID
Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) {
if (!shard_filenames->empty()) {
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n");
}
for (int index = 0; index < dataset_files_list_.size(); index++) {
if (index % num_devices_ == device_id_) {
shard_filenames->push_back(dataset_files_list_.at(index));
}
}
return Status::OK();
}
// Get Dataset size // Get Dataset size
Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) { if (dataset_size_ > 0) {
@ -1080,7 +1093,9 @@ Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
num_rows = num_rows_per_shard_; num_rows = num_rows_per_shard_;
} else { } else {
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, dataset_files_list_)); std::vector<std::string> shard_file_list;
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list));
} }
} }
sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();

View File

@ -400,6 +400,11 @@ class TFReaderOp : public ParallelOp {
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;
// Private function for computing the file list of the specific shard ID. This is because in distributed scenario,
// data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case.
// @return - Status - the status code returned.
Status GetShardFileList(std::vector<std::string> *shard_filenames);
int32_t device_id_; int32_t device_id_;
int32_t num_devices_; int32_t num_devices_;
int64_t rows_per_buffer_; int64_t rows_per_buffer_;

View File

@ -536,6 +536,8 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
RETURN_IF_NOT_OK(op->ParseImageIds()); RETURN_IF_NOT_OK(op->ParseImageIds());
num_rows = static_cast<int64_t>(op->image_ids_.size()); num_rows = static_cast<int64_t>(op->image_ids_.size());
} }
} else {
num_rows = image_ids_.size();
} }
sample_size = sampler_->CalculateNumSamples(num_rows); sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size; *dataset_size = sample_size;

View File

@ -141,8 +141,6 @@ Status ExecutionTree::Launch() {
" Expected state: " + std::to_string(static_cast<int>(kDeTStateReady)); " Expected state: " + std::to_string(static_cast<int>(kDeTStateReady));
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
} }
std::ostringstream ss;
ss << *this;
// Profiling infrastructures need to be initialized before Op launching // Profiling infrastructures need to be initialized before Op launching
if (profiling_manager_->IsProfilingEnable()) { if (profiling_manager_->IsProfilingEnable()) {
@ -152,6 +150,8 @@ Status ExecutionTree::Launch() {
RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor()); RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor());
} }
std::ostringstream ss;
ss << *this;
MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str(); MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str();
for (auto itr = this->begin(); itr != this->end(); ++itr) { for (auto itr = this->begin(); itr != this->end(); ++itr) {
// An inlined operator is one that has an output connector size of 0, and it does not // An inlined operator is one that has an output connector size of 0, and it does not
@ -160,7 +160,7 @@ Status ExecutionTree::Launch() {
// the launching tree/user thread. Do not exec any thread for an inlined op. // the launching tree/user thread. Do not exec any thread for an inlined op.
itr->state_ = DatasetOp::OpState::kDeOpRunning; itr->state_ = DatasetOp::OpState::kDeOpRunning;
if (!itr->inlined()) { if (!itr->inlined()) {
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr))); RETURN_IF_NOT_OK(tg_->CreateAsyncTask(itr->NameWithID(), std::ref(*itr)));
// Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp // Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp
} }
} }
@ -189,10 +189,10 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_
// Given the number of workers, launches the worker entry function for each. Essentially a // Given the number of workers, launches the worker entry function for each. Essentially a
// wrapper for the TaskGroup handling that is stored inside the execution tree. // wrapper for the TaskGroup handling that is stored inside the execution tree.
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func) { Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name) {
// Launch the workers // Launch the workers
for (int32_t i = 0; i < num_workers; ++i) { for (int32_t i = 0; i < num_workers; ++i) {
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i))); RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i)));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -150,7 +150,7 @@ class ExecutionTree {
// @param num_workers - The number of workers to launch // @param num_workers - The number of workers to launch
// @param func - The function entry point that workers will execute // @param func - The function entry point that workers will execute
// @return Status - The error code return // @return Status - The error code return
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func); Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "");
// Getter method // Getter method
// @return shared_ptr to the root operator // @return shared_ptr to the root operator

View File

@ -1,4 +1,5 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-ir-cache OBJECT add_library(engine-ir-cache OBJECT
dataset_cache_impl.cc) pre_built_dataset_cache.cc
dataset_cache_impl.cc)

View File

@ -18,8 +18,8 @@
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h"
namespace mindspore::dataset { namespace mindspore {
namespace dataset {
/// Method to initialize the DatasetCache by creating an instance of a CacheClient /// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code /// \return Status Error code
Status DatasetCacheImpl::Build() { Status DatasetCacheImpl::Build() {
@ -40,5 +40,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
return Status::OK(); return Status::OK();
} }
} // namespace dataset
} // namespace mindspore::dataset } // namespace mindspore

View File

@ -24,8 +24,8 @@
#include "minddata/dataset/engine/datasetops/cache_op.h" #include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h" #include "minddata/dataset/engine/ir/cache/dataset_cache.h"
namespace mindspore::dataset { namespace mindspore {
namespace dataset {
/// DatasetCache is the IR of CacheClient /// DatasetCache is the IR of CacheClient
class DatasetCacheImpl : public DatasetCache { class DatasetCacheImpl : public DatasetCache {
public: public:
@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache {
std::optional<int32_t> num_connections_; std::optional<int32_t> num_connections_;
std::optional<int32_t> prefetch_sz_; std::optional<int32_t> prefetch_sz_;
}; };
} // namespace mindspore::dataset } // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
namespace mindspore {
namespace dataset {
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
Status PreBuiltDatasetCache::Build() {
// we actually want to keep a reference of the runtime object so it can be shared by different pipelines
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheOp> cache_op = nullptr;
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op));
*ds = cache_op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
namespace mindspore {
namespace dataset {
/// DatasetCache is the IR of CacheClient
class PreBuiltDatasetCache : public DatasetCache {
public:
/// \brief Constructor
/// \param cc a pre-built cache client
explicit PreBuiltDatasetCache(std::shared_ptr<CacheClient> cc) : cache_client_(std::move(cc)) {}
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
Status Build() override;
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status ValidateParams() override { return Status::OK(); }
private:
std::shared_ptr<CacheClient> cache_client_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_

View File

@ -31,7 +31,7 @@ namespace dataset {
BucketBatchByLengthNode::BucketBatchByLengthNode( BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function, std::shared_ptr<TensorOp> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) bool drop_remainder)
: column_names_(column_names), : column_names_(column_names),
@ -47,16 +47,13 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
bucket_boundaries_.insert(bucket_boundaries_.begin(), 0);
std::shared_ptr<TensorOp> c_func; node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(
if (element_length_function_ != nullptr) { column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_,
c_func = std::make_shared<CFuncOp>(element_length_function_); pad_to_bucket_boundary_, drop_remainder_, connector_que_size_));
} else { if (bucket_boundaries_[0] == 0) {
c_func = nullptr; bucket_boundaries_.erase(bucket_boundaries_.begin());
} }
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
c_func, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_));
return node_ops; return node_ops;
} }

View File

@ -33,7 +33,7 @@ class BucketBatchByLengthNode : public DatasetNode {
/// \brief Constructor /// \brief Constructor
BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr, std::shared_ptr<TensorOp> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false); bool pad_to_bucket_boundary = false, bool drop_remainder = false);
@ -52,7 +52,7 @@ class BucketBatchByLengthNode : public DatasetNode {
std::vector<std::string> column_names_; std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_; std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_; std::vector<int32_t> bucket_batch_sizes_;
std::function<TensorRow(TensorRow)> element_length_function_; std::shared_ptr<TensorOp> element_length_function_;
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_; std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
bool pad_to_bucket_boundary_; bool pad_to_bucket_boundary_;
bool drop_remainder_; bool drop_remainder_;

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "minddata/dataset/engine/datasetops/concat_op.h" #include "minddata/dataset/engine/datasetops/concat_op.h"
@ -27,7 +28,15 @@ namespace mindspore {
namespace dataset { namespace dataset {
// Function to build ConcatOp // Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; } ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets,
const std::shared_ptr<SamplerObj> &sampler,
const std::vector<std::pair<int, int>> &children_flag_and_nums,
const std::vector<std::pair<int, int>> &children_start_end_index)
: sampler_(sampler),
children_flag_and_nums_(children_flag_and_nums),
children_start_end_index_(children_start_end_index) {
this->children = datasets;
}
Status ConcatNode::ValidateParams() { Status ConcatNode::ValidateParams() {
if (children.size() < 2) { if (children.size() < 2) {
@ -42,14 +51,25 @@ Status ConcatNode::ValidateParams() {
RETURN_STATUS_SYNTAX_ERROR(err_msg); RETURN_STATUS_SYNTAX_ERROR(err_msg);
} }
if ((children_flag_and_nums_.empty() && !children_start_end_index_.empty()) ||
(!children_flag_and_nums_.empty() && children_start_end_index_.empty())) {
std::string err_msg = "ConcatNode: children_flag_and_nums and children_start_end_index should be used together";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK(); return Status::OK();
} }
std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
} else {
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_,
children_start_end_index_));
}
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
return node_ops; return node_ops;
} }

View File

@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
@ -29,7 +30,10 @@ namespace dataset {
class ConcatNode : public DatasetNode { class ConcatNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets); explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets,
const std::shared_ptr<SamplerObj> &sampler = nullptr,
const std::vector<std::pair<int, int>> &children_flag_and_nums = {},
const std::vector<std::pair<int, int>> &children_start_end_index = {});
/// \brief Destructor /// \brief Destructor
~ConcatNode() = default; ~ConcatNode() = default;
@ -41,6 +45,11 @@ class ConcatNode : public DatasetNode {
/// \brief Parameters validation /// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;
private:
std::shared_ptr<SamplerObj> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
}; };
} // namespace dataset } // namespace dataset

View File

@ -240,6 +240,7 @@ DatasetNode::DatasetNode() {
rows_per_buffer_ = cfg->rows_per_buffer(); rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size(); connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size(); worker_connector_size_ = cfg->worker_connector_size();
build_status = Status::OK(); // remove me after changing return val of Build()
} }
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit. // In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
@ -254,5 +255,13 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one. // This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified); return p->VisitAfter(shared_from_this(), modified);
} }
Status DatasetNode::GetShardId(int32_t *shard_id) {
if (!Children().empty()) {
// Get shard id from the child node
return Children()[0]->GetShardId(shard_id);
} else {
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node");
}
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -99,9 +99,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \brief Pure virtual function for derived class to get the shard id of specific node /// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully /// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id) { virtual Status GetShardId(int32_t *shard_id);
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// \brief Setter function for runtime number of workers /// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator /// \param[in] num_workers The number of threads in this operator
@ -126,6 +124,10 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status of the node visit /// \return Status of the node visit
virtual Status AcceptAfter(NodePass *p, bool *modified); virtual Status AcceptAfter(NodePass *p, bool *modified);
/// \brief Method to get status from Node.Build()
/// \notes Remove me after changing return val of Build()
Status BuildStatus() { return build_status; }
protected: protected:
std::vector<std::shared_ptr<DatasetNode>> children; std::vector<std::shared_ptr<DatasetNode>> children;
std::shared_ptr<DatasetCache> cache_; std::shared_ptr<DatasetCache> cache_;
@ -135,6 +137,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
int32_t connector_que_size_; int32_t connector_que_size_;
int32_t worker_connector_size_; int32_t worker_connector_size_;
Status build_status; // remove me after changing return val of Build()
}; };
} // namespace dataset } // namespace dataset

View File

@ -28,7 +28,7 @@ namespace mindspore {
namespace dataset { namespace dataset {
// Constructor for FilterNode // Constructor for FilterNode
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate, FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
std::vector<std::string> input_columns) std::vector<std::string> input_columns)
: predicate_(predicate), input_columns_(input_columns) { : predicate_(predicate), input_columns_(input_columns) {
this->children.push_back(child); this->children.push_back(child);
@ -38,10 +38,7 @@ std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::shared_ptr<TensorOp> c_func; node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_));
c_func = std::make_shared<CFuncOp>(predicate_);
node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, c_func));
return node_ops; return node_ops;
} }

View File

@ -29,7 +29,7 @@ namespace dataset {
class FilterNode : public DatasetNode { class FilterNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate, FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
std::vector<std::string> input_columns = {}); std::vector<std::string> input_columns = {});
/// \brief Destructor /// \brief Destructor
@ -44,7 +44,7 @@ class FilterNode : public DatasetNode {
Status ValidateParams() override; Status ValidateParams() override;
private: private:
std::function<TensorRow(TensorRow)> predicate_; std::shared_ptr<TensorOp> predicate_;
std::vector<std::string> input_columns_; std::vector<std::string> input_columns_;
}; };

View File

@ -64,7 +64,8 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
auto project_op = std::make_shared<ProjectOp>(project_columns_); auto project_op = std::make_shared<ProjectOp>(project_columns_);
node_ops.push_back(project_op); node_ops.push_back(project_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(map_op); node_ops.push_back(map_op);
return node_ops; return node_ops;

View File

@ -59,7 +59,8 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
auto schema = std::make_unique<DataSchema>(); auto schema = std::make_unique<DataSchema>();
RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_)); build_status = schema->LoadSchemaFile(schema_path_, column_names_);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Argument that is not exposed to user in the API. // Argument that is not exposed to user in the API.
std::set<std::string> extensions = {}; std::set<std::string> extensions = {};

View File

@ -60,7 +60,8 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, usage_, extensions_, std::move(schema), decode_, usage_, extensions_, std::move(schema),

View File

@ -56,7 +56,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema), dataset_dir_, connector_que_size_, std::move(schema),

View File

@ -54,7 +54,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema), dataset_dir_, connector_que_size_, std::move(schema),

View File

@ -197,18 +197,23 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(clue_op->Init());
build_status = clue_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp // Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr; std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0; int64_t num_rows = 0;
// First, get the number of rows in the dataset // First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows)); build_status = ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op)); rows_per_buffer_, &shuffle_op);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
node_ops.push_back(shuffle_op); node_ops.push_back(shuffle_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));

View File

@ -111,7 +111,8 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
std::shared_ptr<CocoOp> op = std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(op); node_ops.push_back(op);

View File

@ -108,18 +108,23 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_,
rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
num_shards_, shard_id_, std::move(sampler_->Build())); num_shards_, shard_id_, std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(csv_op->Init());
build_status = csv_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp // Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr; std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0; int64_t num_rows = 0;
// First, get the number of rows in the dataset // First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows)); build_status = CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op)); rows_per_buffer_, &shuffle_op);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
node_ops.push_back(shuffle_op); node_ops.push_back(shuffle_op);
} }

View File

@ -30,7 +30,25 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
const std::vector<DataType> &column_types) const std::vector<DataType> &column_types)
: generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {}
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
: generator_function_(generator_function), schema_(schema) {}
std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() { std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
if (schema_ != nullptr) {
column_names_.clear();
column_types_.clear();
std::string schema_json_string = schema_->to_json();
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, {}));
for (int32_t i = 0; i < data_schema->NumColumns(); i++) {
ColDescriptor col = data_schema->column(i);
column_names_.push_back(col.name());
column_types_.push_back((col.type()));
}
}
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
@ -43,6 +61,8 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
// This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which // This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which
// best be delivered when the test cases for this api is ready. // best be delivered when the test cases for this api is ready.
Status rc = op->Init(); Status rc = op->Init();
build_status = rc; // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (rc.IsOk()) { if (rc.IsOk()) {
node_ops.push_back(op); node_ops.push_back(op);
@ -56,5 +76,11 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
// no validation is needed for generator op. // no validation is needed for generator op.
Status GeneratorNode::ValidateParams() { return Status::OK(); } Status GeneratorNode::ValidateParams() { return Status::OK(); }
Status GeneratorNode::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);
*shard_id = 0;
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -35,6 +35,9 @@ class GeneratorNode : public DatasetNode {
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types); const std::vector<DataType> &column_types);
/// \brief Constructor
GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema);
/// \brief Destructor /// \brief Destructor
~GeneratorNode() = default; ~GeneratorNode() = default;
@ -46,10 +49,15 @@ class GeneratorNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;
/// \brief Get the shard id of node, is always 0 because generator_node doesn't support sharding
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private: private:
py::function generator_function_; py::function generator_function_;
std::vector<std::string> column_names_; std::vector<std::string> column_names_;
std::vector<DataType> column_types_; std::vector<DataType> column_types_;
std::shared_ptr<SchemaObj> schema_;
}; };
} // namespace dataset } // namespace dataset

View File

@ -62,7 +62,8 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema), recursive_, decode_, exts_, class_indexing_, std::move(schema),

View File

@ -79,7 +79,8 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
manifest_op = manifest_op =
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
class_index_, std::move(schema), std::move(sampler_->Build()), usage_); class_index_, std::move(schema), std::move(sampler_->Build()), usage_);
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(manifest_op); node_ops.push_back(manifest_op);

View File

@ -138,7 +138,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::vector<std::shared_ptr<ShardOperator>> operators_; std::vector<std::shared_ptr<ShardOperator>> operators_;
RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_)); build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
std::shared_ptr<MindRecordOp> mindrecord_op; std::shared_ptr<MindRecordOp> mindrecord_op;
// If pass a string to MindData(), it will be treated as a pattern to search for matched files, // If pass a string to MindData(), it will be treated as a pattern to search for matched files,
@ -154,7 +155,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
padded_sample_, sample_bytes_); padded_sample_, sample_bytes_);
} }
RETURN_EMPTY_IF_ERROR(mindrecord_op->Init()); build_status = mindrecord_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(mindrecord_op); node_ops.push_back(mindrecord_op);
return node_ops; return node_ops;

View File

@ -51,7 +51,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
TensorShape scalar = TensorShape::CreateScalar(); TensorShape scalar = TensorShape::CreateScalar();
RETURN_EMPTY_IF_ERROR( RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->Build()))); connector_que_size_, std::move(schema), std::move(sampler_->Build())));

View File

@ -98,7 +98,8 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
std::shared_ptr<RandomDataOp> op; std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema), std::move(sampler_->Build())); std::move(data_schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(op); node_ops.push_back(op);

View File

@ -78,7 +78,8 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(text_file_op->Init()); build_status = text_file_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp // Inject ShuffleOp
@ -86,14 +87,17 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
int64_t num_rows = 0; int64_t num_rows = 0;
// First, get the number of rows in the dataset // First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows)); build_status = TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op)); rows_per_buffer_, &shuffle_op);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
node_ops.push_back(shuffle_op); node_ops.push_back(shuffle_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
// Add TextFileOp // Add TextFileOp
node_ops.push_back(text_file_op); node_ops.push_back(text_file_op);

View File

@ -118,7 +118,8 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_,
shard_id_, shard_equal_rows_, std::move(sampler_->Build())); shard_id_, shard_equal_rows_, std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init()); build_status = tf_reader_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp // Inject ShuffleOp
@ -127,14 +128,17 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
int64_t num_rows = 0; int64_t num_rows = 0;
// First, get the number of rows in the dataset // First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files)); build_status = TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Add the shuffle op after this op // Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, build_status = AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op)); rows_per_buffer_, &shuffle_op);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
node_ops.push_back(shuffle_op); node_ops.push_back(shuffle_op);
} }
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
// Add TFReaderOp // Add TFReaderOp
node_ops.push_back(tf_reader_op); node_ops.push_back(tf_reader_op);

View File

@ -106,7 +106,8 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
std::shared_ptr<VOCOp> voc_op; std::shared_ptr<VOCOp> voc_op;
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(voc_op); node_ops.push_back(voc_op);
return node_ops; return node_ops;

View File

@ -27,9 +27,8 @@ namespace mindspore {
namespace dataset { namespace dataset {
// Constructor for SyncWaitNode // Constructor for SyncWaitNode
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch, SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback)
py::function callback) : condition_name_(condition_name), callback_(callback) {
: condition_name_(condition_name), num_batch_(num_batch), callback_(callback) {
this->children.push_back(child); this->children.push_back(child);
} }
@ -38,20 +37,16 @@ std::vector<std::shared_ptr<DatasetOp>> SyncWaitNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<BarrierOp>(num_batch_, connector_que_size_, condition_name_, callback_)); // Right now barrier should only take num_rows_per_buffer = 1
// The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details
int32_t rows_per_buffer = 1;
node_ops.push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_));
return node_ops; return node_ops;
} }
// Function to validate the parameters for SyncWaitNode // Function to validate the parameters for SyncWaitNode
Status SyncWaitNode::ValidateParams() { Status SyncWaitNode::ValidateParams() { return Status::OK(); }
if (num_batch_ <= 0) {
std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -31,8 +31,7 @@ namespace dataset {
class SyncWaitNode : public DatasetNode { class SyncWaitNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
explicit SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch, SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback);
py::function callback);
/// \brief Destructor /// \brief Destructor
~SyncWaitNode() = default; ~SyncWaitNode() = default;
@ -47,7 +46,6 @@ class SyncWaitNode : public DatasetNode {
private: private:
std::string condition_name_; std::string condition_name_;
int32_t num_batch_;
py::function callback_; py::function callback_;
}; };

View File

@ -18,73 +18,81 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "minddata/dataset/engine/datasetops/device_queue_op.h" #include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Constructor for TransferNode // Constructor for TransferNode
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end) TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type,
: prefetch_size_(16), send_epoch_end_(send_epoch_end), total_batch_(0) { bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
: prefetch_size_(16),
queue_name_(std::move(queue_name)),
device_type_(std::move(device_type)),
send_epoch_end_(send_epoch_end),
total_batch_(total_batch),
create_data_info_queue_(create_data_info_queue),
device_id_(0) {
this->children.push_back(child); this->children.push_back(child);
} }
// Validator for TransferNode // Validator for TransferNode
Status TransferNode::ValidateParams() { Status TransferNode::ValidateParams() {
// Check if device_type_ is in {"CPU", "GPU", "Ascend"} if (total_batch_ < 0) {
RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"})); std::string err_msg = "TransferNode: Total batches should be >= 0, value given: ";
MS_LOG(ERROR) << err_msg << total_batch_;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK(); return Status::OK();
} }
// Function to build TransferNode // Function to build TransferNode
std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
// Get a uuid for queue name if (queue_name_.empty()) {
queue_name_ = Services::GetUniqueID(); // Get a uuid for queue name
// TODO(CRC): queue_name_ = Services::GetUniqueID();
}
if (device_type_.empty()) {
auto context = MsContext::GetInstance();
if (context == nullptr) {
device_type_ = kCPUDevice;
} else {
device_type_ = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
}
}
// Get device type from ms context // Get device type from ms context
device_type_ = "CPU"; // Convert device_type_ from string to DeviceType
// Get device ID from children DeviceQueueOp::DeviceType type;
if (device_type_ == kCPUDevice) {
type = DeviceQueueOp::DeviceType::CPU;
} else if (device_type_ == kGPUDevice) {
type = DeviceQueueOp::DeviceType::GPU;
} else if (device_type_ == kAscendDevice) {
type = DeviceQueueOp::DeviceType::Ascend;
} else {
MS_LOG(ERROR) << "Unknown device target.";
return {};
}
// Get device ID (shard ID) from children
device_id_ = 0; device_id_ = 0;
RETURN_EMPTY_IF_ERROR(TransferNode::get_distribution(shared_from_this(), &device_id_)); build_status = this->GetShardId(&device_id_); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
// A vector containing shared pointer to the Dataset Ops that this object will create // A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops; std::vector<std::shared_ptr<DatasetOp>> node_ops;
// Convert device_type_ from string to DeviceType
DeviceQueueOp::DeviceType type;
if (device_type_ == "CPU") {
type = DeviceQueueOp::DeviceType::CPU;
} else if (device_type_ == "GPU") {
type = DeviceQueueOp::DeviceType::GPU;
} else if (device_type_ == "Ascend") {
type = DeviceQueueOp::DeviceType::Ascend;
}
node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
total_batch_, false)); total_batch_, create_data_info_queue_));
return node_ops; return node_ops;
} }
// Function to get the device_id
Status TransferNode::get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id) {
// Get device id according to the type of dataset
Status rc = ds->GetShardId(device_id);
if (rc != Status::OK()) {
// Get device id from the child node
if (ds->Children().size()) {
ds = ds->Children()[0];
return TransferNode::get_distribution(ds, device_id);
} else {
std::string err_msg = "Unknown dataset type.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -29,7 +29,8 @@ namespace dataset {
class TransferNode : public DatasetNode { class TransferNode : public DatasetNode {
public: public:
/// \brief Constructor /// \brief Constructor
TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end); TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, bool send_epoch_end,
int32_t total_batch, bool create_data_info_queue);
/// \brief Destructor /// \brief Destructor
~TransferNode() = default; ~TransferNode() = default;
@ -42,8 +43,6 @@ class TransferNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
private: private:
std::string queue_name_; std::string queue_name_;
int32_t device_id_; int32_t device_id_;
@ -51,6 +50,7 @@ class TransferNode : public DatasetNode {
int32_t prefetch_size_; int32_t prefetch_size_;
bool send_epoch_end_; bool send_epoch_end_;
int32_t total_batch_; int32_t total_batch_;
bool create_data_info_queue_;
}; };
} // namespace dataset } // namespace dataset

View File

@ -40,21 +40,7 @@ Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *mo
} }
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
if (type_ == kOutputShapeAndType) { nodes_to_clear_callback_.push_back(node);
nodes_to_clear_callback_.push_back(node);
} else if (type_ == kDatasetSize) {
nodes_to_remove_.push_back(node);
}
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
return Status::OK(); return Status::OK();
} }
@ -83,5 +69,6 @@ Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) {
return Status::OK(); return Status::OK();
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -34,6 +34,10 @@ class GetterPass : public TreePass {
enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 }; enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 };
/// \brief Constructor /// \brief Constructor
explicit GetterPass(GetterType tp) : pass_(tp) {} explicit GetterPass(GetterType tp) : pass_(tp) {}
/// \brief default copy Constructor
explicit GetterPass(const GetterPass &) = default;
/// \brief Destructor /// \brief Destructor
~GetterPass() = default; ~GetterPass() = default;
@ -51,11 +55,10 @@ class GetterPass : public TreePass {
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override { return Status::OK(); }
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
// whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp // whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
@ -67,7 +70,7 @@ class GetterPass : public TreePass {
std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_; std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_;
std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_; std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_;
}; };
// outter class needs only to own the inner class object since it automatically has access to its private variables // outer class needs only to own the inner class object since it automatically has access to its private variables
GetterNodes pass_; GetterNodes pass_;
}; };
} // namespace dataset } // namespace dataset

View File

@ -19,7 +19,14 @@
namespace mindspore::dataset { namespace mindspore::dataset {
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); } Status PythonRuntimeContext::Terminate() {
MS_LOG(INFO) << "Terminating a PythonRuntime";
if (tree_consumer_ != nullptr) {
return TerminateImpl();
}
MS_LOG(WARNING) << "TreeConsumer was not initialized";
return Status::OK();
}
Status PythonRuntimeContext::TerminateImpl() { Status PythonRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");

View File

@ -22,7 +22,14 @@ namespace mindspore::dataset {
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) { void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
tree_consumer_ = std::move(tree_consumer); tree_consumer_ = std::move(tree_consumer);
} }
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); } Status NativeRuntimeContext::Terminate() {
MS_LOG(INFO) << "Terminating a NativeRuntime";
if (tree_consumer_ != nullptr) {
return TerminateImpl();
}
MS_LOG(WARNING) << "TreeConsumer was not initialized";
return Status::OK();
}
Status NativeRuntimeContext::TerminateImpl() { Status NativeRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");

View File

@ -97,6 +97,8 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) { Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
// Build the DatasetOp ExecutionTree from the optimized IR tree // Build the DatasetOp ExecutionTree from the optimized IR tree
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build(); std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
RETURN_IF_NOT_OK(ir->BuildStatus()); // remove me after changing return val of Build()
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node."); CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node.");
(*op) = ops.front(); // return the first op to be added as child by the caller of this function (*op) = ops.front(); // return the first op to be added as child by the caller of this function
@ -141,6 +143,8 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op)); RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op));
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
// Note: We will gradually move the pre pass, optimizer pass, and post pass // Note: We will gradually move the pre pass, optimizer pass, and post pass
// on ExecutionTree to perform on IR tree. // on ExecutionTree to perform on IR tree.
// Prepare the tree // Prepare the tree
@ -149,6 +153,11 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
// After the tree is prepared, the col_name_id_map can safely be obtained // After the tree is prepared, the col_name_id_map can safely be obtained
column_name_map_ = tree_->root()->column_name_id_map(); column_name_map_ = tree_->root()->column_name_id_map();
// Profiling parameters init
cur_batch_num_ = 0;
cur_connector_size_ = 0;
cur_connector_capacity_ = 0;
return Status::OK(); return Status::OK();
} }
@ -156,21 +165,55 @@ Status TreeAdapter::GetNext(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(tree_); RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_UNEXPECTED_IF_NULL(row); RETURN_UNEXPECTED_IF_NULL(row);
row->clear(); // make sure row is empty row->clear(); // make sure row is empty
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree // When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
if (cur_db_ == nullptr) { if (cur_db_ == nullptr) {
RETURN_IF_NOT_OK(tree_->Launch()); RETURN_IF_NOT_OK(tree_->Launch());
// Profiling
std::shared_ptr<Tracing> node;
Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
if (s.IsOk()) {
tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
}
if (tracing_ != nullptr) {
cur_connector_size_ = tree_->root()->ConnectorSize();
cur_connector_capacity_ = tree_->root()->ConnectorCapacity();
}
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag
RETURN_OK_IF_TRUE(cur_db_->eoe()); // return empty tensor if 1st buf is a ctrl buf (no rows) if (cur_db_->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
MS_LOG(INFO) << "End of data iteration.";
if (isProfilingEnable) {
tree_->SetEpochEnd();
}
return Status::OK();
}
} }
CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached."); CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached.");
if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_));
RETURN_OK_IF_TRUE(cur_db_->eoe() || cur_db_->eof()); // return empty if this new buffer is a ctrl flag if (cur_db_->eoe()) { // return empty if this new buffer is a ctrl flag
MS_LOG(INFO) << "End of data iteration.";
if (isProfilingEnable) {
tree_->SetEpochEnd();
}
return Status::OK();
}
if (cur_db_->eof()) {
tree_->SetFinished();
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
}
} }
RETURN_IF_NOT_OK(cur_db_->PopRow(row)); RETURN_IF_NOT_OK(cur_db_->PopRow(row));
// Record profiling info
if (tracing_ != nullptr) {
cur_batch_num_++;
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
}
return Status::OK(); return Status::OK();
} }

View File

@ -25,6 +25,7 @@
#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -60,6 +61,9 @@ class TreeAdapter {
// Set optional optimization pass // Set optional optimization pass
void SetOptimize(bool value) { optimize_ = value; } void SetOptimize(bool value) { optimize_ = value; }
// function to override override the pre-pass
void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; }
// Optional optimizations status // Optional optimizations status
bool OptimizationEnabled() const { return optimize_; } bool OptimizationEnabled() const { return optimize_; }
@ -82,9 +86,14 @@ class TreeAdapter {
std::unique_ptr<DataBuffer> cur_db_; std::unique_ptr<DataBuffer> cur_db_;
std::unordered_map<std::string, int32_t> column_name_map_; std::unordered_map<std::string, int32_t> column_name_map_;
std::unique_ptr<ExecutionTree> tree_; std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
int32_t num_epochs_; int32_t num_epochs_;
bool optimize_; // Flag to enable optional optimization pass bool optimize_; // Flag to enable optional optimization pass
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
int32_t cur_batch_num_; // current batch number, used for profiling
int32_t cur_connector_size_; // current connector size of root op, used for profiling
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -145,9 +145,16 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \brief Function to transfer data through a device. /// \brief Function to transfer data through a device.
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation /// \notes If device is Ascend, features of data will be transferred one by one. The limitation
/// of data transmission per time is 256M. /// of data transmission per time is 256M.
/// \param[in] queue_name Channel name (default="", create new unique name).
/// \param[in] device_type Type of device (default="", get from MSContext).
/// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true). /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
/// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
/// of data or not(default=false).
/// \return Returns true if no error encountered else false. /// \return Returns true if no error encountered else false.
bool DeviceQueue(bool send_epoch_end = true); bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t num_epochs = -1,
bool send_epoch_end = true, int32_t total_batches = 0, bool create_data_info_queue = false);
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
/// \note Usage restrictions: /// \note Usage restrictions:
@ -371,21 +378,34 @@ class SchemaObj {
/// \brief SchemaObj init function /// \brief SchemaObj init function
/// \return bool true if schema init success /// \return bool true if schema init success
bool init(); Status init();
/// \brief Add new column to the schema with unknown shape of rank 1
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(TypeId).
/// \return bool true if schema init success
Status add_column(std::string name, TypeId de_type);
/// \brief Add new column to the schema with unknown shape of rank 1
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(std::string).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
Status add_column(std::string name, std::string de_type);
/// \brief Add new column to the schema /// \brief Add new column to the schema
/// \param[in] name name of the column. /// \param[in] name name of the column.
/// \param[in] de_type data type of the column(TypeId). /// \param[in] de_type data type of the column(TypeId).
/// \param[in] shape shape of the column. /// \param[in] shape shape of the column.
/// \return bool true if schema init success /// \return bool true if schema init success
bool add_column(std::string name, TypeId de_type, std::vector<int32_t> shape); Status add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
/// \brief Add new column to the schema /// \brief Add new column to the schema
/// \param[in] name name of the column. /// \param[in] name name of the column.
/// \param[in] de_type data type of the column(std::string). /// \param[in] de_type data type of the column(std::string).
/// \param[in] shape shape of the column. /// \param[in] shape shape of the column.
/// \return bool true if schema init success /// \return bool true if schema init success
bool add_column(std::string name, std::string de_type, std::vector<int32_t> shape); Status add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
/// \brief Get a JSON string of the schema /// \brief Get a JSON string of the schema
/// \return JSON string of the schema /// \return JSON string of the schema
@ -395,25 +415,27 @@ class SchemaObj {
std::string to_string() { return to_json(); } std::string to_string() { return to_json(); }
/// \brief set a new value to dataset_type /// \brief set a new value to dataset_type
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = dataset_type; } inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); }
/// \brief set a new value to num_rows /// \brief set a new value to num_rows
inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; } inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
/// \brief get the current num_rows /// \brief get the current num_rows
inline int32_t get_num_rows() { return num_rows_; } inline int32_t get_num_rows() const { return num_rows_; }
Status FromJSONString(const std::string &json_string);
private: private:
/// \brief Parse the columns and add it to columns /// \brief Parse the columns and add it to columns
/// \param[in] columns dataset attribution information, decoded from schema file. /// \param[in] columns dataset attribution information, decoded from schema file.
/// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject. /// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
/// \return JSON string of the schema /// \return JSON string of the schema
bool parse_column(nlohmann::json columns); Status parse_column(nlohmann::json columns);
/// \brief Get schema file from json file /// \brief Get schema file from json file
/// \param[in] json_obj object of json parsed. /// \param[in] json_obj object of json parsed.
/// \return bool true if json dump success /// \return bool true if json dump success
bool from_json(nlohmann::json json_obj); Status from_json(nlohmann::json json_obj);
int32_t num_rows_; int32_t num_rows_;
std::string dataset_type_; std::string dataset_type_;

View File

@ -61,6 +61,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
class DistributedSamplerObj; class DistributedSamplerObj;
class PKSamplerObj; class PKSamplerObj;
class PreBuiltSamplerObj;
class RandomSamplerObj; class RandomSamplerObj;
class SequentialSamplerObj; class SequentialSamplerObj;
class SubsetRandomSamplerObj; class SubsetRandomSamplerObj;
@ -171,6 +172,31 @@ class PKSamplerObj : public SamplerObj {
int64_t num_samples_; int64_t num_samples_;
}; };
class PreBuiltSamplerObj : public SamplerObj {
public:
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
#endif
~PreBuiltSamplerObj() = default;
std::shared_ptr<SamplerRT> Build() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
private:
std::shared_ptr<SamplerRT> sp_;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
#endif
};
class RandomSamplerObj : public SamplerObj { class RandomSamplerObj : public SamplerObj {
public: public:
RandomSamplerObj(bool replacement, int64_t num_samples); RandomSamplerObj(bool replacement, int64_t num_samples);

View File

@ -70,6 +70,7 @@ namespace transforms {
class ComposeOperation; class ComposeOperation;
class DuplicateOperation; class DuplicateOperation;
class OneHotOperation; class OneHotOperation;
class PreBuiltOperation;
class RandomApplyOperation; class RandomApplyOperation;
class RandomChoiceOperation; class RandomChoiceOperation;
class TypeCastOperation; class TypeCastOperation;
@ -164,6 +165,20 @@ class OneHotOperation : public TensorOperation {
float num_classes_; float num_classes_;
}; };
class PreBuiltOperation : public TensorOperation {
public:
explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op);
~PreBuiltOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
private:
std::shared_ptr<TensorOp> op_;
};
class RandomApplyOperation : public TensorOperation { class RandomApplyOperation : public TensorOperation {
public: public:
explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob); explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob);
@ -192,7 +207,6 @@ class RandomChoiceOperation : public TensorOperation {
private: private:
std::vector<std::shared_ptr<TensorOperation>> transforms_; std::vector<std::shared_ptr<TensorOperation>> transforms_;
}; };
class TypeCastOperation : public TensorOperation { class TypeCastOperation : public TensorOperation {
public: public:
explicit TypeCastOperation(std::string data_type); explicit TypeCastOperation(std::string data_type);

View File

@ -71,6 +71,15 @@ namespace dataset {
return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \
} while (false) } while (false)
#define RETURN_SECOND_IF_ERROR(_s, _r) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
MS_LOG(ERROR) << __rc; \
return _r; \
} \
} while (false)
enum class StatusCode : char { enum class StatusCode : char {
kOK = 0, kOK = 0,
kOutOfMemory = 1, kOutOfMemory = 1,

View File

@ -138,7 +138,9 @@ Status Task::Join(WaitFlag blocking) {
while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) {
// We can't tell which conditional_variable this thread is waiting on. So we may need // We can't tell which conditional_variable this thread is waiting on. So we may need
// to interrupt everything one more time. // to interrupt everything one more time.
MS_LOG(INFO) << "Some threads not responding. Interrupt again"; std::stringstream ss;
ss << get_id();
MS_LOG(ERROR) << MyName() << " Thread ID " << ss.str() << " is not responding. Interrupt again";
interrupt_svc->InterruptAll(); interrupt_svc->InterruptAll();
} }
} else { } else {

View File

@ -21,7 +21,8 @@ import numpy
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load'] 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load',
'get_callback_timeout']
INT32_MAX = 2147483647 INT32_MAX = 2147483647
UINT32_MAX = 4294967295 UINT32_MAX = 4294967295

View File

@ -65,5 +65,7 @@ def mstypelist_to_detypelist(type_list):
for index, _ in enumerate(type_list): for index, _ in enumerate(type_list):
if type_list[index] is not None: if type_list[index] is not None:
type_list[index] = mstype_to_detype(type_list[index]) type_list[index] = mstype_to_detype(type_list[index])
else:
type_list[index] = cde.DataType("")
return type_list return type_list

File diff suppressed because it is too large Load Diff

View File

@ -15,17 +15,13 @@
"""Built-in iterators. """Built-in iterators.
""" """
from abc import abstractmethod from abc import abstractmethod
import copy
import weakref import weakref
import numpy as np import numpy as np
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._c_dataengine import DEPipeline import mindspore._c_dataengine as cde
from mindspore._c_dataengine import OpName
from mindspore import log as logger from mindspore import log as logger
from . import datasets as de
_ITERATOR_CLEANUP = False _ITERATOR_CLEANUP = False
@ -57,29 +53,6 @@ def _cleanup():
itr.release() itr.release()
def alter_tree(node):
"""Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
if not node.children:
return _alter_node(node)
converted_children = []
for input_op in node.children:
converted_children.append(alter_tree(input_op))
node.children = converted_children
return _alter_node(node)
def _alter_node(node):
"""DEPRECATED"""
# Please check ccsrc/dataset/engine/opt for tree transformation.
if isinstance(node, de.MapDataset):
if node.python_multiprocessing:
# Bootstrap can only be performed on a copy of the original dataset node.
# Bootstrap on original dataset node will make all iterators share the same process pool
node.iterator_bootstrap()
return node
class Iterator: class Iterator:
""" """
General Iterator over a dataset. General Iterator over a dataset.
@ -89,185 +62,62 @@ class Iterator:
""" """
def __init__(self, dataset, num_epochs=-1, output_numpy=False): def __init__(self, dataset, num_epochs=-1, output_numpy=False):
self.num_epochs = num_epochs self._col_names = None
self.output_numpy = output_numpy
# create a copy of tree and work on it.
self.ori_dataset = dataset
self.ir_tree, self.dataset = dataset.create_ir_tree()
self._runtime_context = cde.PythonRuntimeContext()
self._runtime_context.Init()
consumer = cde.PythonIteratorConsumer(num_epochs)
consumer.Init(self.ir_tree)
self._runtime_context.AssignConsumer(consumer)
self._iterator = self._runtime_context.GetConsumer()
self._transform_tensor = lambda t: t.as_array()
if not output_numpy:
self._transform_tensor = lambda t: Tensor(t.as_array())
self._index = 0
# todo remove next when ContextManager is done
ITERATORS_LIST.append(weakref.ref(self)) ITERATORS_LIST.append(weakref.ref(self))
_unset_iterator_cleanup() _unset_iterator_cleanup()
# create a copy of tree and work on it. #######
self.dataset = copy.deepcopy(dataset)
self.ori_dataset = dataset
self.parent_subtree = []
# The dataset passed into the iterator is not the root of the tree. def __iter__(self):
# Trim the tree by saving the parent subtree into self.parent_subtree and return self
# restore it after launching our C++ pipeline.
if self.dataset.parent:
logger.info("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
self.parent_subtree = self.dataset.parent
self.dataset.parent = []
self.dataset = alter_tree(self.dataset)
if not self.__is_tree():
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers).")
self.depipeline = DEPipeline()
# for manifest temporary use
self.__batch_node(self.dataset, 0)
root = self.__convert_node_postorder(self.dataset)
self.depipeline.AssignRootNode(root)
self.depipeline.PrepareTree(self.num_epochs)
self._index = 0
def stop(self): def stop(self):
""" """
Manually terminate Python iterator instead of relying on out of scope destruction. Manually terminate Python iterator instead of relying on out of scope destruction.
""" """
logger.info("Terminating Python iterator. This will also terminate C++ pipeline.") logger.info("Terminating Python iterator. This will also terminate C++ pipeline.")
if hasattr(self, 'depipeline') and self.depipeline: if hasattr(self, '_runtime_context') and self._runtime_context:
del self.depipeline if hasattr(self, '_iterator') and self._iterator:
self._runtime_context.Terminate()
def __is_tree_node(self, node): del self._iterator
"""Check if a node is tree node.""" del self._runtime_context
if not node.children: del self.dataset
if len(node.parent) > 1:
return False
if len(node.parent) > 1:
return False
for input_node in node.children:
cls = self.__is_tree_node(input_node)
if not cls:
return False
return True
def __is_tree(self):
return self.__is_tree_node(self.dataset)
@staticmethod
def __get_dataset_type(dataset):
"""Get the dataset type."""
op_type = None
if isinstance(dataset, de.ShuffleDataset):
op_type = OpName.SHUFFLE
elif isinstance(dataset, de.MindDataset):
op_type = OpName.MINDRECORD
elif isinstance(dataset, de.BatchDataset):
op_type = OpName.BATCH
elif isinstance(dataset, de.BucketBatchByLengthDataset):
op_type = OpName.BUCKETBATCH
elif isinstance(dataset, de.SyncWaitDataset):
op_type = OpName.BARRIER
elif isinstance(dataset, de.ZipDataset):
op_type = OpName.ZIP
elif isinstance(dataset, de.ConcatDataset):
op_type = OpName.CONCAT
elif isinstance(dataset, de.MapDataset):
op_type = OpName.MAP
elif isinstance(dataset, de.FilterDataset):
op_type = OpName.FILTER
elif isinstance(dataset, de.RepeatDataset):
op_type = OpName.REPEAT
elif isinstance(dataset, de.SkipDataset):
op_type = OpName.SKIP
elif isinstance(dataset, de.TakeDataset):
op_type = OpName.TAKE
elif isinstance(dataset, de.ImageFolderDataset):
op_type = OpName.IMAGEFOLDER
elif isinstance(dataset, de.GeneratorDataset):
op_type = OpName.GENERATOR
elif isinstance(dataset, de.TransferDataset):
op_type = OpName.DEVICEQUEUE
elif isinstance(dataset, de.RenameDataset):
op_type = OpName.RENAME
elif isinstance(dataset, de.TFRecordDataset):
op_type = OpName.TFREADER
elif isinstance(dataset, de.ProjectDataset):
op_type = OpName.PROJECT
elif isinstance(dataset, de.MnistDataset):
op_type = OpName.MNIST
elif isinstance(dataset, de.ManifestDataset):
op_type = OpName.MANIFEST
elif isinstance(dataset, de.VOCDataset):
op_type = OpName.VOC
elif isinstance(dataset, de.CocoDataset):
op_type = OpName.COCO
elif isinstance(dataset, de.Cifar10Dataset):
op_type = OpName.CIFAR10
elif isinstance(dataset, de.Cifar100Dataset):
op_type = OpName.CIFAR100
elif isinstance(dataset, de.CelebADataset):
op_type = OpName.CELEBA
elif isinstance(dataset, de.RandomDataset):
op_type = OpName.RANDOMDATA
elif isinstance(dataset, de.TextFileDataset):
op_type = OpName.TEXTFILE
elif isinstance(dataset, de.BuildVocabDataset):
op_type = OpName.BUILDVOCAB
elif isinstance(dataset, de.BuildSentencePieceVocabDataset):
op_type = OpName.SENTENCEPIECEVOCAB
elif isinstance(dataset, de.CLUEDataset):
op_type = OpName.CLUE
elif isinstance(dataset, de.CSVDataset):
op_type = OpName.CSV
else:
raise ValueError("Unsupported DatasetOp.")
return op_type
# Convert Python node into C node and add to C layer execution tree in postorder traversal.
def __convert_node_postorder(self, node):
self.check_node_type(node)
op_type = self.__get_dataset_type(node)
c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
for py_child in node.children:
c_child = self.__convert_node_postorder(py_child)
self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"])
return c_nodes["top"]
def __batch_node(self, dataset, level):
"""Recursively get batch node in the dataset tree."""
if isinstance(dataset, de.BatchDataset):
return
for input_op in dataset.children:
self.__batch_node(input_op, level + 1)
@staticmethod
def __print_local(dataset, level):
"""Recursively print the name and address of nodes in the dataset tree."""
name = dataset.__class__.__name__
ptr = hex(id(dataset))
for _ in range(level):
logger.info("\t", end='')
if not dataset.children:
logger.info("-%s (%s)", name, ptr)
else:
logger.info("+%s (%s)", name, ptr)
for input_op in dataset.children:
Iterator.__print_local(input_op, level + 1)
def print(self):
"""Print the dataset tree"""
self.__print_local(self.dataset, 0)
def release(self): def release(self):
if hasattr(self, 'depipeline') and self.depipeline: self.stop()
del self.depipeline
def __del__(self):
self.release()
@abstractmethod @abstractmethod
def get_next(self): def _get_next(self):
raise RuntimeError("Calling base class Iterator's get_next is invalid.") raise RuntimeError("Calling base class Iterator's get_next is invalid.")
def __next__(self): def __next__(self):
if not self.depipeline: if not self._runtime_context:
logger.warning("Iterator does not have a running C++ pipeline." + logger.warning("Iterator does not have a running C++ pipeline." +
"It might because Iterator stop() had been called, or C++ pipeline crashed silently.") "It might because Iterator stop() had been called, or C++ pipeline crashed silently.")
raise RuntimeError("Iterator does not have a running C++ pipeline.") raise RuntimeError("Iterator does not have a running C++ pipeline.")
data = self.get_next() data = self._get_next()
if not data: if not data:
if self._index == 0: if self._index == 0:
logger.warning("No records available.") logger.warning("No records available.")
@ -277,100 +127,56 @@ class Iterator:
self._index += 1 self._index += 1
return data return data
@abstractmethod
def check_node_type(self, node):
pass
def get_output_shapes(self):
return [t for t in self.depipeline.GetOutputShapes()]
def get_output_types(self):
return [t for t in self.depipeline.GetOutputTypes()]
def get_dataset_size(self):
return self.depipeline.GetDatasetSize()
def get_batch_size(self):
return self.depipeline.GetBatchSize()
def get_repeat_count(self):
return self.depipeline.GetRepeatCount()
def num_classes(self):
return self.depipeline.GetNumClasses()
def get_col_names(self):
return self.depipeline.GetColumnNames()
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self return self
def _getters(self):
"""
Get pipeline information.
"""
getter = cde.TreeGetters()
getter.Init(self.ir_tree)
self._runtime_context.AssignConsumer(getter)
self._col_names = getter.GetColumnNames()
class SaveOp(Iterator): def get_col_names(self):
""" """
The derived class of Iterator with dict type. Get names of the columns in the dataset
""" """
def __init__(self, dataset, num_epochs=-1): if self._col_names is None:
super().__init__(dataset, num_epochs) self._getters()
self.depipeline.LaunchTreeExec() return self._col_names
def get_next(self):
pass
def check_node_type(self, node):
if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
logger.warning("Used shuffle, repeat, batch before save operator.")
def save(self, file_names, file_type):
return self.depipeline.SaveDataset(file_names, file_type)
class DictIterator(Iterator): class DictIterator(Iterator):
""" """
The derived class of Iterator with dict type. The derived class of Iterator with dict type.
""" """
def __init__(self, dataset, num_epochs=-1, output_numpy=False):
super().__init__(dataset, num_epochs, output_numpy)
self.depipeline.LaunchTreeExec()
def check_node_type(self, node): def _get_next(self):
pass
def __iter__(self):
return self
def get_next(self):
""" """
Returns the next record in the dataset as dictionary Returns the next record in the dataset as dictionary
Returns: Returns:
Dict, the next record in the dataset. Dict, the next record in the dataset.
""" """
return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()}
if self.output_numpy:
return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()}
return {k: Tensor(v.as_array()) for k, v in self.depipeline.GetNextAsMap().items()}
class TupleIterator(Iterator): class TupleIterator(Iterator):
""" """
The derived class of Iterator with list type. The derived class of Iterator with list type.
""" """
def check_node_type(self, node):
pass
def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False): def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False):
if columns is not None: if columns is not None:
if not isinstance(columns, list): if not isinstance(columns, list):
columns = [columns] columns = [columns]
# todo: move next to IR
dataset = dataset.project(columns) dataset = dataset.project(columns)
super().__init__(dataset, num_epochs, output_numpy) super().__init__(dataset, num_epochs, output_numpy)
self.depipeline.LaunchTreeExec()
def __iter__(self): def _get_next(self):
return self
def get_next(self):
""" """
Returns the next record in the dataset as a list Returns the next record in the dataset as a list
@ -378,15 +184,14 @@ class TupleIterator(Iterator):
List, the next record in the dataset. List, the next record in the dataset.
""" """
if self.output_numpy: return [self._transform_tensor(t) for t in self._iterator.GetNextAsList()]
return [t.as_array() for t in self.depipeline.GetNextAsList()]
return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()]
class DummyIterator: class DummyIterator:
""" """
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED" A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
""" """
def __init__(self, dataset, mode): def __init__(self, dataset, mode):
self.mode = mode self.mode = mode
self.shapes = dataset.output_shapes() self.shapes = dataset.output_shapes()

View File

@ -283,9 +283,12 @@ def create_node(node):
node.get('shard_id'), sampler) node.get('shard_id'), sampler)
elif dataset_op == 'TFRecordDataset': elif dataset_op == 'TFRecordDataset':
shuffle = node.get('shuffle')
if shuffle is not None and isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'), pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
node.get('num_samples'), node.get('num_parallel_workers'), node.get('num_samples'), node.get('num_parallel_workers'),
de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id')) shuffle, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'ManifestDataset': elif dataset_op == 'ManifestDataset':
sampler = construct_sampler(node.get('sampler')) sampler = construct_sampler(node.get('sampler'))

View File

@ -293,14 +293,38 @@ def check_save(method):
return new_method return new_method
def check_iterator(method): def check_tuple_iterator(method):
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs) [columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_bool = ['output_numpy'] nreq_param_bool = ['output_numpy']
validate_dataset_param_value(nreq_param_bool, param_dict, bool) validate_dataset_param_value(nreq_param_bool, param_dict, bool)
if num_epochs is not None:
type_check(num_epochs, (int,), "num_epochs")
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
if columns is not None:
check_columns(columns, "column_names")
return method(self, *args, **kwargs)
return new_method
def check_dict_iterator(method):
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
@wraps(method)
def new_method(self, *args, **kwargs):
[num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_bool = ['output_numpy']
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
if num_epochs is not None:
type_check(num_epochs, (int,), "num_epochs")
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -523,6 +547,8 @@ def check_batch(method):
sig = ins.signature(batch_size) sig = ins.signature(batch_size)
if len(sig.parameters) != 1: if len(sig.parameters) != 1:
raise ValueError("callable batch_size should take one parameter (BatchInfo).") raise ValueError("callable batch_size should take one parameter (BatchInfo).")
else:
check_pos_int32(int(batch_size), "batch_size")
if num_parallel_workers is not None: if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers) check_num_parallel_workers(num_parallel_workers)
@ -807,6 +833,21 @@ def check_project(method):
return new_method return new_method
def check_schema(method):
"""check the input arguments of Schema.__init__."""
@wraps(method)
def new_method(self, *args, **kwargs):
[schema_file], _ = parse_user_args(method, *args, **kwargs)
if schema_file is not None:
type_check(schema_file, (str,), "schema_file")
return method(self, *args, **kwargs)
return new_method
def check_add_column(method): def check_add_column(method):
"""check the input arguments of add_column.""" """check the input arguments of add_column."""
@ -1261,3 +1302,23 @@ def check_cache_option(cache):
"""Sanity check for cache parameter""" """Sanity check for cache parameter"""
if cache is not None: if cache is not None:
type_check(cache, (cache_client.DatasetCache,), "cache") type_check(cache, (cache_client.DatasetCache,), "cache")
def check_to_device_send(method):
"""A wrapper that wraps a parameter checker around the check_to_device_send."""
@wraps(method)
def new_method(self, *args, **kwargs):
[num_epochs], _ = parse_user_args(method, *args, **kwargs)
if num_epochs is not None:
type_check(num_epochs, (int,), "num_epochs")
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
return method(self, *args, **kwargs)
return new_method
def replace_none(value, default):
return value if value is not None else default

View File

@ -18,13 +18,13 @@ use to_bytes and to_str to encode and decode strings into a specified format.
""" """
from enum import IntEnum from enum import IntEnum
import copy
import numpy as np import numpy as np
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \ from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \
check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model
__all__ = [ __all__ = [
"Vocab", "SentencePieceVocab", "to_str", "to_bytes" "Vocab", "SentencePieceVocab", "to_str", "to_bytes"
] ]
@ -39,8 +39,7 @@ class Vocab(cde.Vocab):
@classmethod @classmethod
@check_from_dataset @check_from_dataset
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, special_first=True):
special_first=True):
""" """
Build a vocab from a dataset. Build a vocab from a dataset.
@ -69,21 +68,7 @@ class Vocab(cde.Vocab):
Returns: Returns:
Vocab, Vocab object built from dataset. Vocab, Vocab object built from dataset.
""" """
return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first)
vocab = Vocab()
if columns is None:
columns = []
if not isinstance(columns, list):
columns = [columns]
if freq_range is None:
freq_range = (None, None)
if special_tokens is None:
special_tokens = []
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
for d in root.create_dict_iterator(num_epochs=1):
if d is not None:
raise ValueError("from_dataset should receive data other than None.")
return vocab
@classmethod @classmethod
@check_from_list @check_from_list
@ -143,6 +128,7 @@ class SentencePieceVocab(cde.SentencePieceVocab):
""" """
SentencePiece obiect that is used to segmentate words SentencePiece obiect that is used to segmentate words
""" """
@classmethod @classmethod
@check_from_dataset_sentencepiece @check_from_dataset_sentencepiece
def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params): def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params):
@ -164,13 +150,8 @@ class SentencePieceVocab(cde.SentencePieceVocab):
SentencePiece, SentencePiece object from dataset. SentencePiece, SentencePiece object from dataset.
""" """
vocab = SentencePieceVocab() return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage,
root = copy.deepcopy(dataset).build_sentencepiece_vocab(vocab, col_names, vocab_size, character_coverage, DE_C_INTER_SENTENCEPIECE_MODE[model_type], params)
model_type, params)
for d in root.create_dict_iterator(num_epochs=1):
if d is None:
raise ValueError("from_dataset should receive data other than None.")
return vocab
@classmethod @classmethod
@check_from_file_sentencepiece @check_from_file_sentencepiece
@ -270,6 +251,7 @@ class SentencePieceModel(IntEnum):
CHAR = 2 CHAR = 2
WORD = 3 WORD = 3
DE_C_INTER_SENTENCEPIECE_MODE = { DE_C_INTER_SENTENCEPIECE_MODE = {
SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM, SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM,
SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE, SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE,

View File

@ -432,7 +432,7 @@ def check_from_dataset_sentencepiece(method):
[_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
if col_names is not None: if col_names is not None:
type_check(col_names, (list,), "col_names") type_check_list(col_names, (str,), "col_names")
if vocab_size is not None: if vocab_size is not None:
check_uint32(vocab_size, "vocab_size") check_uint32(vocab_size, "vocab_size")

View File

@ -146,6 +146,7 @@ if (BUILD_MINDDATA STREQUAL "full")
list(REMOVE_ITEM MINDDATA_ENGINE_IR_CACHE_SRC_FILES list(REMOVE_ITEM MINDDATA_ENGINE_IR_CACHE_SRC_FILES
"${MINDDATA_DIR}/engine/ir/cache/dataset_cache_impl.cc" "${MINDDATA_DIR}/engine/ir/cache/dataset_cache_impl.cc"
"${MINDDATA_DIR}/engine/ir/cache/pre_built_dataset_cache.cc"
) )
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES

View File

@ -123,6 +123,7 @@ def connect_network_with_dataset(network, dataset_helper):
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
return network return network
class DatasetHelper: class DatasetHelper:
""" """
DatasetHelper is a class to process the MindData dataset and it provides the information of dataset. DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
@ -197,7 +198,6 @@ class DatasetHelper:
def get_data_info(self): def get_data_info(self):
return self.iter.get_data_info() return self.iter.get_data_info()
class _DatasetIter: class _DatasetIter:
"""Base iter for dataset helper""" """Base iter for dataset helper"""
@ -331,7 +331,6 @@ class _DatasetIterPSLite(_DatasetIter):
class _DatasetIterNormal: class _DatasetIterNormal:
"""Iter for normal(non sink) mode, feed the data from host.""" """Iter for normal(non sink) mode, feed the data from host."""
def __init__(self, dataset, epoch_num=-1): def __init__(self, dataset, epoch_num=-1):
self.dataset = dataset self.dataset = dataset
self.device_num = _get_device_num() self.device_num = _get_device_num()

View File

@ -61,15 +61,15 @@ class MindData:
def send(self, num_epochs=-1): def send(self, num_epochs=-1):
pass pass
def get_data_info(self):
pass
def stop_send(self): def stop_send(self):
pass pass
def continue_send(self): def continue_send(self):
pass pass
def get_data_info(self):
pass
def __len__(self): def __len__(self):
return self._size return self._size

View File

@ -177,8 +177,8 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) {
MS_LOG(INFO) << "Tensor image shape: " << image->shape(); MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row); iter->GetNextRow(&row);
} }
// 5 batches of size 2 // With 2 boundaries, 3 buckets are created
EXPECT_EQ(i, 5); EXPECT_EQ(i, 3);
// Manually terminate the pipeline // Manually terminate the pipeline
iter->Stop(); iter->Stop();

View File

@ -132,6 +132,6 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
// verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not
EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos);
EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos);
EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos); EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos);
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
} }

View File

@ -63,7 +63,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}}; const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}};
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
std::vector<size_t> row_sizes = {2, 2, 0, 0}; std::vector<size_t> row_sizes = {2, 2, 0};
TensorRow row; TensorRow row;
for (size_t sz : row_sizes) { for (size_t sz : row_sizes) {
@ -75,7 +75,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
rc = tree_adapter.GetNext(&row); rc = tree_adapter.GetNext(&row);
EXPECT_TRUE(rc.IsError()); EXPECT_TRUE(rc.IsError());
const std::string err_msg = rc.ToString(); const std::string err_msg = rc.ToString();
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
} }
TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
@ -97,7 +97,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap(); const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap();
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0, 0}; std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0};
TensorRow row; TensorRow row;
for (size_t sz : row_sizes) { for (size_t sz : row_sizes) {
@ -107,7 +107,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
} }
rc = tree_adapter.GetNext(&row); rc = tree_adapter.GetNext(&row);
const std::string err_msg = rc.ToString(); const std::string err_msg = rc.ToString();
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
} }
TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
@ -135,7 +135,7 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
const std::unordered_map<std::string, int32_t> map = {{"label", 0}}; const std::unordered_map<std::string, int32_t> map = {{"label", 0}};
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0, 0}; std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0};
TensorRow row; TensorRow row;
for (size_t sz : row_sizes) { for (size_t sz : row_sizes) {
@ -145,5 +145,5 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
} }
rc = tree_adapter.GetNext(&row); rc = tree_adapter.GetNext(&row);
const std::string err_msg = rc.ToString(); const std::string err_msg = rc.ToString();
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
} }

View File

@ -451,6 +451,10 @@ def test_batch_exception_13():
def test_batch_exception_14(): def test_batch_exception_14():
"""
Test per_batch_map and input column name
"""
logger.info("test_batch_exception_14")
batch_size = 2 batch_size = 2
input_columns = ["num"] input_columns = ["num"]
data1 = ds.TFRecordDataset(DATA_DIR) data1 = ds.TFRecordDataset(DATA_DIR)
@ -460,6 +464,22 @@ def test_batch_exception_14():
assert "per_batch_map and input_columns need to be passed in together." in str(e) assert "per_batch_map and input_columns need to be passed in together." in str(e)
def test_batch_exception_15():
"""
Test batch_size = int32 max value + 1
"""
logger.info("test_batch_exception_15")
batch_size = 2147483647 + 1
input_columns = ["num"]
data1 = ds.TFRecordDataset(DATA_DIR)
err_msg = ""
try:
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
except ValueError as e:
err_msg = str(e)
assert "batch_size is not within the required interval of (1 to 2147483647)" in err_msg
if __name__ == '__main__': if __name__ == '__main__':
test_batch_01() test_batch_01()
test_batch_02() test_batch_02()
@ -486,4 +506,5 @@ if __name__ == '__main__':
test_batch_exception_12() test_batch_exception_12()
test_batch_exception_13() test_batch_exception_13()
test_batch_exception_14() test_batch_exception_14()
test_batch_exception_15()
logger.info('\n') logger.info('\n')

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import os
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
@ -354,6 +355,18 @@ def test_clue_to_device():
data.send() data.send()
def test_clue_invalid_files():
"""
Test CLUE with invalid files
"""
AFQMC_DIR = '../data/dataset/testCLUE/afqmc'
afqmc_train_json = os.path.join(AFQMC_DIR)
with pytest.raises(ValueError) as info:
_ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False)
assert "The following patterns did not match any files" in str(info.value)
assert AFQMC_DIR in str(info.value)
if __name__ == "__main__": if __name__ == "__main__":
test_clue() test_clue()
test_clue_num_shards() test_clue_num_shards()
@ -366,3 +379,4 @@ if __name__ == "__main__":
test_clue_tnews() test_clue_tnews()
test_clue_wsc() test_clue_wsc()
test_clue_to_device() test_clue_to_device()
test_clue_invalid_files()

View File

@ -195,6 +195,19 @@ def test_csv_dataset_size():
assert data.get_dataset_size() == 5 assert data.get_dataset_size() == 5
def test_csv_dataset_type_error():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", 0, "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
assert "type does not match" in str(err.value)
def test_csv_dataset_exception(): def test_csv_dataset_exception():
TEST_FILE = '../data/dataset/testCSV/exception.csv' TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset( data = ds.CSVDataset(
@ -208,17 +221,16 @@ def test_csv_dataset_exception():
assert "failed to parse file" in str(err.value) assert "failed to parse file" in str(err.value)
def test_csv_dataset_type_error(): def test_csv_dataset_duplicate_columns():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset( data = ds.CSVDataset(
TEST_FILE, DATA_FILE,
column_defaults=["", 0, "", ""], column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'], column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'],
shuffle=False) shuffle=False)
with pytest.raises(Exception) as err: with pytest.raises(RuntimeError) as info:
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): _ = data.create_dict_iterator(num_epochs=1, output_numpy=True)
pass assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value)
assert "type does not match" in str(err.value) assert "column_names" in str(info.value)
if __name__ == "__main__": if __name__ == "__main__":
@ -234,5 +246,6 @@ if __name__ == "__main__":
test_csv_dataset_header() test_csv_dataset_header()
test_csv_dataset_number() test_csv_dataset_number()
test_csv_dataset_size() test_csv_dataset_size()
test_csv_dataset_exception()
test_csv_dataset_type_error() test_csv_dataset_type_error()
test_csv_dataset_exception()
test_csv_dataset_duplicate_columns()

View File

@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train" IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train"
IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
@ -21,9 +22,18 @@ IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-000
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
MNIST_DATA_DIR = "../data/dataset/testMnistData" MNIST_DATA_DIR = "../data/dataset/testMnistData"
MIND_CV_FILE_NAME = "../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord"
SCHEMA_FILE = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data" CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data"
CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data" CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data"
VOC_DATA_DIR = "../data/dataset/testVOC2012"
COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
CLUE_FILE = '../data/dataset/testCLUE/afqmc/train.json'
CSV_FILE = '../data/dataset/testCSV/1.csv'
TEXT_DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
def test_imagenet_rawdata_dataset_size(): def test_imagenet_rawdata_dataset_size():
@ -50,8 +60,15 @@ def test_imagenet_tf_file_dataset_size():
ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0) ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0)
assert ds_shard_2_0.get_dataset_size() == 6 assert ds_shard_2_0.get_dataset_size() == 6
# FIXME: dataset_size == 6 looks wrong but seem it aims to match the current code.
# Correct answer should be 12/3=4, the code issue should be addressed.
ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0) ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0)
assert ds_shard_3_0.get_dataset_size() == 4 assert ds_shard_3_0.get_dataset_size() == 6
count = 0
for _ in ds_shard_3_0.create_dict_iterator():
count += 1
assert ds_shard_3_0.get_dataset_size() == count
def test_mnist_dataset_size(): def test_mnist_dataset_size():
@ -76,6 +93,14 @@ def test_mnist_dataset_size():
assert ds_shard_3_0.get_dataset_size() == 3334 assert ds_shard_3_0.get_dataset_size() == 3334
def test_mind_dataset_size():
dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0")
assert dataset.get_dataset_size() == 20
dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 10
def test_manifest_dataset_size(): def test_manifest_dataset_size():
ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE) ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE)
assert ds_total.get_dataset_size() == 4 assert ds_total.get_dataset_size() == 4
@ -95,10 +120,11 @@ def test_cifar10_dataset_size():
assert ds_total.get_dataset_size() == 10000 assert ds_total.get_dataset_size() == 10000
# test get_dataset_size with usage flag # test get_dataset_size with usage flag
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 0
train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size() train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 10000 assert train_size == 10000
test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size()
assert test_size == 0
all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size() all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size()
assert all_size == 10000 assert all_size == 10000
@ -120,8 +146,6 @@ def test_cifar100_dataset_size():
assert ds_total.get_dataset_size() == 10000 assert ds_total.get_dataset_size() == 10000
# test get_dataset_size with usage flag # test get_dataset_size with usage flag
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 0
test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size() test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size()
assert test_size == 10000 assert test_size == 10000
all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size() all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size()
@ -137,10 +161,97 @@ def test_cifar100_dataset_size():
assert ds_shard_3_0.get_dataset_size() == 3334 assert ds_shard_3_0.get_dataset_size() == 3334
def test_voc_dataset_size():
dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
assert dataset.get_dataset_size() == 10
dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True,
num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 5
def test_coco_dataset_size():
dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
decode=True, shuffle=False)
assert dataset.get_dataset_size() == 6
dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 3
def test_celeba_dataset_size():
dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
assert dataset.get_dataset_size() == 4
dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
def test_clue_dataset_size():
dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False)
assert dataset.get_dataset_size() == 3
dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
def test_csv_dataset_size():
dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
shuffle=False)
assert dataset.get_dataset_size() == 3
dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
def test_text_file_dataset_size():
dataset = ds.TextFileDataset(TEXT_DATA_FILE)
assert dataset.get_dataset_size() == 3
dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
def test_padded_dataset_size():
dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}])
assert dataset.get_dataset_size() == 2
def test_pipeline_get_dataset_size():
dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False)
assert dataset.get_dataset_size() == 12
dataset = dataset.shuffle(buffer_size=3)
assert dataset.get_dataset_size() == 12
decode_op = vision.Decode()
resize_op = vision.RandomResize(10)
dataset = dataset.map([decode_op, resize_op], input_columns=["image"])
assert dataset.get_dataset_size() == 12
dataset = dataset.batch(batch_size=3)
assert dataset.get_dataset_size() == 4
dataset = dataset.repeat(count=2)
assert dataset.get_dataset_size() == 8
if __name__ == '__main__': if __name__ == '__main__':
test_imagenet_rawdata_dataset_size() test_imagenet_rawdata_dataset_size()
test_imagenet_tf_file_dataset_size() test_imagenet_tf_file_dataset_size()
test_mnist_dataset_size() test_mnist_dataset_size()
test_mind_dataset_size()
test_manifest_dataset_size() test_manifest_dataset_size()
test_cifar10_dataset_size() test_cifar10_dataset_size()
test_cifar100_dataset_size() test_cifar100_dataset_size()
test_voc_dataset_size()
test_coco_dataset_size()
test_celeba_dataset_size()
test_clue_dataset_size()
test_csv_dataset_size()
test_text_file_dataset_size()
test_padded_dataset_size()
test_pipeline_get_dataset_size()

View File

@ -521,7 +521,7 @@ def test_chained_sampler_04():
# Verify dataset size # Verify dataset size
data1_size = data1.get_dataset_size() data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size)) logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 24 assert data1_size == 6
# Verify number of iterations # Verify number of iterations
num_iter = 0 num_iter = 0

View File

@ -182,6 +182,15 @@ def test_voc_exception():
pass pass
def test_voc_num_classes():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
assert data1.num_classes() is None
class_index = {'car': 0, 'cat': 1, 'train': 5}
data2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True)
assert data2.num_classes() is None
if __name__ == '__main__': if __name__ == '__main__':
test_voc_segmentation() test_voc_segmentation()
test_voc_detection() test_voc_detection()
@ -191,3 +200,4 @@ if __name__ == '__main__':
test_case_1() test_case_1()
test_case_2() test_case_2()
test_voc_exception() test_voc_exception()
test_voc_num_classes()

View File

@ -107,7 +107,7 @@ def test_decode_op():
# Expect a AttributeError since iter1 has been stopped. # Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info: with pytest.raises(AttributeError) as info:
iter1.__next__() iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value) assert "object has no attribute '_runtime_context'" in str(info.value)
with pytest.raises(RuntimeError) as info: with pytest.raises(RuntimeError) as info:
iter2.__next__() iter2.__next__()
@ -205,7 +205,7 @@ def test_generator_dict_3():
# Expect a AttributeError since iter1 has been stopped. # Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info: with pytest.raises(AttributeError) as info:
iter1.__next__() iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value) assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_dict_4(): def test_generator_dict_4():
@ -396,7 +396,7 @@ def test_generator_tuple_3():
# Expect a AttributeError since iter1 has been stopped. # Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info: with pytest.raises(AttributeError) as info:
iter1.__next__() iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value) assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_tuple_4(): def test_generator_tuple_4():
@ -546,7 +546,7 @@ def test_generator_tuple_repeat_repeat_2():
# Expect a AttributeError since iter1 has been stopped. # Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info: with pytest.raises(AttributeError) as info:
iter1.__next__() iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value) assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_tuple_repeat_repeat_3(): def test_generator_tuple_repeat_repeat_3():

View File

@ -74,9 +74,11 @@ def test_case2():
def test_case3(): def test_case3():
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename(
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(5) ["col_sint64"], ["a1"])
data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2) data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(5).rename(
["col_sint64"], ["a2"])
data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).rename(["col_sint64"], ["a3"])
data4 = ds.zip((data1, data2, data3)) data4 = ds.zip((data1, data2, data3))
@ -84,8 +86,9 @@ def test_case3():
def test_case4(): def test_case4():
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename(
data2 = ds.TFRecordDataset(FILES) ["col_sint64"], ["a1"])
data2 = ds.TFRecordDataset(FILES, columns_list=["col_sint64"]).rename(["col_sint64"], ["a2"])
assert data2.get_dataset_size() == 12 assert data2.get_dataset_size() == 12
data2 = data2.batch(2) data2 = data2.batch(2)
assert data2.get_dataset_size() == 6 assert data2.get_dataset_size() == 6

Some files were not shown because too many files have changed in this diff Show More