forked from mindspore-Ecosystem/mindspore
added python api based on cpp api
1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
This commit is contained in:
parent
d9b4b5c750
commit
809e1d5086
|
@ -2,9 +2,11 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
if (ENABLE_PYTHON)
|
||||
add_library(APItoPython OBJECT
|
||||
python/de_pipeline.cc
|
||||
python/pybind_register.cc
|
||||
python/bindings.cc
|
||||
python/pybind_conversion.cc
|
||||
python/bindings/dataset/include/datasets_bindings.cc
|
||||
python/bindings/dataset/include/iterator_bindings.cc
|
||||
python/bindings/dataset/include/schema_bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
python/bindings/dataset/callback/bindings.cc
|
||||
|
|
|
@ -115,7 +115,8 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Function to return a transferred Node that transfers data through a device.
|
||||
bool Dataset::DeviceQueue(bool send_epoch_end) {
|
||||
bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32_t num_epochs, bool send_epoch_end,
|
||||
int32_t total_batches, bool create_data_info_queue) {
|
||||
Status rc;
|
||||
|
||||
// Build and launch tree
|
||||
|
@ -126,11 +127,12 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Add TransferNode IR on top of dataset d
|
||||
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end);
|
||||
// Add TransferNode IR on top of dataset
|
||||
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), queue_name, device_type, send_epoch_end,
|
||||
total_batches, create_data_info_queue);
|
||||
|
||||
// Get ToDevice consumer
|
||||
auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1);
|
||||
auto consumer = std::make_unique<ToDevice>(num_epochs);
|
||||
ToDevice *consumer_ = consumer.get();
|
||||
rc = consumer->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
|
@ -199,127 +201,55 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
|
|||
|
||||
int64_t Dataset::GetDatasetSize() {
|
||||
int64_t dataset_size;
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetDatasetSize(&dataset_size);
|
||||
return rc.IsError() ? -1 : dataset_size;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1);
|
||||
return dataset_size;
|
||||
}
|
||||
|
||||
std::vector<DataType> Dataset::GetOutputTypes() {
|
||||
std::vector<DataType> types;
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
|
||||
return types;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
|
||||
return types;
|
||||
}
|
||||
rc = tree_getters_->GetOutputTypes(&types);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed.";
|
||||
types.clear();
|
||||
return types;
|
||||
}
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputTypes(&types), {});
|
||||
return types;
|
||||
}
|
||||
|
||||
std::vector<TensorShape> Dataset::GetOutputShapes() {
|
||||
std::vector<TensorShape> shapes;
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
|
||||
return shapes;
|
||||
}
|
||||
rc = tree_getters_->Init(this->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
|
||||
return shapes;
|
||||
}
|
||||
rc = tree_getters_->GetOutputShapes(&shapes);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed.";
|
||||
shapes.clear();
|
||||
return shapes;
|
||||
}
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputShapes(&shapes), {});
|
||||
return shapes;
|
||||
}
|
||||
|
||||
int64_t Dataset::GetNumClasses() {
|
||||
int64_t num_classes;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetNumClasses(&num_classes);
|
||||
return rc.IsError() ? -1 : num_classes;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetNumClasses(&num_classes), -1);
|
||||
return num_classes;
|
||||
}
|
||||
|
||||
std::vector<std::string> Dataset::GetColumnNames() {
|
||||
std::vector<std::string> col_names;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed.";
|
||||
return std::vector<std::string>();
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed.";
|
||||
return std::vector<std::string>();
|
||||
}
|
||||
rc = tree_getters_->GetColumnNames(&col_names);
|
||||
return rc.IsError() ? std::vector<std::string>() : col_names;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetColumnNames(&col_names), {});
|
||||
return col_names;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed.";
|
||||
return output_class_indexing;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed.";
|
||||
return output_class_indexing;
|
||||
}
|
||||
rc = tree_getters_->GetClassIndexing(&output_class_indexing);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed.";
|
||||
output_class_indexing.clear();
|
||||
return output_class_indexing;
|
||||
}
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetClassIndexing(&output_class_indexing), {});
|
||||
return output_class_indexing;
|
||||
}
|
||||
|
||||
|
@ -501,9 +431,13 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset(
|
|||
std::function<TensorRow(TensorRow)> element_length_function,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder) {
|
||||
auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries,
|
||||
bucket_batch_sizes, element_length_function, pad_info,
|
||||
pad_to_bucket_boundary, drop_remainder);
|
||||
std::shared_ptr<TensorOp> c_func = nullptr;
|
||||
if (element_length_function != nullptr) {
|
||||
c_func = std::make_shared<CFuncOp>(element_length_function);
|
||||
}
|
||||
auto ds =
|
||||
std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, bucket_batch_sizes,
|
||||
c_func, pad_info, pad_to_bucket_boundary, drop_remainder);
|
||||
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
@ -522,7 +456,9 @@ ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datase
|
|||
|
||||
FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate,
|
||||
std::vector<std::string> input_columns) {
|
||||
auto ds = std::make_shared<FilterNode>(input->IRNode(), predicate, input_columns);
|
||||
std::shared_ptr<TensorOp> c_func = nullptr;
|
||||
if (predicate) c_func = std::make_shared<CFuncOp>(predicate);
|
||||
auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, input_columns);
|
||||
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
@ -604,40 +540,20 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
|||
#endif
|
||||
int64_t Dataset::GetBatchSize() {
|
||||
int64_t batch_size;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetBatchSize(&batch_size);
|
||||
return rc.IsError() ? -1 : batch_size;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1);
|
||||
return batch_size;
|
||||
}
|
||||
|
||||
int64_t Dataset::GetRepeatCount() {
|
||||
int64_t repeat_count;
|
||||
auto ds = shared_from_this();
|
||||
Status rc;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->Init(ds->IRNode());
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
|
||||
return -1;
|
||||
}
|
||||
rc = tree_getters_->GetRepeatCount(&repeat_count);
|
||||
return rc.IsError() ? 0 : repeat_count;
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0);
|
||||
return repeat_count;
|
||||
}
|
||||
|
||||
std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
|
||||
|
@ -720,62 +636,65 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
|
|||
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}
|
||||
|
||||
// SchemaObj init function
|
||||
bool SchemaObj::init() {
|
||||
if (schema_file_ != "") {
|
||||
Status SchemaObj::init() {
|
||||
if (!schema_file_.empty()) {
|
||||
Path schema_file(schema_file_);
|
||||
if (!schema_file.Exists()) {
|
||||
MS_LOG(ERROR) << "The file " << schema_file << " does not exist or permission denied!";
|
||||
return false;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(),
|
||||
"The file " + schema_file_ + " does not exist or permission denied!");
|
||||
|
||||
nlohmann::json js;
|
||||
try {
|
||||
std::ifstream in(schema_file_);
|
||||
in >> js;
|
||||
if (js.find("columns") == js.end()) {
|
||||
MS_LOG(ERROR) << "\"columns\" node is required in the schema json file.";
|
||||
return false;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
|
||||
"\"columns\" node is required in the schema json file.");
|
||||
} catch (const std::exception &err) {
|
||||
MS_LOG(ERROR) << "Schema file failed to load";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Schema file failed to load");
|
||||
}
|
||||
return from_json(js);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to add a column to schema with a mstype de_type
|
||||
bool SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) {
|
||||
nlohmann::json new_column;
|
||||
new_column["name"] = name;
|
||||
// if de_type is mstype
|
||||
// Function to add a column to schema with a mstype de_type and known shape
|
||||
Status SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) {
|
||||
DataType data_type = dataset::MSTypeToDEType(de_type);
|
||||
new_column["type"] = data_type.ToString();
|
||||
if (shape.size() > 0) {
|
||||
new_column["shape"] = shape;
|
||||
new_column["rank"] = shape.size();
|
||||
} else {
|
||||
new_column["rank"] = 1;
|
||||
}
|
||||
columns_.push_back(new_column);
|
||||
return true;
|
||||
return add_column(name, data_type.ToString(), shape);
|
||||
}
|
||||
|
||||
// Function to add a column to schema with a string de_type
|
||||
bool SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) {
|
||||
// Function to add a column to schema with a string de_type and known shape
|
||||
Status SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) {
|
||||
DataType data_type(de_type);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
|
||||
|
||||
nlohmann::json new_column;
|
||||
new_column["name"] = name;
|
||||
DataType data_type(de_type);
|
||||
new_column["type"] = data_type.ToString();
|
||||
if (shape.size() > 0) {
|
||||
new_column["shape"] = shape;
|
||||
new_column["rank"] = shape.size();
|
||||
} else {
|
||||
new_column["rank"] = 1;
|
||||
}
|
||||
new_column["shape"] = shape;
|
||||
new_column["rank"] = shape.size();
|
||||
|
||||
columns_.push_back(new_column);
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to add a column to schema with a mstype de_type and without shape
|
||||
Status SchemaObj::add_column(std::string name, TypeId de_type) {
|
||||
DataType data_type = dataset::MSTypeToDEType(de_type);
|
||||
return add_column(name, data_type.ToString());
|
||||
}
|
||||
|
||||
// Function to add a column to schema with a string de_type and without shape
|
||||
Status SchemaObj::add_column(std::string name, std::string de_type) {
|
||||
DataType data_type(de_type);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
|
||||
|
||||
nlohmann::json new_column;
|
||||
new_column["name"] = name;
|
||||
new_column["type"] = data_type.ToString();
|
||||
new_column["rank"] = 1;
|
||||
|
||||
columns_.push_back(new_column);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string SchemaObj::to_json() {
|
||||
|
@ -792,7 +711,7 @@ std::string SchemaObj::to_json() {
|
|||
return json_file.dump(2);
|
||||
}
|
||||
|
||||
bool SchemaObj::parse_column(nlohmann::json columns) {
|
||||
Status SchemaObj::parse_column(nlohmann::json columns) {
|
||||
std::string name, de_type;
|
||||
std::vector<int32_t> shape;
|
||||
|
||||
|
@ -802,15 +721,13 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
|
|||
for (auto column : columns) {
|
||||
auto key_name = column.find("name");
|
||||
if (key_name == column.end()) {
|
||||
MS_LOG(ERROR) << "Column's name is missing";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Column's name is missing");
|
||||
}
|
||||
name = *key_name;
|
||||
|
||||
auto key_type = column.find("type");
|
||||
if (key_type == column.end()) {
|
||||
MS_LOG(ERROR) << "Column's type is missing";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
|
||||
}
|
||||
de_type = *key_type;
|
||||
|
||||
|
@ -819,17 +736,14 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
|
|||
if (key_shape != column.end()) {
|
||||
shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
|
||||
}
|
||||
if (!add_column(name, de_type, shape)) {
|
||||
return false;
|
||||
}
|
||||
RETURN_IF_NOT_OK(add_column(name, de_type, shape));
|
||||
}
|
||||
} else if (columns.type() == nlohmann::json::value_t::object) {
|
||||
for (const auto &it_child : columns.items()) {
|
||||
name = it_child.key();
|
||||
auto key_type = it_child.value().find("type");
|
||||
if (key_type == it_child.value().end()) {
|
||||
MS_LOG(ERROR) << "Column's type is missing";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
|
||||
}
|
||||
de_type = *key_type;
|
||||
|
||||
|
@ -839,43 +753,45 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
|
|||
shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
|
||||
}
|
||||
|
||||
if (!add_column(name, de_type, shape)) {
|
||||
return false;
|
||||
}
|
||||
RETURN_IF_NOT_OK(add_column(name, de_type, shape));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "columns must be dict or list, columns contain name, type, shape(optional).";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional).");
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool SchemaObj::from_json(nlohmann::json json_obj) {
|
||||
Status SchemaObj::from_json(nlohmann::json json_obj) {
|
||||
for (const auto &it_child : json_obj.items()) {
|
||||
if (it_child.key() == "datasetType") {
|
||||
dataset_type_ = it_child.value();
|
||||
} else if (it_child.key() == "numRows") {
|
||||
num_rows_ = it_child.value();
|
||||
} else if (it_child.key() == "columns") {
|
||||
if (!parse_column(it_child.value())) {
|
||||
MS_LOG(ERROR) << "parse columns failed";
|
||||
return false;
|
||||
}
|
||||
RETURN_IF_NOT_OK(parse_column(it_child.value()));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknown field " << it_child.key();
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key());
|
||||
}
|
||||
}
|
||||
if (columns_.empty()) {
|
||||
MS_LOG(ERROR) << "Columns are missing.";
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR("Columns are missing.");
|
||||
}
|
||||
if (num_rows_ <= 0) {
|
||||
MS_LOG(ERROR) << "numRows must be greater than 0";
|
||||
return false;
|
||||
if (num_rows_ < 0) {
|
||||
RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0");
|
||||
}
|
||||
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
Status SchemaObj::FromJSONString(const std::string &json_string) {
|
||||
try {
|
||||
nlohmann::json js = nlohmann::json::parse(json_string);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
|
||||
"\"columns\" node is required in the schema json JSON.");
|
||||
RETURN_IF_NOT_OK(from_json(js));
|
||||
} catch (const std::exception &err) {
|
||||
RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// OTHER FUNCTIONS
|
||||
|
|
|
@ -1,136 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
DEPipeline, 0, ([](const py::module *m) {
|
||||
(void)py::class_<DEPipeline>(*m, "DEPipeline")
|
||||
.def(py::init<>())
|
||||
.def(
|
||||
"AddNodeToTree",
|
||||
[](DEPipeline &de, const OpName &op_name, const py::dict &args) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
|
||||
return out;
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def_static("AddChildToParentNode",
|
||||
[](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
|
||||
THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
|
||||
})
|
||||
.def("AssignRootNode",
|
||||
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
|
||||
.def("SetBatchParameters",
|
||||
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
|
||||
.def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); })
|
||||
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
|
||||
.def("GetColumnNames",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetColumnNames(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetNextAsMap",
|
||||
[](DEPipeline &de) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(de.GetNextAsMap(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetNextAsList",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetNextAsList(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetOutputShapes",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetOutputShapes(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetOutputTypes",
|
||||
[](DEPipeline &de) {
|
||||
py::list out;
|
||||
THROW_IF_ERROR(de.GetOutputTypes(&out));
|
||||
return out;
|
||||
})
|
||||
.def("GetDataInfo",
|
||||
[](DEPipeline &de) {
|
||||
py::list types, shapes;
|
||||
THROW_IF_ERROR(de.GetDataInfo(&types, &shapes));
|
||||
return py::make_tuple(types, shapes);
|
||||
})
|
||||
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
|
||||
.def("GetBatchSize", &DEPipeline::GetBatchSize)
|
||||
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
||||
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
|
||||
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
|
||||
.def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
|
||||
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
|
||||
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
|
||||
return true;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(OpName, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<OpName>(*m, "OpName", py::arithmetic())
|
||||
.value("SHUFFLE", OpName::kShuffle)
|
||||
.value("BATCH", OpName::kBatch)
|
||||
.value("BUCKETBATCH", OpName::kBucketBatch)
|
||||
.value("BARRIER", OpName::kBarrier)
|
||||
.value("MINDRECORD", OpName::kMindrecord)
|
||||
.value("CACHE", OpName::kCache)
|
||||
.value("REPEAT", OpName::kRepeat)
|
||||
.value("SKIP", OpName::kSkip)
|
||||
.value("TAKE", OpName::kTake)
|
||||
.value("ZIP", OpName::kZip)
|
||||
.value("CONCAT", OpName::kConcat)
|
||||
.value("MAP", OpName::kMap)
|
||||
.value("FILTER", OpName::kFilter)
|
||||
.value("DEVICEQUEUE", OpName::kDeviceQueue)
|
||||
.value("GENERATOR", OpName::kGenerator)
|
||||
.export_values()
|
||||
.value("RENAME", OpName::kRename)
|
||||
.value("TFREADER", OpName::kTfReader)
|
||||
.value("PROJECT", OpName::kProject)
|
||||
.value("IMAGEFOLDER", OpName::kImageFolder)
|
||||
.value("MNIST", OpName::kMnist)
|
||||
.value("MANIFEST", OpName::kManifest)
|
||||
.value("VOC", OpName::kVoc)
|
||||
.value("COCO", OpName::kCoco)
|
||||
.value("CIFAR10", OpName::kCifar10)
|
||||
.value("CIFAR100", OpName::kCifar100)
|
||||
.value("RANDOMDATA", OpName::kRandomData)
|
||||
.value("BUILDVOCAB", OpName::kBuildVocab)
|
||||
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
|
||||
.value("CELEBA", OpName::kCelebA)
|
||||
.value("TEXTFILE", OpName::kTextFile)
|
||||
.value("EPOCHCTRL", OpName::kEpochCtrl)
|
||||
.value("CSV", OpName::kCsv)
|
||||
.value("CLUE", OpName::kClue);
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -19,8 +19,10 @@
|
|||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/core/client.h" // DE client
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
|
|
@ -0,0 +1,551 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
// IR non-leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
// IR non-leaf nodes - for android
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
|
||||
#endif
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
|
||||
// IR leaf nodes disabled for android
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
|
||||
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
|
||||
.def("SetNumWorkers",
|
||||
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) {
|
||||
return num_workers ? self->SetNumWorkers(*num_workers) : self;
|
||||
})
|
||||
.def(
|
||||
"Zip",
|
||||
[](std::shared_ptr<DatasetNode> self, py::list datasets) {
|
||||
auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets)));
|
||||
THROW_IF_ERROR(zip->ValidateParams());
|
||||
return zip;
|
||||
},
|
||||
py::arg("datasets"));
|
||||
}));
|
||||
|
||||
// PYBIND FOR LEAF NODES
|
||||
// (In alphabetical order)
|
||||
|
||||
PYBIND_REGISTER(
|
||||
CelebANode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", "to create a CelebANode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, bool decode,
|
||||
std::optional<py::list> extensions, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
|
||||
toStringSet(extensions), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(celebA->ValidateParams());
|
||||
return celebA;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
|
||||
"to create a Cifar10Node")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler),
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(cifar10->ValidateParams());
|
||||
return cifar10;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
|
||||
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
|
||||
"to create a Cifar100Node")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler),
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(cifar100->ValidateParams());
|
||||
return cifar100;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
CLUENode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode")
|
||||
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
std::shared_ptr<CLUENode> clue_node =
|
||||
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle),
|
||||
num_shards, shard_id, toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(clue_node->ValidateParams());
|
||||
return clue_node;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
CocoNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", "to create a CocoNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, bool decode,
|
||||
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
|
||||
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(coco->ValidateParams());
|
||||
return coco;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
|
||||
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
|
||||
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
|
||||
int32_t num_shards, int32_t shard_id,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults),
|
||||
column_names, num_samples, toShuffleMode(shuffle),
|
||||
num_shards, shard_id, toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(csv->ValidateParams());
|
||||
return csv;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
|
||||
*m, "GeneratorNode", "to create a GeneratorNode")
|
||||
.def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
|
||||
const std::vector<DataType> &column_types) {
|
||||
auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types);
|
||||
THROW_IF_ERROR(gen->ValidateParams());
|
||||
return gen;
|
||||
}))
|
||||
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema) {
|
||||
auto gen = std::make_shared<GeneratorNode>(generator_function, schema);
|
||||
THROW_IF_ERROR(gen->ValidateParams());
|
||||
return gen;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
|
||||
*m, "ImageFolderNode", "to create an ImageFolderNode")
|
||||
.def(py::init([](std::string dataset_dir, bool decode, std::optional<py::handle> sampler,
|
||||
std::optional<py::list> extensions, std::optional<py::dict> class_indexing,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
bool recursive = true;
|
||||
auto imagefolder = std::make_shared<ImageFolderNode>(
|
||||
dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions),
|
||||
toStringMap(class_indexing), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(imagefolder->ValidateParams());
|
||||
return imagefolder;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
|
||||
"to create a ManifestNode")
|
||||
.def(py::init([](std::string dataset_file, std::string usage, std::optional<py::handle> sampler,
|
||||
std::optional<py::dict> class_indexing, bool decode,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
|
||||
toStringMap(class_indexing), decode,
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(manifest->ValidateParams());
|
||||
return manifest;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
|
||||
"to create a MindDataNode")
|
||||
.def(py::init([](std::string dataset_file, std::optional<py::list> columns_list,
|
||||
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
|
||||
nlohmann::json padded_sample_json;
|
||||
std::map<std::string, std::string> sample_bytes;
|
||||
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
|
||||
auto minddata =
|
||||
std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list),
|
||||
toSamplerObj(sampler, true), padded_sample_json, num_padded);
|
||||
minddata->SetSampleBytes(&sample_bytes);
|
||||
THROW_IF_ERROR(minddata->ValidateParams());
|
||||
return minddata;
|
||||
}))
|
||||
.def(py::init([](py::list dataset_file, std::optional<py::list> columns_list,
|
||||
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
|
||||
nlohmann::json padded_sample_json;
|
||||
std::map<std::string, std::string> sample_bytes;
|
||||
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
|
||||
auto minddata =
|
||||
std::make_shared<MindDataNode>(toStringVector(dataset_file), toStringVector(columns_list),
|
||||
toSamplerObj(sampler, true), padded_sample_json, num_padded);
|
||||
minddata->SetSampleBytes(&sample_bytes);
|
||||
THROW_IF_ERROR(minddata->ValidateParams());
|
||||
return minddata;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
|
||||
"to create an MnistNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler),
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(mnist->ValidateParams());
|
||||
return mnist;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RandomNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode")
|
||||
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto random_node =
|
||||
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(random_node->ValidateParams());
|
||||
return random_node;
|
||||
}))
|
||||
.def(py::init([](int32_t total_rows, std::string schema, std::optional<py::list> columns_list,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
auto random_node =
|
||||
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(random_node->ValidateParams());
|
||||
return random_node;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
|
||||
"to create a TextFileNode")
|
||||
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
|
||||
int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>(
|
||||
toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id,
|
||||
toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(textfile_node->ValidateParams());
|
||||
return textfile_node;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
TFRecordNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
|
||||
"to create a TFRecordNode")
|
||||
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
|
||||
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
|
||||
std::optional<int32_t> shard_id, bool shard_equal_rows,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
if (!num_samples) {
|
||||
*num_samples = 0;
|
||||
}
|
||||
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
||||
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
|
||||
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(tfrecord->ValidateParams());
|
||||
return tfrecord;
|
||||
}))
|
||||
.def(py::init([](py::list dataset_files, std::string schema, std::optional<py::list> columns_list,
|
||||
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
|
||||
std::optional<int32_t> shard_id, bool shard_equal_rows,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
if (!num_samples) {
|
||||
*num_samples = 0;
|
||||
}
|
||||
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
|
||||
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
|
||||
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(tfrecord->ValidateParams());
|
||||
return tfrecord;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
|
||||
.def(
|
||||
py::init([](std::string dataset_dir, std::string task, std::string usage,
|
||||
std::optional<py::dict> class_indexing, bool decode,
|
||||
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
std::shared_ptr<VOCNode> voc =
|
||||
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
|
||||
toSamplerObj(sampler), toDatasetCache(std::move(cc)));
|
||||
THROW_IF_ERROR(voc->ValidateParams());
|
||||
return voc;
|
||||
}));
|
||||
}));
|
||||
|
||||
// PYBIND FOR NON-LEAF NODES
|
||||
// (In alphabetical order)
|
||||
|
||||
PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode",
|
||||
"to create a BatchNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t batch_size, bool drop_remainder,
|
||||
bool pad, py::list in_col_names, py::list out_col_names, py::list col_order,
|
||||
py::object size_obj, py::object map_obj, py::dict pad_info) {
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
|
||||
if (pad) {
|
||||
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
|
||||
}
|
||||
py::function size_func =
|
||||
py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function();
|
||||
py::function map_func =
|
||||
py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function();
|
||||
auto batch = std::make_shared<BatchNode>(
|
||||
self, batch_size, drop_remainder, pad, toStringVector(in_col_names),
|
||||
toStringVector(out_col_names), toStringVector(col_order), size_func, map_func, c_pad_info);
|
||||
THROW_IF_ERROR(batch->ValidateParams());
|
||||
return batch;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>(
|
||||
*m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> dataset, py::list column_names,
|
||||
std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes,
|
||||
py::object element_length_function, py::dict pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder) {
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
|
||||
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
|
||||
|
||||
auto bucket_batch = std::make_shared<BucketBatchByLengthNode>(
|
||||
dataset, toStringVector(column_names), bucket_boundaries, bucket_batch_sizes,
|
||||
toPyFuncOp(std::move(element_length_function), DataType::DE_INT32), c_pad_info,
|
||||
pad_to_bucket_boundary, drop_remainder);
|
||||
THROW_IF_ERROR(bucket_batch->ValidateParams());
|
||||
return bucket_batch;
|
||||
}),
|
||||
py::arg("dataset"), py::arg("column_names"), py::arg("bucket_boundaries"),
|
||||
py::arg("bucket_batch_sizes"), py::arg("element_length_function") = py::none(),
|
||||
py::arg("pad_info"), py::arg("pad_to_bucket_boundary"), py::arg("drop_remainder"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>(
|
||||
*m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<SentencePieceVocab> vocab,
|
||||
const std::vector<std::string> &col_names, uint32_t vocab_size,
|
||||
float character_coverage, SentencePieceModel model_type,
|
||||
const std::unordered_map<std::string, std::string> ¶ms) {
|
||||
auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>(
|
||||
self, vocab, col_names, vocab_size, character_coverage, model_type, params);
|
||||
THROW_IF_ERROR(build_sentence_vocab->ValidateParams());
|
||||
return build_sentence_vocab;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>(
|
||||
*m, "BuildVocabNode", "to create a BuildVocabNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<Vocab> vocab, py::list columns,
|
||||
py::tuple freq_range, int64_t top_k, py::list special_tokens, bool special_first) {
|
||||
auto build_vocab =
|
||||
std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range),
|
||||
top_k, toStringVector(special_tokens), special_first);
|
||||
THROW_IF_ERROR(build_vocab->ValidateParams());
|
||||
return build_vocab;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode",
|
||||
"to create a ConcatNode")
|
||||
.def(
|
||||
py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, std::optional<py::handle> sampler,
|
||||
py::list children_flag_and_nums, py::list children_start_end_index) {
|
||||
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
|
||||
toPairVector(children_flag_and_nums),
|
||||
toPairVector(children_start_end_index));
|
||||
THROW_IF_ERROR(concat->ValidateParams());
|
||||
return concat;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode",
|
||||
"to create a FilterNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::object predicate,
|
||||
std::vector<std::string> input_columns) {
|
||||
auto filter =
|
||||
std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns);
|
||||
THROW_IF_ERROR(filter->ValidateParams());
|
||||
return filter;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> operations,
|
||||
std::optional<py::list> input_columns, std::optional<py::list> output_columns,
|
||||
std::optional<py::list> project_columns,
|
||||
std::optional<std::shared_ptr<CacheClient>> cc,
|
||||
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) {
|
||||
auto map = std::make_shared<MapNode>(
|
||||
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
|
||||
toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)),
|
||||
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()));
|
||||
THROW_IF_ERROR(map->ValidateParams());
|
||||
return map;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode",
|
||||
"to create a ProjectNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list columns) {
|
||||
auto project = std::make_shared<ProjectNode>(self, toStringVector(columns));
|
||||
THROW_IF_ERROR(project->ValidateParams());
|
||||
return project;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode",
|
||||
"to create a RenameNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> input_columns,
|
||||
std::optional<py::list> output_columns) {
|
||||
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns),
|
||||
toStringVector(output_columns));
|
||||
THROW_IF_ERROR(rename->ValidateParams());
|
||||
return rename;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode",
|
||||
"to create a RepeatNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> input, int32_t count) {
|
||||
auto repeat = std::make_shared<RepeatNode>(input, count);
|
||||
THROW_IF_ERROR(repeat->ValidateParams());
|
||||
return repeat;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode",
|
||||
"to create a ShuffleNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t shuffle_size, bool reset_every_epoch) {
|
||||
auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch);
|
||||
THROW_IF_ERROR(shuffle->ValidateParams());
|
||||
return shuffle;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode",
|
||||
"to create a SkipNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
|
||||
auto skip = std::make_shared<SkipNode>(self, count);
|
||||
THROW_IF_ERROR(skip->ValidateParams());
|
||||
return skip;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode",
|
||||
"to create a SyncWaitNode")
|
||||
.def(
|
||||
py::init([](std::shared_ptr<DatasetNode> self, std::string condition_name, py::object callback) {
|
||||
py::function callback_func =
|
||||
py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function();
|
||||
auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback);
|
||||
THROW_IF_ERROR(sync_wait->ValidateParams());
|
||||
return sync_wait;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode",
|
||||
"to create a TakeNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
|
||||
auto take = std::make_shared<TakeNode>(self, count);
|
||||
THROW_IF_ERROR(take->ValidateParams());
|
||||
return take;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
|
||||
"to create a TransferNode")
|
||||
.def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type,
|
||||
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) {
|
||||
auto transfer = std::make_shared<TransferNode>(self, queue_name, device_type, send_epoch_end,
|
||||
total_batch, create_data_info_queue);
|
||||
THROW_IF_ERROR(transfer->ValidateParams());
|
||||
return transfer;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode")
|
||||
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets) {
|
||||
auto zip = std::make_shared<ZipNode>(datasets);
|
||||
THROW_IF_ERROR(zip->ValidateParams());
|
||||
return zip;
|
||||
}));
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,168 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
|
||||
#include "minddata/dataset/engine/python_runtime_context.h"
|
||||
#include "minddata/dataset/engine/consumers/python_tree_consumer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
|
||||
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer");
|
||||
}));
|
||||
PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>(
|
||||
*m, "PythonIteratorConsumer")
|
||||
.def(py::init<int32_t>())
|
||||
.def("Init", [](PythonIteratorConsumer &self,
|
||||
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("GetNextAsMap",
|
||||
[](PythonIteratorConsumer &self) {
|
||||
py::dict output;
|
||||
THROW_IF_ERROR(self.GetNextAsDict(&output));
|
||||
return output;
|
||||
})
|
||||
.def("GetNextAsList", [](PythonIteratorConsumer &self) {
|
||||
py::list output;
|
||||
THROW_IF_ERROR(self.GetNextAsList(&output));
|
||||
return output;
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonTreeGetters, TreeConsumer, std::shared_ptr<PythonTreeGetters>>(*m,
|
||||
"TreeGetters")
|
||||
.def(py::init<>())
|
||||
.def("Init",
|
||||
[](PythonTreeGetters &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("GetOutputShapes",
|
||||
[](PythonTreeGetters &self) {
|
||||
std::vector<TensorShape> shapes;
|
||||
THROW_IF_ERROR(self.GetOutputShapes(&shapes));
|
||||
return shapesToListOfShape(shapes);
|
||||
})
|
||||
.def("GetOutputTypes",
|
||||
[](PythonTreeGetters &self) {
|
||||
std::vector<DataType> types;
|
||||
THROW_IF_ERROR(self.GetOutputTypes(&types));
|
||||
return typesToListOfType(types);
|
||||
})
|
||||
.def("GetNumClasses",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t num_classes;
|
||||
THROW_IF_ERROR(self.GetNumClasses(&num_classes));
|
||||
return num_classes;
|
||||
})
|
||||
.def("GetRepeatCount",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t repeat_count;
|
||||
THROW_IF_ERROR(self.GetRepeatCount(&repeat_count));
|
||||
return repeat_count;
|
||||
})
|
||||
.def("GetBatchSize",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t batch_size;
|
||||
THROW_IF_ERROR(self.GetBatchSize(&batch_size));
|
||||
return batch_size;
|
||||
})
|
||||
.def("GetColumnNames",
|
||||
[](PythonTreeGetters &self) {
|
||||
std::vector<std::string> col_names;
|
||||
THROW_IF_ERROR(self.GetColumnNames(&col_names));
|
||||
return col_names;
|
||||
})
|
||||
.def("GetClassIndexing",
|
||||
[](PythonTreeGetters &self) {
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
||||
THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing));
|
||||
return output_class_indexing;
|
||||
})
|
||||
.def("GetDatasetSize",
|
||||
[](PythonTreeGetters &self) {
|
||||
int64_t dataset_size;
|
||||
THROW_IF_ERROR(self.GetDatasetSize(&dataset_size));
|
||||
return dataset_size;
|
||||
})
|
||||
.def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonRuntimeContext, 2, ([](const py::module *m) {
|
||||
(void)py::class_<PythonRuntimeContext, std::shared_ptr<PythonRuntimeContext>>(*m,
|
||||
"PythonRuntimeContext")
|
||||
.def(py::init<>())
|
||||
.def("Init", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Init()); })
|
||||
.def("AssignConsumer", &PythonRuntimeContext::AssignConsumer)
|
||||
.def("Terminate", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Terminate()); })
|
||||
.def("GetConsumer", &PythonRuntimeContext::GetPythonConsumer, py::return_value_policy::reference)
|
||||
.def("__deepcopy__", [](py::object &runtime_context, py::dict memo) { return runtime_context; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonBuildVocabConsumer, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonBuildVocabConsumer, TreeConsumer, std::shared_ptr<PythonBuildVocabConsumer>>(
|
||||
*m, "PythonBuildVocabConsumer")
|
||||
.def(py::init<>())
|
||||
.def("Init", [](PythonBuildVocabConsumer &self,
|
||||
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("Start", [](PythonBuildVocabConsumer &self) { THROW_IF_ERROR(self.Start()); });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ToDevice, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ToDevice, TreeConsumer, std::shared_ptr<ToDevice>>(*m, "ToDevice")
|
||||
.def(py::init<int32_t>())
|
||||
.def("Init", [](ToDevice &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); })
|
||||
.def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); })
|
||||
.def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); })
|
||||
.def("GetDataInfo",
|
||||
[](ToDevice &self) {
|
||||
std::vector<DataType> types_c;
|
||||
std::vector<TensorShape> shapes_c;
|
||||
{
|
||||
py::gil_scoped_release rel;
|
||||
THROW_IF_ERROR(self.GetDataInfo(&types_c, &shapes_c));
|
||||
}
|
||||
py::list types, shapes;
|
||||
for (auto el : types_c) {
|
||||
types.append(el.AsNumpyType());
|
||||
py::list shape;
|
||||
}
|
||||
for (auto el : shapes_c) {
|
||||
py::list shape = el.AsPyList();
|
||||
shapes.append(shape);
|
||||
}
|
||||
return py::make_tuple(types, shapes);
|
||||
})
|
||||
.def("__deepcopy__", [](py::object &to_device, py::dict memo) { return to_device; });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonSaveToDisk, TreeConsumer, std::shared_ptr<PythonSaveToDisk>>(
|
||||
*m, "PythonSaveToDisk")
|
||||
.def(py::init([](std::string &dataset_path, int32_t numFiles, std::string &datasetType) {
|
||||
auto save = std::make_shared<PythonSaveToDisk>(dataset_path, numFiles, datasetType);
|
||||
THROW_IF_ERROR(save->ValidateParams());
|
||||
return save;
|
||||
}))
|
||||
.def("Init",
|
||||
[](PythonSaveToDisk &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
|
||||
.def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); });
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(
|
||||
SchemaObj, 0, ([](const py::module *m) {
|
||||
(void)py::class_<SchemaObj, std::shared_ptr<SchemaObj>>(*m, "SchemaObj", "to create a SchemaObj")
|
||||
.def(py::init([](std::string schema_file) {
|
||||
auto schema = std::make_shared<SchemaObj>(schema_file);
|
||||
THROW_IF_ERROR(schema->init());
|
||||
return schema;
|
||||
}))
|
||||
.def("add_column", [](SchemaObj &self, std::string name, TypeId de_type,
|
||||
std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); })
|
||||
.def("add_column", [](SchemaObj &self, std::string name, std::string de_type,
|
||||
std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); })
|
||||
.def("add_column",
|
||||
[](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
|
||||
.def("add_column", [](SchemaObj &self, std::string name,
|
||||
std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
|
||||
.def("to_json", &SchemaObj::to_json)
|
||||
.def("to_string", &SchemaObj::to_string)
|
||||
.def("from_string",
|
||||
[](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.FromJSONString(json_string)); })
|
||||
.def("set_dataset_type", [](SchemaObj &self, std::string dataset_type) { self.set_dataset_type(dataset_type); })
|
||||
.def("set_num_rows", [](SchemaObj &self, int32_t num_rows) { self.set_num_rows(num_rows); })
|
||||
.def("get_num_rows", &SchemaObj::get_num_rows)
|
||||
.def("__deepcopy__", [](py::object &schema, py::dict memo) { return schema; });
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/api/python/de_pipeline.h"
|
||||
|
||||
#include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,265 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/core/client.h" // DE client
|
||||
#include "minddata/dataset/engine/dataset_iterator.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
using json = nlohmann::json;
|
||||
using DsOpPtr = std::shared_ptr<DatasetOp>;
|
||||
|
||||
class CacheClient;
|
||||
|
||||
// enum for the dataset operator names
|
||||
enum OpName {
|
||||
kShuffle,
|
||||
kMindrecord,
|
||||
kBatch,
|
||||
kBucketBatch,
|
||||
kBarrier,
|
||||
kCache,
|
||||
kRepeat,
|
||||
kSkip,
|
||||
kTake,
|
||||
kZip,
|
||||
kConcat,
|
||||
kMap,
|
||||
kFilter,
|
||||
kDeviceQueue,
|
||||
kGenerator,
|
||||
kRename,
|
||||
kTfReader,
|
||||
kProject,
|
||||
kImageFolder,
|
||||
kMnist,
|
||||
kManifest,
|
||||
kVoc,
|
||||
kCoco,
|
||||
kCifar10,
|
||||
kCifar100,
|
||||
kCelebA,
|
||||
kRandomData,
|
||||
kTextFile,
|
||||
kBuildVocab,
|
||||
kClue,
|
||||
kEpochCtrl,
|
||||
kSentencePieceVocab,
|
||||
kCsv
|
||||
};
|
||||
|
||||
// The C++ binder class that we expose to the python script.
|
||||
class DEPipeline {
|
||||
public:
|
||||
DEPipeline();
|
||||
|
||||
~DEPipeline();
|
||||
|
||||
// Function to add a Node to the Execution Tree.
|
||||
Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output);
|
||||
|
||||
// Function to add a child and parent relationship.
|
||||
static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op);
|
||||
|
||||
// Function to assign the node as root.
|
||||
Status AssignRootNode(const DsOpPtr &dataset_op);
|
||||
|
||||
// Function to get the column names in the last node in the tree in order
|
||||
Status GetColumnNames(py::list *output);
|
||||
|
||||
// Function to prepare the tree for execution
|
||||
Status PrepareTree(const int32_t num_epochs);
|
||||
|
||||
// Function to launch the tree execution.
|
||||
Status LaunchTreeExec();
|
||||
|
||||
// Get a row of data as dictionary of column name to the value.
|
||||
Status GetNextAsMap(py::dict *output);
|
||||
|
||||
// Get a row of data as list.
|
||||
Status GetNextAsList(py::list *output);
|
||||
|
||||
Status GetOutputShapes(py::list *output);
|
||||
|
||||
Status GetOutputTypes(py::list *output);
|
||||
|
||||
Status GetDataInfo(py::list *types, py::list *shapes);
|
||||
|
||||
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
|
||||
|
||||
int GetDatasetSize() const;
|
||||
|
||||
int GetBatchSize() const;
|
||||
|
||||
int GetRepeatCount() const;
|
||||
|
||||
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
template <typename T, typename S>
|
||||
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
|
||||
std::unique_ptr<S> *s, bool need_convert = false);
|
||||
|
||||
Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
|
||||
const TensorRow &row, json *schema, std::vector<std::string> *index_fields);
|
||||
|
||||
Status FetchDataFromTensorRow(const TensorRow &row,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data,
|
||||
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
|
||||
|
||||
Status BuildMindrecordSamplerChain(const py::handle &handle,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
|
||||
int num_padded);
|
||||
|
||||
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseProjectOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseImageFolderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseManifestOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
void PrintTree();
|
||||
|
||||
int32_t GetNumClasses() const;
|
||||
|
||||
Status ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status SetBatchParameters(const py::dict &args);
|
||||
|
||||
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status StopSend();
|
||||
|
||||
Status ContinueSend();
|
||||
|
||||
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
|
||||
std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
|
||||
|
||||
private:
|
||||
// Execution tree that links the dataset operators.
|
||||
std::shared_ptr<ExecutionTree> tree_;
|
||||
|
||||
std::unique_ptr<DatasetIterator> iterator_;
|
||||
|
||||
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
|
||||
|
||||
/// \brief Helper function to inject a cache operator over top of the current operation being built.
|
||||
/// \param[in] cache_client The client to use for caching
|
||||
/// \param[in] num_workers The number of workers to use in the cache op
|
||||
/// \param[in] input_op The operator to build the cache on top of
|
||||
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
|
||||
/// the cache operator
|
||||
/// \return Status return code
|
||||
Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op,
|
||||
std::shared_ptr<DatasetOp> *cache_op);
|
||||
|
||||
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
|
||||
/// \param[in] shuffle_size The size to use in the shuffle buffer
|
||||
/// \param[in] input_op The operator to build shuffle on top of
|
||||
/// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be
|
||||
/// the shuffle operator
|
||||
/// \return Status return code
|
||||
Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
|
||||
std::shared_ptr<DatasetOp> *shuffle_op);
|
||||
|
||||
/// \brief Helper function to compute the shuffle size
|
||||
/// \param[in] num_files The number of files in the dataset
|
||||
/// \param[in] num_devices The number of devices in the dataset
|
||||
/// \param[in] num_rows The number of rows in the dataset
|
||||
/// \param[in] total_rows An upper bound on the total rows in the dataset
|
||||
/// \param[out] shuffle_size The resultant computed shuffle size
|
||||
/// \return Status return code
|
||||
Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
|
||||
int64_t *shuffle_size);
|
||||
|
||||
int batch_size_;
|
||||
int repeat_num_;
|
||||
int num_rows_;
|
||||
int num_classes_;
|
||||
|
||||
int temp_batch_size_;
|
||||
bool temp_drop_remainder_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
|
|
@ -0,0 +1,265 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
float toFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
|
||||
|
||||
int toInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
||||
|
||||
int64_t toInt64(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
|
||||
|
||||
bool toBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
|
||||
|
||||
std::string toString(const py::handle &handle) { return py::reinterpret_borrow<py::str>(handle); }
|
||||
|
||||
std::set<std::string> toStringSet(const std::optional<py::list> list) {
|
||||
std::set<std::string> set;
|
||||
if (list) {
|
||||
for (auto l : *list) {
|
||||
if (!l.is_none()) {
|
||||
(void)set.insert(py::str(l));
|
||||
}
|
||||
}
|
||||
}
|
||||
return set;
|
||||
}
|
||||
|
||||
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict) {
|
||||
std::map<std::string, int32_t> map;
|
||||
if (dict) {
|
||||
for (auto p : *dict) {
|
||||
(void)map.emplace(toString(p.first), toInt(p.second));
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
std::vector<std::string> toStringVector(const std::optional<py::list> list) {
|
||||
std::vector<std::string> vector;
|
||||
if (list) {
|
||||
for (auto l : *list) {
|
||||
if (l.is_none())
|
||||
vector.emplace_back("");
|
||||
else
|
||||
vector.push_back(py::str(l));
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple) {
|
||||
std::pair<int64_t, int64_t> pair;
|
||||
if (tuple) {
|
||||
pair = std::make_pair(toInt64((*tuple)[0]), toInt64((*tuple)[1]));
|
||||
}
|
||||
return pair;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int, int>> toPairVector(const py::list list) {
|
||||
std::vector<std::pair<int, int>> vector;
|
||||
if (list) {
|
||||
for (auto data : list) {
|
||||
auto l = data.cast<py::tuple>();
|
||||
if (l[1].is_none())
|
||||
vector.emplace_back(toInt64(l[0]), 0);
|
||||
else
|
||||
vector.emplace_back(toInt64(l[0]), toInt64(l[1]));
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations) {
|
||||
std::vector<std::shared_ptr<TensorOperation>> vector;
|
||||
if (operations) {
|
||||
for (auto op : *operations) {
|
||||
std::shared_ptr<TensorOp> tensor_op;
|
||||
if (py::isinstance<TensorOp>(op)) {
|
||||
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
||||
} else if (py::isinstance<py::function>(op)) {
|
||||
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
[]() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }());
|
||||
}
|
||||
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets) {
|
||||
std::vector<std::shared_ptr<DatasetNode>> vector;
|
||||
vector.push_back(self);
|
||||
if (datasets) {
|
||||
for (auto ds : *datasets) {
|
||||
if (py::isinstance<DatasetNode>(ds)) {
|
||||
vector.push_back(ds.cast<std::shared_ptr<DatasetNode>>());
|
||||
} else {
|
||||
THROW_IF_ERROR(
|
||||
[]() { RETURN_STATUS_UNEXPECTED("Error: datasets is not recognised (not a DatasetNode instance)."); }());
|
||||
}
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset) {
|
||||
if (py_sampler) {
|
||||
std::shared_ptr<SamplerObj> sampler_obj;
|
||||
if (!isMindDataset) {
|
||||
// Common Sampler
|
||||
std::shared_ptr<SamplerRT> sampler;
|
||||
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create");
|
||||
sampler = create().cast<std::shared_ptr<SamplerRT>>();
|
||||
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
||||
} else {
|
||||
// Mindrecord Sampler
|
||||
std::shared_ptr<mindrecord::ShardOperator> sampler;
|
||||
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create_for_minddataset");
|
||||
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
|
||||
}
|
||||
return sampler_obj;
|
||||
} else {
|
||||
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: sampler input is not SamplerRT."); }());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Here we take in a python object, that holds a reference to a C++ object
|
||||
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc) {
|
||||
if (cc) {
|
||||
std::shared_ptr<DatasetCache> built_cache;
|
||||
// Common Sampler
|
||||
built_cache = std::make_shared<PreBuiltDatasetCache>(std::move(cc.value()));
|
||||
return built_cache;
|
||||
} else {
|
||||
// don't need to check here as cache is not enabled.
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ShuffleMode toShuffleMode(const int32_t shuffle) {
|
||||
if (shuffle == 0) return ShuffleMode::kFalse;
|
||||
if (shuffle == 1) return ShuffleMode::kFiles;
|
||||
if (shuffle == 2) return ShuffleMode::kGlobal;
|
||||
return ShuffleMode();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases) {
|
||||
std::vector<std::shared_ptr<CsvBase>> vector;
|
||||
if (csv_bases) {
|
||||
for (auto base : *csv_bases) {
|
||||
if (py::isinstance<py::int_>(base)) {
|
||||
vector.push_back(std::make_shared<CsvRecord<int>>(CsvType::INT, toInt(base)));
|
||||
} else if (py::isinstance<py::float_>(base)) {
|
||||
vector.push_back(std::make_shared<CsvRecord<float>>(CsvType::FLOAT, toFloat(base)));
|
||||
} else if (py::isinstance<py::str>(base)) {
|
||||
vector.push_back(std::make_shared<CsvRecord<std::string>>(CsvType::STRING, toString(base)));
|
||||
} else {
|
||||
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: each default value must be int, float, or string"); }());
|
||||
}
|
||||
}
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json,
|
||||
std::map<std::string, std::string> *sample_bytes) {
|
||||
for (const py::handle &key : padded_sample) {
|
||||
if (py::isinstance<py::bytes>(padded_sample[key])) {
|
||||
(*sample_bytes)[py::str(key).cast<std::string>()] = padded_sample[key].cast<std::string>();
|
||||
// py::str(key) enter here will loss its key name, so we create an unuse key for it in json, to pass ValidateParam
|
||||
(*padded_sample_json)[py::str(key).cast<std::string>()] = nlohmann::json::object();
|
||||
} else {
|
||||
nlohmann::json obj_json;
|
||||
if (padded_sample[key].is_none()) {
|
||||
obj_json = nullptr;
|
||||
} else if (py::isinstance<py::int_>(padded_sample[key])) {
|
||||
obj_json = padded_sample[key].cast<int64_t>();
|
||||
} else if (py::isinstance<py::float_>(padded_sample[key])) {
|
||||
obj_json = padded_sample[key].cast<double>();
|
||||
} else if (py::isinstance<py::str>(padded_sample[key])) {
|
||||
obj_json = padded_sample[key].cast<std::string>(); // also catch py::bytes
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Python object convert to json failed: " << py::cast<std::string>(padded_sample[key]);
|
||||
RETURN_STATUS_SYNTAX_ERROR("Python object convert to json failed");
|
||||
}
|
||||
(*padded_sample_json)[py::str(key).cast<std::string>()] = obj_json;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info) {
|
||||
for (auto p : value) {
|
||||
if (!p.second.is_none()) {
|
||||
auto tp = py::reinterpret_borrow<py::tuple>(p.second);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
|
||||
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
|
||||
std::shared_ptr<Tensor> pad_val = nullptr;
|
||||
if (py::isinstance<py::str>(tp[1])) {
|
||||
std::string pad_val_string = tp[1].is_none() ? "" : toString(tp[1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
Tensor::CreateFromVector(std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar(), &pad_val),
|
||||
"Cannot create pad_value Tensor");
|
||||
} else {
|
||||
float pad_val_float = tp[1].is_none() ? 0 : toFloat(tp[1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_val),
|
||||
"Cannot create pad_value Tensor");
|
||||
pad_val->SetItemAt<float>({}, pad_val_float);
|
||||
}
|
||||
(void)pad_info->insert({toString(p.first), {shape, pad_val}});
|
||||
} else { // tuple is None
|
||||
(void)pad_info->insert({toString(p.first), {TensorShape({}), nullptr}});
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type) {
|
||||
std::shared_ptr<TensorOp> py_func;
|
||||
if (!func.is_none()) {
|
||||
py::function py_function = func.cast<py::function>();
|
||||
py_func = std::make_shared<PyFuncOp>(py_function, data_type);
|
||||
} else {
|
||||
py_func = nullptr;
|
||||
}
|
||||
return py_func;
|
||||
}
|
||||
|
||||
py::list shapesToListOfShape(std::vector<TensorShape> shapes) {
|
||||
py::list shape_list;
|
||||
for (const auto &shape : shapes) {
|
||||
shape_list.append(shape.AsVector());
|
||||
}
|
||||
return shape_list;
|
||||
}
|
||||
|
||||
py::list typesToListOfType(std::vector<DataType> types) {
|
||||
py::list type_list;
|
||||
for (const auto &type : types) {
|
||||
type_list.append(type.AsNumpyType());
|
||||
}
|
||||
return type_list;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/kernels/py_func_op.h"
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
float toFloat(const py::handle &handle);
|
||||
|
||||
int toInt(const py::handle &handle);
|
||||
|
||||
int64_t toInt64(const py::handle &handle);
|
||||
|
||||
bool toBool(const py::handle &handle);
|
||||
|
||||
std::string toString(const py::handle &handle);
|
||||
|
||||
std::set<std::string> toStringSet(const std::optional<py::list> list);
|
||||
|
||||
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict);
|
||||
|
||||
std::vector<std::string> toStringVector(const std::optional<py::list> list);
|
||||
|
||||
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple);
|
||||
|
||||
std::vector<std::pair<int, int>> toPairVector(const py::list list);
|
||||
|
||||
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations);
|
||||
|
||||
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets);
|
||||
|
||||
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false);
|
||||
|
||||
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc);
|
||||
|
||||
ShuffleMode toShuffleMode(const int32_t shuffle);
|
||||
|
||||
std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases);
|
||||
|
||||
std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type);
|
||||
|
||||
Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json,
|
||||
std::map<std::string, std::string> *sample_bytes);
|
||||
|
||||
Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info);
|
||||
|
||||
py::list shapesToListOfShape(std::vector<TensorShape> shapes);
|
||||
|
||||
py::list typesToListOfType(std::vector<DataType> types);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
|
|
@ -190,6 +190,23 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// PreBuiltOperation
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler)
|
||||
: sp_(std::move(sampler)), sp_minddataset_(nullptr) {}
|
||||
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
|
||||
: sp_(nullptr), sp_minddataset_(std::move(sampler)) {}
|
||||
#endif
|
||||
|
||||
bool PreBuiltSamplerObj::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
|
|
|
@ -222,6 +222,13 @@ Status OneHotOperation::ValidateParams() {
|
|||
|
||||
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
|
||||
|
||||
// PreBuiltOperation
|
||||
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {}
|
||||
|
||||
Status PreBuiltOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
|
||||
|
||||
// RandomApplyOperation
|
||||
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
|
||||
: transforms_(transforms), prob_(prob) {}
|
||||
|
|
|
@ -18,6 +18,7 @@ set(SRC_FILES_LIST
|
|||
dataset_iterator.cc
|
||||
tree_adapter.cc
|
||||
runtime_context.cc
|
||||
python_runtime_context.cc
|
||||
consumers/tree_consumer.cc
|
||||
)
|
||||
if (ENABLE_PYTHON)
|
||||
|
|
|
@ -32,15 +32,37 @@ Status PythonIteratorConsumer::GetNextAsList(py::list *out) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) {
|
||||
std::unordered_map<std::string, TensorPtr> row;
|
||||
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec;
|
||||
Status s;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
RETURN_IF_NOT_OK(GetNextAsMap(&row));
|
||||
s = GetNextAsOrderedPair(&vec);
|
||||
}
|
||||
for (auto el : row) {
|
||||
(*out)[common::SafeCStr(el.first)] = el.second;
|
||||
RETURN_IF_NOT_OK(s);
|
||||
// Generate Python dict, python dict maintains its insertion order
|
||||
for (const auto &pair : vec) {
|
||||
(*out)[common::SafeCStr(pair.first)] = pair.second;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonBuildVocabConsumer::Start() {
|
||||
py::gil_scoped_release gil_release;
|
||||
return BuildVocabConsumer::Start();
|
||||
}
|
||||
|
||||
Status PythonSaveToDisk::Save() {
|
||||
py::gil_scoped_release gil_release;
|
||||
return SaveToDisk::Save();
|
||||
}
|
||||
|
||||
PythonSaveToDisk::PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType)
|
||||
: SaveToDisk(datasetPath, numFiles, datasetType) {}
|
||||
|
||||
Status PythonTreeGetters::GetRow(TensorRow *r) {
|
||||
py::gil_scoped_release gil_release;
|
||||
return TreeGetters::GetRow(r);
|
||||
}
|
||||
} // namespace mindspore::dataset
|
||||
|
|
|
@ -44,5 +44,21 @@ class PythonIteratorConsumer : public IteratorConsumer {
|
|||
/// \return Status error code
|
||||
Status GetNextAsDict(py::dict *out);
|
||||
};
|
||||
|
||||
class PythonBuildVocabConsumer : public BuildVocabConsumer {
|
||||
public:
|
||||
Status Start() override;
|
||||
};
|
||||
|
||||
class PythonSaveToDisk : public SaveToDisk {
|
||||
public:
|
||||
PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType);
|
||||
Status Save() override;
|
||||
};
|
||||
|
||||
class PythonTreeGetters : public TreeGetters {
|
||||
public:
|
||||
Status GetRow(TensorRow *r) override;
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <vector>
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/mindrecord/include/shard_header.h"
|
||||
|
@ -35,7 +36,7 @@ namespace mindspore::dataset {
|
|||
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
|
||||
|
||||
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); }
|
||||
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); }
|
||||
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->ServiceStop(); }
|
||||
|
||||
// IteratorConsumer
|
||||
Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
|
||||
|
@ -73,6 +74,38 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
|
||||
|
||||
TensorRow curr_row;
|
||||
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row));
|
||||
RETURN_OK_IF_TRUE(curr_row.empty());
|
||||
|
||||
size_t num_cols = curr_row.size(); // num_cols is non-empty.
|
||||
// order the column names according to their ids
|
||||
if (column_order_.empty()) {
|
||||
const int32_t invalid_col_id = -1;
|
||||
column_order_.resize(num_cols, {std::string(), invalid_col_id});
|
||||
for (const auto &itr : tree_adapter_->GetColumnNameMap()) {
|
||||
int32_t ind = itr.second;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
|
||||
column_order_[ind] = std::make_pair(itr.first, ind);
|
||||
}
|
||||
// error check, make sure the ids in col_name_id_map are continuous and starts from 0
|
||||
for (const auto &col : column_order_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(col.second != invalid_col_id, "column ids are not continuous.");
|
||||
}
|
||||
}
|
||||
|
||||
vec->reserve(num_cols);
|
||||
|
||||
std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
|
||||
[curr_row](const auto &col) { return std::make_pair(col.first, curr_row[col.second]); });
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ToDevice
|
||||
Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); }
|
||||
|
||||
|
@ -81,7 +114,6 @@ Status ToDevice::Send() {
|
|||
RETURN_IF_NOT_OK(tree_adapter_->Launch());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_IF_NOT_OK(root->GetNextBuffer(&db));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -101,9 +133,36 @@ Status ToDevice::Stop() {
|
|||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
|
||||
op->StopSend();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToDevice::GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes) {
|
||||
// tree_.root() must be DeviceQueueOp
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DeviceQueueOp");
|
||||
DATA_INFO data_info;
|
||||
RETURN_IF_NOT_OK(op->GetDataInfo(&data_info));
|
||||
for (auto el : data_info) {
|
||||
types->push_back(el.first);
|
||||
shapes->push_back(el.second);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToDevice::Terminate() {
|
||||
#ifdef ENABLE_TDTQUE
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
|
||||
op->StopWaiting();
|
||||
#endif
|
||||
return TreeConsumer::Terminate();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// SaveToDisk
|
||||
Status SaveToDisk::ValidateParams() {
|
||||
|
@ -282,50 +341,50 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
|
|||
if (column_type == DataType::DE_INT8) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<int8_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT16) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<int16_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT16) {
|
||||
std::unique_ptr<int32_t> data;
|
||||
std::unique_ptr<uint16_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT8) {
|
||||
std::unique_ptr<uint8_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT32) {
|
||||
std::unique_ptr<int32_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_UINT32) {
|
||||
std::unique_ptr<int64_t> data;
|
||||
std::unique_ptr<uint32_t> dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_INT64) {
|
||||
std::unique_ptr<int64_t> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_FLOAT32) {
|
||||
std::unique_ptr<float> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_FLOAT64) {
|
||||
std::unique_ptr<double> data, dummy;
|
||||
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
|
||||
RETURN_IF_NOT_OK(s);
|
||||
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
|
||||
} else if (column_type == DataType::DE_STRING) {
|
||||
|
@ -346,7 +405,7 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
|
||||
std::unique_ptr<S> *s, bool need_convert) {
|
||||
if (nullptr == src) {
|
||||
|
@ -379,47 +438,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
|
|||
}
|
||||
#endif
|
||||
|
||||
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) {
|
||||
tree_adapter_ = std::make_unique<TreeAdapter>();
|
||||
}
|
||||
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); }
|
||||
|
||||
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
|
||||
if (init_flag_) {
|
||||
return Status::OK();
|
||||
}
|
||||
Status s = tree_adapter_->Compile(std::move(d), 1);
|
||||
if (!s.IsError()) {
|
||||
init_flag_ = true;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
bool TreeGetters::isInitialized() { return init_flag_; }
|
||||
|
||||
Status TreeGetters::GetRow(TensorRow *row) {
|
||||
if (row_flag_ == false) {
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(row));
|
||||
row_flag_ = true;
|
||||
}
|
||||
root_ = std::move(d);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); }
|
||||
|
||||
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ == -1) {
|
||||
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize)));
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
|
||||
dataset_size_ = *dataset_size;
|
||||
if (*dataset_size == -1) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
int64_t num_rows = 0;
|
||||
TensorRow row = row_;
|
||||
while (row.size() != 0) {
|
||||
num_rows++;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
||||
if (*dataset_size == -1) { // run through the tree and get everything
|
||||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(GetRow(&row));
|
||||
int64_t row_cnt = 0;
|
||||
while (!row.empty()) {
|
||||
++row_cnt;
|
||||
RETURN_IF_NOT_OK(GetRow(&row));
|
||||
}
|
||||
dataset_size_ = num_rows;
|
||||
*dataset_size = row_cnt;
|
||||
}
|
||||
dataset_size_ = *dataset_size; // save the previous result
|
||||
}
|
||||
|
||||
*dataset_size = dataset_size_;
|
||||
|
@ -427,68 +471,88 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
|||
}
|
||||
|
||||
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
for (auto ts : row_) {
|
||||
DataType dt = ts->type();
|
||||
types->push_back(dt);
|
||||
}
|
||||
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
|
||||
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
|
||||
|
||||
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types),
|
||||
[](const TensorPtr &t) { return t->type(); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
|
||||
RETURN_IF_NOT_OK(GetRow(&row_));
|
||||
for (auto ts : row_) {
|
||||
TensorShape t = ts->shape();
|
||||
shapes->push_back(t);
|
||||
}
|
||||
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
|
||||
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
|
||||
|
||||
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes),
|
||||
[](const TensorPtr &t) { return t->shape(); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetBatchSize(int64_t *batch_size) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
*batch_size = root->GetTreeBatchSize();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
*repeat_count = root->GetTreeRepeatCount();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetNumClasses(int64_t *num_classes) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetColumnNames(std::vector<std::string> *output) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map();
|
||||
if (column_name_id_map.empty()) RETURN_STATUS_UNEXPECTED("GetColumnNames: column_name_id map was empty.");
|
||||
std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(),
|
||||
column_name_id_map.end());
|
||||
std::sort(column_name_id_vector.begin(), column_name_id_vector.end(),
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map is empty.");
|
||||
std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end());
|
||||
std::sort(col_name_id_vec.begin(), col_name_id_vec.end(),
|
||||
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
|
||||
return a.second < b.second;
|
||||
});
|
||||
for (auto item : column_name_id_vector) {
|
||||
(*output).push_back(item.first);
|
||||
}
|
||||
std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output),
|
||||
[](const std::pair<std::string, int32_t> &p) { return p.first; });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
RETURN_IF_NOT_OK(InternalInit());
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root);
|
||||
RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeGetters::InternalInit(int8_t type) {
|
||||
if (init_flag_) return Status::OK();
|
||||
tree_adapter_->SetPrePassOverride([&type](OptPass pre) {
|
||||
pre.push_back(std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(type)));
|
||||
return pre;
|
||||
});
|
||||
Status s = tree_adapter_->Compile(std::move(root_), 1);
|
||||
if (!s.IsError()) init_flag_ = true;
|
||||
return s;
|
||||
}
|
||||
Status TreeGetters::InternalInit() {
|
||||
if (init_flag_) return Status::OK();
|
||||
Status s = tree_adapter_->Compile(std::move(root_), 1);
|
||||
if (!s.IsError()) init_flag_ = true;
|
||||
return s;
|
||||
}
|
||||
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
|
||||
|
||||
Status BuildVocabConsumer::Start() {
|
||||
|
|
|
@ -41,7 +41,7 @@ class TreeConsumer {
|
|||
/// \return Status error code.
|
||||
virtual Status Init(std::shared_ptr<DatasetNode> d);
|
||||
|
||||
Status Terminate();
|
||||
virtual Status Terminate();
|
||||
|
||||
protected:
|
||||
/// The class owns the tree_adapter that handles execution tree operations.
|
||||
|
@ -72,6 +72,11 @@ class IteratorConsumer : public TreeConsumer {
|
|||
/// \return Status error code
|
||||
Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out);
|
||||
|
||||
/// Returns the next row in as a map
|
||||
/// \param[out] out std::map of string to Tensor
|
||||
/// \return Status error code
|
||||
Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec);
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
/// \return string
|
||||
|
@ -79,6 +84,7 @@ class IteratorConsumer : public TreeConsumer {
|
|||
|
||||
private:
|
||||
int32_t num_epochs_;
|
||||
std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id
|
||||
};
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -101,7 +107,7 @@ class SaveToDisk : public TreeConsumer {
|
|||
/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
|
||||
/// would be written to disk)
|
||||
/// \return Status error code
|
||||
Status Save();
|
||||
virtual Status Save();
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
|
@ -110,7 +116,7 @@ class SaveToDisk : public TreeConsumer {
|
|||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
Status TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
|
||||
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
|
||||
std::unique_ptr<S> *s, bool need_convert = false);
|
||||
|
||||
|
@ -131,24 +137,29 @@ class SaveToDisk : public TreeConsumer {
|
|||
/// Consumer that iterates over the dataset and send it to a device
|
||||
class ToDevice : public TreeConsumer {
|
||||
public:
|
||||
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1)
|
||||
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
|
||||
explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
|
||||
|
||||
~ToDevice() = default;
|
||||
|
||||
Status Init(std::shared_ptr<DatasetNode> d) override;
|
||||
|
||||
Status Terminate() override;
|
||||
|
||||
/// Send the data to device
|
||||
/// \return Status error code
|
||||
Status Send();
|
||||
virtual Status Send();
|
||||
|
||||
/// Stop to send data to device
|
||||
/// \return Status error code
|
||||
Status Stop();
|
||||
virtual Status Stop();
|
||||
|
||||
/// Continue to send data to device
|
||||
/// \return Status error code
|
||||
Status Continue();
|
||||
virtual Status Continue();
|
||||
|
||||
/// Get data info from TDT
|
||||
/// \return Status error code
|
||||
virtual Status GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes);
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
|
@ -156,8 +167,6 @@ class ToDevice : public TreeConsumer {
|
|||
std::string Name() override { return "ToDevice"; }
|
||||
|
||||
private:
|
||||
std::string device_type_;
|
||||
bool send_epoch_end_;
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
|
@ -167,6 +176,7 @@ class TreeGetters : public TreeConsumer {
|
|||
TreeGetters();
|
||||
~TreeGetters() = default;
|
||||
Status Init(std::shared_ptr<DatasetNode> d) override;
|
||||
|
||||
Status GetDatasetSize(int64_t *size);
|
||||
Status GetOutputTypes(std::vector<DataType> *types);
|
||||
Status GetOutputShapes(std::vector<TensorShape> *shapes);
|
||||
|
@ -175,15 +185,17 @@ class TreeGetters : public TreeConsumer {
|
|||
Status GetNumClasses(int64_t *num_classes);
|
||||
Status GetColumnNames(std::vector<std::string> *output);
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
|
||||
bool isInitialized();
|
||||
std::string Name() override { return "TreeGetters"; }
|
||||
Status GetRow(TensorRow *r);
|
||||
virtual Status GetRow(TensorRow *r);
|
||||
|
||||
private:
|
||||
std::shared_ptr<DatasetNode> root_;
|
||||
int64_t dataset_size_;
|
||||
TensorRow row_;
|
||||
TensorRow first_row_;
|
||||
bool init_flag_; // indicate whether the tree has initialized
|
||||
bool row_flag_; // indicate whether the first row has been stored in row_
|
||||
|
||||
Status InternalInit(int8_t type);
|
||||
Status InternalInit();
|
||||
};
|
||||
|
||||
class BuildVocabConsumer : public TreeConsumer {
|
||||
|
@ -197,7 +209,7 @@ class BuildVocabConsumer : public TreeConsumer {
|
|||
|
||||
/// Start consuming
|
||||
/// \return Status error code
|
||||
Status Start();
|
||||
virtual Status Start();
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
|
|
|
@ -44,9 +44,9 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
|
|||
}
|
||||
|
||||
// Constructor of the ConcatOp.
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums,
|
||||
std::vector<std::pair<int, int>> children_start_end_index)
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index)
|
||||
: PipelineOp(op_connector_size),
|
||||
children_num_(0),
|
||||
sampler_(sampler),
|
||||
|
|
|
@ -70,9 +70,9 @@ class ConcatOp : public PipelineOp {
|
|||
// @note The builder class should be used to call it
|
||||
// @param op_connector_size - connector size
|
||||
explicit ConcatOp(int32_t op_connector_size);
|
||||
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums,
|
||||
std::vector<std::pair<int, int>> children_start_end_index);
|
||||
ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index);
|
||||
|
||||
// Destructor
|
||||
~ConcatOp() = default;
|
||||
|
|
|
@ -346,6 +346,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return Name of the current Op
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
/// Op name and ID getter
|
||||
/// \return Name and ID of the current Op
|
||||
std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; }
|
||||
|
||||
/// Execution Tree getter
|
||||
/// \return Pointer to the ExecutionTree the current op belongs to, no ownership
|
||||
ExecutionTree *Tree() { return tree_; }
|
||||
|
|
|
@ -205,7 +205,6 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
}
|
||||
|
||||
tree_->SetFinished();
|
||||
MS_LOG(INFO) << "Device queue total batch is " << send_batch;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -39,10 +39,10 @@ using mindspore::device::GpuBufferMgr;
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>;
|
||||
using DATA_INFO_QUEUE = Queue<DATA_INFO>;
|
||||
const int kDataInfoQueueCapacity = 128;
|
||||
|
||||
class DeviceQueueOp : public PipelineOp {
|
||||
public:
|
||||
static const uint32_t INVALID_HANDLE = 0xffffffffUL;
|
||||
|
@ -184,7 +184,6 @@ class DeviceQueueOp : public PipelineOp {
|
|||
#ifdef ENABLE_TDTQUE
|
||||
Status SendDataToAscend();
|
||||
bool ascend_keep_waiting_;
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_GPUQUE
|
||||
|
|
|
@ -169,7 +169,7 @@ Status MapOp::operator()() {
|
|||
}
|
||||
|
||||
// The operator class just starts off threads by calling the tree_ function
|
||||
rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1));
|
||||
rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1), NameWithID());
|
||||
// Synchronize with TaskManager
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
|
|
|
@ -704,6 +704,8 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
}
|
||||
if (image_ids_.size() == 0) {
|
||||
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
|
||||
} else {
|
||||
num_rows = image_ids_.size();
|
||||
}
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
|
|
|
@ -480,13 +480,13 @@ Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows = num_rows_, sample_size;
|
||||
int64_t num_rows = num_rows_;
|
||||
if (num_rows_ <= 0) {
|
||||
std::shared_ptr<ShardOperator> op;
|
||||
// The last operator is parent sampler
|
||||
std::shared_ptr<ShardOperator> op = operators_.back();
|
||||
RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_));
|
||||
}
|
||||
sample_size = operators_[0]->GetNumSamples(num_rows, 0);
|
||||
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||
*dataset_size = num_rows;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1067,6 +1067,19 @@ Status TFReaderOp::PrepareNodePostAction() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the file list of the specific shard ID
|
||||
Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) {
|
||||
if (!shard_filenames->empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n");
|
||||
}
|
||||
for (int index = 0; index < dataset_files_list_.size(); index++) {
|
||||
if (index % num_devices_ == device_id_) {
|
||||
shard_filenames->push_back(dataset_files_list_.at(index));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
|
@ -1080,7 +1093,9 @@ Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
num_rows = num_rows_per_shard_;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, dataset_files_list_));
|
||||
std::vector<std::string> shard_file_list;
|
||||
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
|
||||
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list));
|
||||
}
|
||||
}
|
||||
sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();
|
||||
|
|
|
@ -400,6 +400,11 @@ class TFReaderOp : public ParallelOp {
|
|||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
// Private function for computing the file list of the specific shard ID. This is because in distributed scenario,
|
||||
// data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case.
|
||||
// @return - Status - the status code returned.
|
||||
Status GetShardFileList(std::vector<std::string> *shard_filenames);
|
||||
|
||||
int32_t device_id_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
|
|
|
@ -536,6 +536,8 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
|
|||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
num_rows = static_cast<int64_t>(op->image_ids_.size());
|
||||
}
|
||||
} else {
|
||||
num_rows = image_ids_.size();
|
||||
}
|
||||
sample_size = sampler_->CalculateNumSamples(num_rows);
|
||||
*dataset_size = sample_size;
|
||||
|
|
|
@ -141,8 +141,6 @@ Status ExecutionTree::Launch() {
|
|||
" Expected state: " + std::to_string(static_cast<int>(kDeTStateReady));
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
std::ostringstream ss;
|
||||
ss << *this;
|
||||
|
||||
// Profiling infrastructures need to be initialized before Op launching
|
||||
if (profiling_manager_->IsProfilingEnable()) {
|
||||
|
@ -152,6 +150,8 @@ Status ExecutionTree::Launch() {
|
|||
RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor());
|
||||
}
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << *this;
|
||||
MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str();
|
||||
for (auto itr = this->begin(); itr != this->end(); ++itr) {
|
||||
// An inlined operator is one that has an output connector size of 0, and it does not
|
||||
|
@ -160,7 +160,7 @@ Status ExecutionTree::Launch() {
|
|||
// the launching tree/user thread. Do not exec any thread for an inlined op.
|
||||
itr->state_ = DatasetOp::OpState::kDeOpRunning;
|
||||
if (!itr->inlined()) {
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr)));
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(itr->NameWithID(), std::ref(*itr)));
|
||||
// Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp
|
||||
}
|
||||
}
|
||||
|
@ -189,10 +189,10 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_
|
|||
|
||||
// Given the number of workers, launches the worker entry function for each. Essentially a
|
||||
// wrapper for the TaskGroup handling that is stored inside the execution tree.
|
||||
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func) {
|
||||
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name) {
|
||||
// Launch the workers
|
||||
for (int32_t i = 0; i < num_workers; ++i) {
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i)));
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -150,7 +150,7 @@ class ExecutionTree {
|
|||
// @param num_workers - The number of workers to launch
|
||||
// @param func - The function entry point that workers will execute
|
||||
// @return Status - The error code return
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func);
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "");
|
||||
|
||||
// Getter method
|
||||
// @return shared_ptr to the root operator
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-ir-cache OBJECT
|
||||
dataset_cache_impl.cc)
|
||||
pre_built_dataset_cache.cc
|
||||
dataset_cache_impl.cc)
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
|
||||
/// \return Status Error code
|
||||
Status DatasetCacheImpl::Build() {
|
||||
|
@ -40,5 +40,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// DatasetCache is the IR of CacheClient
|
||||
class DatasetCacheImpl : public DatasetCache {
|
||||
public:
|
||||
|
@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache {
|
|||
std::optional<int32_t> num_connections_;
|
||||
std::optional<int32_t> prefetch_sz_;
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
|
||||
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
|
||||
/// \return Status Error code
|
||||
Status PreBuiltDatasetCache::Build() {
|
||||
// we actually want to keep a reference of the runtime object so it can be shared by different pipelines
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op));
|
||||
*ds = cache_op;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// DatasetCache is the IR of CacheClient
|
||||
class PreBuiltDatasetCache : public DatasetCache {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param cc a pre-built cache client
|
||||
explicit PreBuiltDatasetCache(std::shared_ptr<CacheClient> cc) : cache_client_(std::move(cc)) {}
|
||||
|
||||
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
|
||||
/// \return Status Error code
|
||||
Status Build() override;
|
||||
|
||||
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
||||
|
||||
Status ValidateParams() override { return Status::OK(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<CacheClient> cache_client_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
|
|
@ -31,7 +31,7 @@ namespace dataset {
|
|||
BucketBatchByLengthNode::BucketBatchByLengthNode(
|
||||
std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
|
||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function,
|
||||
std::shared_ptr<TensorOp> element_length_function,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
|
||||
bool drop_remainder)
|
||||
: column_names_(column_names),
|
||||
|
@ -47,16 +47,13 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
|
|||
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::shared_ptr<TensorOp> c_func;
|
||||
if (element_length_function_ != nullptr) {
|
||||
c_func = std::make_shared<CFuncOp>(element_length_function_);
|
||||
} else {
|
||||
c_func = nullptr;
|
||||
bucket_boundaries_.insert(bucket_boundaries_.begin(), 0);
|
||||
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(
|
||||
column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_,
|
||||
pad_to_bucket_boundary_, drop_remainder_, connector_que_size_));
|
||||
if (bucket_boundaries_[0] == 0) {
|
||||
bucket_boundaries_.erase(bucket_boundaries_.begin());
|
||||
}
|
||||
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||
c_func, pad_info_, pad_to_bucket_boundary_,
|
||||
drop_remainder_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class BucketBatchByLengthNode : public DatasetNode {
|
|||
/// \brief Constructor
|
||||
BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
|
||||
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
|
||||
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
|
||||
std::shared_ptr<TensorOp> element_length_function = nullptr,
|
||||
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
|
||||
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
|
||||
|
||||
|
@ -52,7 +52,7 @@ class BucketBatchByLengthNode : public DatasetNode {
|
|||
std::vector<std::string> column_names_;
|
||||
std::vector<int32_t> bucket_boundaries_;
|
||||
std::vector<int32_t> bucket_batch_sizes_;
|
||||
std::function<TensorRow(TensorRow)> element_length_function_;
|
||||
std::shared_ptr<TensorOp> element_length_function_;
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
|
||||
bool pad_to_bucket_boundary_;
|
||||
bool drop_remainder_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
||||
|
@ -27,7 +28,15 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Function to build ConcatOp
|
||||
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; }
|
||||
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets,
|
||||
const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index)
|
||||
: sampler_(sampler),
|
||||
children_flag_and_nums_(children_flag_and_nums),
|
||||
children_start_end_index_(children_start_end_index) {
|
||||
this->children = datasets;
|
||||
}
|
||||
|
||||
Status ConcatNode::ValidateParams() {
|
||||
if (children.size() < 2) {
|
||||
|
@ -42,14 +51,25 @@ Status ConcatNode::ValidateParams() {
|
|||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if ((children_flag_and_nums_.empty() && !children_start_end_index_.empty()) ||
|
||||
(!children_flag_and_nums_.empty() && children_start_end_index_.empty())) {
|
||||
std::string err_msg = "ConcatNode: children_flag_and_nums and children_start_end_index should be used together";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
|
||||
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
} else {
|
||||
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_,
|
||||
children_start_end_index_));
|
||||
}
|
||||
|
||||
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
@ -29,7 +30,10 @@ namespace dataset {
|
|||
class ConcatNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets);
|
||||
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets,
|
||||
const std::shared_ptr<SamplerObj> &sampler = nullptr,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums = {},
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index = {});
|
||||
|
||||
/// \brief Destructor
|
||||
~ConcatNode() = default;
|
||||
|
@ -41,6 +45,11 @@ class ConcatNode : public DatasetNode {
|
|||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -240,6 +240,7 @@ DatasetNode::DatasetNode() {
|
|||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
connector_que_size_ = cfg->op_connector_size();
|
||||
worker_connector_size_ = cfg->worker_connector_size();
|
||||
build_status = Status::OK(); // remove me after changing return val of Build()
|
||||
}
|
||||
|
||||
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
|
||||
|
@ -254,5 +255,13 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
|
|||
// This method will only be called if its derived class does not implement one.
|
||||
return p->VisitAfter(shared_from_this(), modified);
|
||||
}
|
||||
Status DatasetNode::GetShardId(int32_t *shard_id) {
|
||||
if (!Children().empty()) {
|
||||
// Get shard id from the child node
|
||||
return Children()[0]->GetShardId(shard_id);
|
||||
} else {
|
||||
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node");
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -99,9 +99,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
|
||||
/// \brief Pure virtual function for derived class to get the shard id of specific node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
virtual Status GetShardId(int32_t *shard_id) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
virtual Status GetShardId(int32_t *shard_id);
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
|
@ -126,6 +124,10 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \return Status of the node visit
|
||||
virtual Status AcceptAfter(NodePass *p, bool *modified);
|
||||
|
||||
/// \brief Method to get status from Node.Build()
|
||||
/// \notes Remove me after changing return val of Build()
|
||||
Status BuildStatus() { return build_status; }
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<DatasetNode>> children;
|
||||
std::shared_ptr<DatasetCache> cache_;
|
||||
|
@ -135,6 +137,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
int32_t rows_per_buffer_;
|
||||
int32_t connector_que_size_;
|
||||
int32_t worker_connector_size_;
|
||||
Status build_status; // remove me after changing return val of Build()
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Constructor for FilterNode
|
||||
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate,
|
||||
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
|
||||
std::vector<std::string> input_columns)
|
||||
: predicate_(predicate), input_columns_(input_columns) {
|
||||
this->children.push_back(child);
|
||||
|
@ -38,10 +38,7 @@ std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() {
|
|||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::shared_ptr<TensorOp> c_func;
|
||||
c_func = std::make_shared<CFuncOp>(predicate_);
|
||||
|
||||
node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, c_func));
|
||||
node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace dataset {
|
|||
class FilterNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate,
|
||||
FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
|
||||
std::vector<std::string> input_columns = {});
|
||||
|
||||
/// \brief Destructor
|
||||
|
@ -44,7 +44,7 @@ class FilterNode : public DatasetNode {
|
|||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::function<TensorRow(TensorRow)> predicate_;
|
||||
std::shared_ptr<TensorOp> predicate_;
|
||||
std::vector<std::string> input_columns_;
|
||||
};
|
||||
|
||||
|
|
|
@ -64,7 +64,8 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
|
|||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||
node_ops.push_back(project_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(map_op);
|
||||
return node_ops;
|
||||
|
|
|
@ -59,7 +59,8 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
|
|||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_));
|
||||
build_status = schema->LoadSchemaFile(schema_path_, column_names_);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Argument that is not exposed to user in the API.
|
||||
std::set<std::string> extensions = {};
|
||||
|
|
|
@ -60,7 +60,8 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, usage_, extensions_, std::move(schema),
|
||||
|
|
|
@ -56,7 +56,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
|
|
|
@ -54,7 +54,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
|
|
|
@ -197,18 +197,23 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
|
|||
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(clue_op->Init());
|
||||
|
||||
build_status = clue_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
build_status = ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
|
|
@ -111,7 +111,8 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
|
|||
std::shared_ptr<CocoOp> op =
|
||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(op);
|
||||
|
||||
|
|
|
@ -108,18 +108,23 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
|
|||
std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_,
|
||||
rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
|
||||
num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(csv_op->Init());
|
||||
|
||||
build_status = csv_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows));
|
||||
build_status = CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,25 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
|
|||
const std::vector<DataType> &column_types)
|
||||
: generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {}
|
||||
|
||||
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
|
||||
: generator_function_(generator_function), schema_(schema) {}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
|
||||
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
|
||||
|
||||
if (schema_ != nullptr) {
|
||||
column_names_.clear();
|
||||
column_types_.clear();
|
||||
std::string schema_json_string = schema_->to_json();
|
||||
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, {}));
|
||||
|
||||
for (int32_t i = 0; i < data_schema->NumColumns(); i++) {
|
||||
ColDescriptor col = data_schema->column(i);
|
||||
column_names_.push_back(col.name());
|
||||
column_types_.push_back((col.type()));
|
||||
}
|
||||
}
|
||||
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
|
||||
|
@ -43,6 +61,8 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
|
|||
// This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which
|
||||
// best be delivered when the test cases for this api is ready.
|
||||
Status rc = op->Init();
|
||||
build_status = rc; // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (rc.IsOk()) {
|
||||
node_ops.push_back(op);
|
||||
|
@ -56,5 +76,11 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
|
|||
// no validation is needed for generator op.
|
||||
Status GeneratorNode::ValidateParams() { return Status::OK(); }
|
||||
|
||||
Status GeneratorNode::GetShardId(int32_t *shard_id) {
|
||||
RETURN_UNEXPECTED_IF_NULL(shard_id);
|
||||
*shard_id = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,9 @@ class GeneratorNode : public DatasetNode {
|
|||
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
|
||||
const std::vector<DataType> &column_types);
|
||||
|
||||
/// \brief Constructor
|
||||
GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema);
|
||||
|
||||
/// \brief Destructor
|
||||
~GeneratorNode() = default;
|
||||
|
||||
|
@ -46,10 +49,15 @@ class GeneratorNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node, is always 0 because generator_node doesn't support sharding
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
py::function generator_function_;
|
||||
std::vector<std::string> column_names_;
|
||||
std::vector<DataType> column_types_;
|
||||
std::shared_ptr<SchemaObj> schema_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -62,7 +62,8 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
|
|||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
recursive_, decode_, exts_, class_indexing_, std::move(schema),
|
||||
|
|
|
@ -79,7 +79,8 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
|
|||
manifest_op =
|
||||
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
|
||||
class_index_, std::move(schema), std::move(sampler_->Build()), usage_);
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(manifest_op);
|
||||
|
||||
|
|
|
@ -138,7 +138,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
|
|||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::vector<std::shared_ptr<ShardOperator>> operators_;
|
||||
RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_));
|
||||
build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
std::shared_ptr<MindRecordOp> mindrecord_op;
|
||||
// If pass a string to MindData(), it will be treated as a pattern to search for matched files,
|
||||
|
@ -154,7 +155,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
|
|||
padded_sample_, sample_bytes_);
|
||||
}
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(mindrecord_op->Init());
|
||||
build_status = mindrecord_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
node_ops.push_back(mindrecord_op);
|
||||
|
||||
return node_ops;
|
||||
|
|
|
@ -51,7 +51,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
|
|||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_->Build())));
|
||||
|
|
|
@ -98,7 +98,8 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
|
|||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
|
||||
std::move(data_schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(op);
|
||||
|
||||
|
|
|
@ -78,7 +78,8 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
|
|||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
|
||||
build_status = text_file_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
@ -86,14 +87,17 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
build_status = TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
// Add TextFileOp
|
||||
node_ops.push_back(text_file_op);
|
||||
|
|
|
@ -118,7 +118,8 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
|
|||
std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_,
|
||||
shard_id_, shard_equal_rows_, std::move(sampler_->Build()));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init());
|
||||
build_status = tf_reader_op->Init(); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
@ -127,14 +128,17 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
|
||||
build_status = TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
build_status = AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op);
|
||||
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
// Add TFReaderOp
|
||||
node_ops.push_back(tf_reader_op);
|
||||
|
|
|
@ -106,7 +106,8 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
|
|||
std::shared_ptr<VOCOp> voc_op;
|
||||
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
node_ops.push_back(voc_op);
|
||||
return node_ops;
|
||||
|
|
|
@ -27,9 +27,8 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Constructor for SyncWaitNode
|
||||
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch,
|
||||
py::function callback)
|
||||
: condition_name_(condition_name), num_batch_(num_batch), callback_(callback) {
|
||||
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback)
|
||||
: condition_name_(condition_name), callback_(callback) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
|
@ -38,20 +37,16 @@ std::vector<std::shared_ptr<DatasetOp>> SyncWaitNode::Build() {
|
|||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<BarrierOp>(num_batch_, connector_que_size_, condition_name_, callback_));
|
||||
// Right now barrier should only take num_rows_per_buffer = 1
|
||||
// The reason for this is because having it otherwise can lead to blocking issues
|
||||
// See barrier_op.h for more details
|
||||
int32_t rows_per_buffer = 1;
|
||||
node_ops.push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for SyncWaitNode
|
||||
Status SyncWaitNode::ValidateParams() {
|
||||
if (num_batch_ <= 0) {
|
||||
std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
Status SyncWaitNode::ValidateParams() { return Status::OK(); }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,8 +31,7 @@ namespace dataset {
|
|||
class SyncWaitNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch,
|
||||
py::function callback);
|
||||
SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback);
|
||||
|
||||
/// \brief Destructor
|
||||
~SyncWaitNode() = default;
|
||||
|
@ -47,7 +46,6 @@ class SyncWaitNode : public DatasetNode {
|
|||
|
||||
private:
|
||||
std::string condition_name_;
|
||||
int32_t num_batch_;
|
||||
py::function callback_;
|
||||
};
|
||||
|
||||
|
|
|
@ -18,73 +18,81 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor for TransferNode
|
||||
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end)
|
||||
: prefetch_size_(16), send_epoch_end_(send_epoch_end), total_batch_(0) {
|
||||
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type,
|
||||
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
|
||||
: prefetch_size_(16),
|
||||
queue_name_(std::move(queue_name)),
|
||||
device_type_(std::move(device_type)),
|
||||
send_epoch_end_(send_epoch_end),
|
||||
total_batch_(total_batch),
|
||||
create_data_info_queue_(create_data_info_queue),
|
||||
device_id_(0) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Validator for TransferNode
|
||||
Status TransferNode::ValidateParams() {
|
||||
// Check if device_type_ is in {"CPU", "GPU", "Ascend"}
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"}));
|
||||
if (total_batch_ < 0) {
|
||||
std::string err_msg = "TransferNode: Total batches should be >= 0, value given: ";
|
||||
MS_LOG(ERROR) << err_msg << total_batch_;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build TransferNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
|
||||
// Get a uuid for queue name
|
||||
queue_name_ = Services::GetUniqueID();
|
||||
// TODO(CRC):
|
||||
if (queue_name_.empty()) {
|
||||
// Get a uuid for queue name
|
||||
queue_name_ = Services::GetUniqueID();
|
||||
}
|
||||
if (device_type_.empty()) {
|
||||
auto context = MsContext::GetInstance();
|
||||
if (context == nullptr) {
|
||||
device_type_ = kCPUDevice;
|
||||
} else {
|
||||
device_type_ = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
}
|
||||
}
|
||||
|
||||
// Get device type from ms context
|
||||
device_type_ = "CPU";
|
||||
// Get device ID from children
|
||||
// Convert device_type_ from string to DeviceType
|
||||
DeviceQueueOp::DeviceType type;
|
||||
if (device_type_ == kCPUDevice) {
|
||||
type = DeviceQueueOp::DeviceType::CPU;
|
||||
} else if (device_type_ == kGPUDevice) {
|
||||
type = DeviceQueueOp::DeviceType::GPU;
|
||||
} else if (device_type_ == kAscendDevice) {
|
||||
type = DeviceQueueOp::DeviceType::Ascend;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknown device target.";
|
||||
return {};
|
||||
}
|
||||
|
||||
// Get device ID (shard ID) from children
|
||||
device_id_ = 0;
|
||||
RETURN_EMPTY_IF_ERROR(TransferNode::get_distribution(shared_from_this(), &device_id_));
|
||||
build_status = this->GetShardId(&device_id_); // remove me after changing return val of Build()
|
||||
RETURN_EMPTY_IF_ERROR(build_status);
|
||||
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Convert device_type_ from string to DeviceType
|
||||
DeviceQueueOp::DeviceType type;
|
||||
if (device_type_ == "CPU") {
|
||||
type = DeviceQueueOp::DeviceType::CPU;
|
||||
} else if (device_type_ == "GPU") {
|
||||
type = DeviceQueueOp::DeviceType::GPU;
|
||||
} else if (device_type_ == "Ascend") {
|
||||
type = DeviceQueueOp::DeviceType::Ascend;
|
||||
}
|
||||
node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
||||
total_batch_, false));
|
||||
total_batch_, create_data_info_queue_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to get the device_id
|
||||
Status TransferNode::get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id) {
|
||||
// Get device id according to the type of dataset
|
||||
Status rc = ds->GetShardId(device_id);
|
||||
if (rc != Status::OK()) {
|
||||
// Get device id from the child node
|
||||
if (ds->Children().size()) {
|
||||
ds = ds->Children()[0];
|
||||
return TransferNode::get_distribution(ds, device_id);
|
||||
} else {
|
||||
std::string err_msg = "Unknown dataset type.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,7 +29,8 @@ namespace dataset {
|
|||
class TransferNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end);
|
||||
TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, bool send_epoch_end,
|
||||
int32_t total_batch, bool create_data_info_queue);
|
||||
|
||||
/// \brief Destructor
|
||||
~TransferNode() = default;
|
||||
|
@ -42,8 +43,6 @@ class TransferNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
|
||||
|
||||
private:
|
||||
std::string queue_name_;
|
||||
int32_t device_id_;
|
||||
|
@ -51,6 +50,7 @@ class TransferNode : public DatasetNode {
|
|||
int32_t prefetch_size_;
|
||||
bool send_epoch_end_;
|
||||
int32_t total_batch_;
|
||||
bool create_data_info_queue_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -40,21 +40,7 @@ Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *mo
|
|||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
if (type_ == kOutputShapeAndType) {
|
||||
nodes_to_clear_callback_.push_back(node);
|
||||
} else if (type_ == kDatasetSize) {
|
||||
nodes_to_remove_.push_back(node);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
|
||||
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
|
||||
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
|
||||
nodes_to_clear_callback_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -83,5 +69,6 @@ Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,10 @@ class GetterPass : public TreePass {
|
|||
enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 };
|
||||
/// \brief Constructor
|
||||
explicit GetterPass(GetterType tp) : pass_(tp) {}
|
||||
|
||||
/// \brief default copy Constructor
|
||||
explicit GetterPass(const GetterPass &) = default;
|
||||
|
||||
/// \brief Destructor
|
||||
~GetterPass() = default;
|
||||
|
||||
|
@ -51,11 +55,10 @@ class GetterPass : public TreePass {
|
|||
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override { return Status::OK(); }
|
||||
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
|
||||
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
|
||||
// whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp
|
||||
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
|
||||
|
||||
|
@ -67,7 +70,7 @@ class GetterPass : public TreePass {
|
|||
std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_;
|
||||
std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_;
|
||||
};
|
||||
// outter class needs only to own the inner class object since it automatically has access to its private variables
|
||||
// outer class needs only to own the inner class object since it automatically has access to its private variables
|
||||
GetterNodes pass_;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -19,7 +19,14 @@
|
|||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }
|
||||
Status PythonRuntimeContext::Terminate() {
|
||||
MS_LOG(INFO) << "Terminating a PythonRuntime";
|
||||
if (tree_consumer_ != nullptr) {
|
||||
return TerminateImpl();
|
||||
}
|
||||
MS_LOG(WARNING) << "TreeConsumer was not initialized";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PythonRuntimeContext::TerminateImpl() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
|
||||
|
|
|
@ -22,7 +22,14 @@ namespace mindspore::dataset {
|
|||
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
|
||||
tree_consumer_ = std::move(tree_consumer);
|
||||
}
|
||||
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }
|
||||
Status NativeRuntimeContext::Terminate() {
|
||||
MS_LOG(INFO) << "Terminating a NativeRuntime";
|
||||
if (tree_consumer_ != nullptr) {
|
||||
return TerminateImpl();
|
||||
}
|
||||
MS_LOG(WARNING) << "TreeConsumer was not initialized";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NativeRuntimeContext::TerminateImpl() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");
|
||||
|
|
|
@ -97,6 +97,8 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
|
|||
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
|
||||
// Build the DatasetOp ExecutionTree from the optimized IR tree
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
|
||||
RETURN_IF_NOT_OK(ir->BuildStatus()); // remove me after changing return val of Build()
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node.");
|
||||
|
||||
(*op) = ops.front(); // return the first op to be added as child by the caller of this function
|
||||
|
@ -141,6 +143,8 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
|
|||
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||
|
||||
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
|
||||
|
||||
// Note: We will gradually move the pre pass, optimizer pass, and post pass
|
||||
// on ExecutionTree to perform on IR tree.
|
||||
// Prepare the tree
|
||||
|
@ -149,6 +153,11 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
|
|||
// After the tree is prepared, the col_name_id_map can safely be obtained
|
||||
column_name_map_ = tree_->root()->column_name_id_map();
|
||||
|
||||
// Profiling parameters init
|
||||
cur_batch_num_ = 0;
|
||||
cur_connector_size_ = 0;
|
||||
cur_connector_capacity_ = 0;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -156,21 +165,55 @@ Status TreeAdapter::GetNext(TensorRow *row) {
|
|||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
row->clear(); // make sure row is empty
|
||||
|
||||
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
|
||||
|
||||
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
|
||||
if (cur_db_ == nullptr) {
|
||||
RETURN_IF_NOT_OK(tree_->Launch());
|
||||
// Profiling
|
||||
std::shared_ptr<Tracing> node;
|
||||
Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
|
||||
if (s.IsOk()) {
|
||||
tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
|
||||
}
|
||||
if (tracing_ != nullptr) {
|
||||
cur_connector_size_ = tree_->root()->ConnectorSize();
|
||||
cur_connector_capacity_ = tree_->root()->ConnectorCapacity();
|
||||
}
|
||||
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag
|
||||
RETURN_OK_IF_TRUE(cur_db_->eoe()); // return empty tensor if 1st buf is a ctrl buf (no rows)
|
||||
if (cur_db_->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
|
||||
MS_LOG(INFO) << "End of data iteration.";
|
||||
if (isProfilingEnable) {
|
||||
tree_->SetEpochEnd();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached.");
|
||||
|
||||
if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf
|
||||
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_));
|
||||
RETURN_OK_IF_TRUE(cur_db_->eoe() || cur_db_->eof()); // return empty if this new buffer is a ctrl flag
|
||||
if (cur_db_->eoe()) { // return empty if this new buffer is a ctrl flag
|
||||
MS_LOG(INFO) << "End of data iteration.";
|
||||
if (isProfilingEnable) {
|
||||
tree_->SetEpochEnd();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (cur_db_->eof()) {
|
||||
tree_->SetFinished();
|
||||
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(cur_db_->PopRow(row));
|
||||
// Record profiling info
|
||||
if (tracing_ != nullptr) {
|
||||
cur_batch_num_++;
|
||||
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -60,6 +61,9 @@ class TreeAdapter {
|
|||
// Set optional optimization pass
|
||||
void SetOptimize(bool value) { optimize_ = value; }
|
||||
|
||||
// function to override override the pre-pass
|
||||
void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; }
|
||||
|
||||
// Optional optimizations status
|
||||
bool OptimizationEnabled() const { return optimize_; }
|
||||
|
||||
|
@ -82,9 +86,14 @@ class TreeAdapter {
|
|||
|
||||
std::unique_ptr<DataBuffer> cur_db_;
|
||||
std::unordered_map<std::string, int32_t> column_name_map_;
|
||||
std::unique_ptr<ExecutionTree> tree_;
|
||||
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
|
||||
int32_t num_epochs_;
|
||||
bool optimize_; // Flag to enable optional optimization pass
|
||||
bool optimize_; // Flag to enable optional optimization pass
|
||||
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
|
||||
int32_t cur_batch_num_; // current batch number, used for profiling
|
||||
int32_t cur_connector_size_; // current connector size of root op, used for profiling
|
||||
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
|
||||
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -145,9 +145,16 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \brief Function to transfer data through a device.
|
||||
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
|
||||
/// of data transmission per time is 256M.
|
||||
/// \param[in] queue_name Channel name (default="", create new unique name).
|
||||
/// \param[in] device_type Type of device (default="", get from MSContext).
|
||||
/// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
|
||||
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
|
||||
/// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
|
||||
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
|
||||
/// of data or not(default=false).
|
||||
/// \return Returns true if no error encountered else false.
|
||||
bool DeviceQueue(bool send_epoch_end = true);
|
||||
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t num_epochs = -1,
|
||||
bool send_epoch_end = true, int32_t total_batches = 0, bool create_data_info_queue = false);
|
||||
|
||||
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
|
||||
/// \note Usage restrictions:
|
||||
|
@ -371,21 +378,34 @@ class SchemaObj {
|
|||
|
||||
/// \brief SchemaObj init function
|
||||
/// \return bool true if schema init success
|
||||
bool init();
|
||||
Status init();
|
||||
|
||||
/// \brief Add new column to the schema with unknown shape of rank 1
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(TypeId).
|
||||
/// \return bool true if schema init success
|
||||
Status add_column(std::string name, TypeId de_type);
|
||||
|
||||
/// \brief Add new column to the schema with unknown shape of rank 1
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(std::string).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
Status add_column(std::string name, std::string de_type);
|
||||
|
||||
/// \brief Add new column to the schema
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(TypeId).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
bool add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
|
||||
Status add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
|
||||
|
||||
/// \brief Add new column to the schema
|
||||
/// \param[in] name name of the column.
|
||||
/// \param[in] de_type data type of the column(std::string).
|
||||
/// \param[in] shape shape of the column.
|
||||
/// \return bool true if schema init success
|
||||
bool add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
|
||||
Status add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
|
||||
|
||||
/// \brief Get a JSON string of the schema
|
||||
/// \return JSON string of the schema
|
||||
|
@ -395,25 +415,27 @@ class SchemaObj {
|
|||
std::string to_string() { return to_json(); }
|
||||
|
||||
/// \brief set a new value to dataset_type
|
||||
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = dataset_type; }
|
||||
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); }
|
||||
|
||||
/// \brief set a new value to num_rows
|
||||
inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
|
||||
|
||||
/// \brief get the current num_rows
|
||||
inline int32_t get_num_rows() { return num_rows_; }
|
||||
inline int32_t get_num_rows() const { return num_rows_; }
|
||||
|
||||
Status FromJSONString(const std::string &json_string);
|
||||
|
||||
private:
|
||||
/// \brief Parse the columns and add it to columns
|
||||
/// \param[in] columns dataset attribution information, decoded from schema file.
|
||||
/// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
|
||||
/// \return JSON string of the schema
|
||||
bool parse_column(nlohmann::json columns);
|
||||
Status parse_column(nlohmann::json columns);
|
||||
|
||||
/// \brief Get schema file from json file
|
||||
/// \param[in] json_obj object of json parsed.
|
||||
/// \return bool true if json dump success
|
||||
bool from_json(nlohmann::json json_obj);
|
||||
Status from_json(nlohmann::json json_obj);
|
||||
|
||||
int32_t num_rows_;
|
||||
std::string dataset_type_;
|
||||
|
|
|
@ -61,6 +61,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
|
||||
class DistributedSamplerObj;
|
||||
class PKSamplerObj;
|
||||
class PreBuiltSamplerObj;
|
||||
class RandomSamplerObj;
|
||||
class SequentialSamplerObj;
|
||||
class SubsetRandomSamplerObj;
|
||||
|
@ -171,6 +172,31 @@ class PKSamplerObj : public SamplerObj {
|
|||
int64_t num_samples_;
|
||||
};
|
||||
|
||||
class PreBuiltSamplerObj : public SamplerObj {
|
||||
public:
|
||||
#ifndef ENABLE_ANDROID
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
|
||||
#endif
|
||||
|
||||
~PreBuiltSamplerObj() = default;
|
||||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerRT> sp_;
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
|
||||
#endif
|
||||
};
|
||||
|
||||
class RandomSamplerObj : public SamplerObj {
|
||||
public:
|
||||
RandomSamplerObj(bool replacement, int64_t num_samples);
|
||||
|
|
|
@ -70,6 +70,7 @@ namespace transforms {
|
|||
class ComposeOperation;
|
||||
class DuplicateOperation;
|
||||
class OneHotOperation;
|
||||
class PreBuiltOperation;
|
||||
class RandomApplyOperation;
|
||||
class RandomChoiceOperation;
|
||||
class TypeCastOperation;
|
||||
|
@ -164,6 +165,20 @@ class OneHotOperation : public TensorOperation {
|
|||
float num_classes_;
|
||||
};
|
||||
|
||||
class PreBuiltOperation : public TensorOperation {
|
||||
public:
|
||||
explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op);
|
||||
|
||||
~PreBuiltOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<TensorOp> op_;
|
||||
};
|
||||
|
||||
class RandomApplyOperation : public TensorOperation {
|
||||
public:
|
||||
explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob);
|
||||
|
@ -192,7 +207,6 @@ class RandomChoiceOperation : public TensorOperation {
|
|||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
||||
};
|
||||
|
||||
class TypeCastOperation : public TensorOperation {
|
||||
public:
|
||||
explicit TypeCastOperation(std::string data_type);
|
||||
|
|
|
@ -71,6 +71,15 @@ namespace dataset {
|
|||
return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \
|
||||
} while (false)
|
||||
|
||||
#define RETURN_SECOND_IF_ERROR(_s, _r) \
|
||||
do { \
|
||||
Status __rc = (_s); \
|
||||
if (__rc.IsError()) { \
|
||||
MS_LOG(ERROR) << __rc; \
|
||||
return _r; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
enum class StatusCode : char {
|
||||
kOK = 0,
|
||||
kOutOfMemory = 1,
|
||||
|
|
|
@ -138,7 +138,9 @@ Status Task::Join(WaitFlag blocking) {
|
|||
while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) {
|
||||
// We can't tell which conditional_variable this thread is waiting on. So we may need
|
||||
// to interrupt everything one more time.
|
||||
MS_LOG(INFO) << "Some threads not responding. Interrupt again";
|
||||
std::stringstream ss;
|
||||
ss << get_id();
|
||||
MS_LOG(ERROR) << MyName() << " Thread ID " << ss.str() << " is not responding. Interrupt again";
|
||||
interrupt_svc->InterruptAll();
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -21,7 +21,8 @@ import numpy
|
|||
import mindspore._c_dataengine as cde
|
||||
|
||||
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
|
||||
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load']
|
||||
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load',
|
||||
'get_callback_timeout']
|
||||
|
||||
INT32_MAX = 2147483647
|
||||
UINT32_MAX = 4294967295
|
||||
|
|
|
@ -65,5 +65,7 @@ def mstypelist_to_detypelist(type_list):
|
|||
for index, _ in enumerate(type_list):
|
||||
if type_list[index] is not None:
|
||||
type_list[index] = mstype_to_detype(type_list[index])
|
||||
else:
|
||||
type_list[index] = cde.DataType("")
|
||||
|
||||
return type_list
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -15,17 +15,13 @@
|
|||
"""Built-in iterators.
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
import copy
|
||||
import weakref
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_dataengine import DEPipeline
|
||||
from mindspore._c_dataengine import OpName
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from mindspore import log as logger
|
||||
from . import datasets as de
|
||||
|
||||
|
||||
_ITERATOR_CLEANUP = False
|
||||
|
||||
|
@ -57,29 +53,6 @@ def _cleanup():
|
|||
itr.release()
|
||||
|
||||
|
||||
def alter_tree(node):
|
||||
"""Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
|
||||
if not node.children:
|
||||
return _alter_node(node)
|
||||
|
||||
converted_children = []
|
||||
for input_op in node.children:
|
||||
converted_children.append(alter_tree(input_op))
|
||||
node.children = converted_children
|
||||
return _alter_node(node)
|
||||
|
||||
|
||||
def _alter_node(node):
|
||||
"""DEPRECATED"""
|
||||
# Please check ccsrc/dataset/engine/opt for tree transformation.
|
||||
if isinstance(node, de.MapDataset):
|
||||
if node.python_multiprocessing:
|
||||
# Bootstrap can only be performed on a copy of the original dataset node.
|
||||
# Bootstrap on original dataset node will make all iterators share the same process pool
|
||||
node.iterator_bootstrap()
|
||||
return node
|
||||
|
||||
|
||||
class Iterator:
|
||||
"""
|
||||
General Iterator over a dataset.
|
||||
|
@ -89,185 +62,62 @@ class Iterator:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, num_epochs=-1, output_numpy=False):
|
||||
self.num_epochs = num_epochs
|
||||
self.output_numpy = output_numpy
|
||||
self._col_names = None
|
||||
|
||||
# create a copy of tree and work on it.
|
||||
self.ori_dataset = dataset
|
||||
|
||||
self.ir_tree, self.dataset = dataset.create_ir_tree()
|
||||
|
||||
self._runtime_context = cde.PythonRuntimeContext()
|
||||
self._runtime_context.Init()
|
||||
consumer = cde.PythonIteratorConsumer(num_epochs)
|
||||
consumer.Init(self.ir_tree)
|
||||
self._runtime_context.AssignConsumer(consumer)
|
||||
self._iterator = self._runtime_context.GetConsumer()
|
||||
|
||||
self._transform_tensor = lambda t: t.as_array()
|
||||
if not output_numpy:
|
||||
self._transform_tensor = lambda t: Tensor(t.as_array())
|
||||
self._index = 0
|
||||
|
||||
# todo remove next when ContextManager is done
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
_unset_iterator_cleanup()
|
||||
# create a copy of tree and work on it.
|
||||
self.dataset = copy.deepcopy(dataset)
|
||||
self.ori_dataset = dataset
|
||||
self.parent_subtree = []
|
||||
#######
|
||||
|
||||
# The dataset passed into the iterator is not the root of the tree.
|
||||
# Trim the tree by saving the parent subtree into self.parent_subtree and
|
||||
# restore it after launching our C++ pipeline.
|
||||
if self.dataset.parent:
|
||||
logger.info("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
|
||||
self.parent_subtree = self.dataset.parent
|
||||
self.dataset.parent = []
|
||||
|
||||
self.dataset = alter_tree(self.dataset)
|
||||
if not self.__is_tree():
|
||||
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers).")
|
||||
self.depipeline = DEPipeline()
|
||||
|
||||
# for manifest temporary use
|
||||
self.__batch_node(self.dataset, 0)
|
||||
|
||||
root = self.__convert_node_postorder(self.dataset)
|
||||
self.depipeline.AssignRootNode(root)
|
||||
self.depipeline.PrepareTree(self.num_epochs)
|
||||
self._index = 0
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Manually terminate Python iterator instead of relying on out of scope destruction.
|
||||
"""
|
||||
logger.info("Terminating Python iterator. This will also terminate C++ pipeline.")
|
||||
if hasattr(self, 'depipeline') and self.depipeline:
|
||||
del self.depipeline
|
||||
|
||||
def __is_tree_node(self, node):
|
||||
"""Check if a node is tree node."""
|
||||
if not node.children:
|
||||
if len(node.parent) > 1:
|
||||
return False
|
||||
|
||||
if len(node.parent) > 1:
|
||||
return False
|
||||
|
||||
for input_node in node.children:
|
||||
cls = self.__is_tree_node(input_node)
|
||||
if not cls:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __is_tree(self):
|
||||
return self.__is_tree_node(self.dataset)
|
||||
|
||||
@staticmethod
|
||||
def __get_dataset_type(dataset):
|
||||
"""Get the dataset type."""
|
||||
op_type = None
|
||||
if isinstance(dataset, de.ShuffleDataset):
|
||||
op_type = OpName.SHUFFLE
|
||||
elif isinstance(dataset, de.MindDataset):
|
||||
op_type = OpName.MINDRECORD
|
||||
elif isinstance(dataset, de.BatchDataset):
|
||||
op_type = OpName.BATCH
|
||||
elif isinstance(dataset, de.BucketBatchByLengthDataset):
|
||||
op_type = OpName.BUCKETBATCH
|
||||
elif isinstance(dataset, de.SyncWaitDataset):
|
||||
op_type = OpName.BARRIER
|
||||
elif isinstance(dataset, de.ZipDataset):
|
||||
op_type = OpName.ZIP
|
||||
elif isinstance(dataset, de.ConcatDataset):
|
||||
op_type = OpName.CONCAT
|
||||
elif isinstance(dataset, de.MapDataset):
|
||||
op_type = OpName.MAP
|
||||
elif isinstance(dataset, de.FilterDataset):
|
||||
op_type = OpName.FILTER
|
||||
elif isinstance(dataset, de.RepeatDataset):
|
||||
op_type = OpName.REPEAT
|
||||
elif isinstance(dataset, de.SkipDataset):
|
||||
op_type = OpName.SKIP
|
||||
elif isinstance(dataset, de.TakeDataset):
|
||||
op_type = OpName.TAKE
|
||||
elif isinstance(dataset, de.ImageFolderDataset):
|
||||
op_type = OpName.IMAGEFOLDER
|
||||
elif isinstance(dataset, de.GeneratorDataset):
|
||||
op_type = OpName.GENERATOR
|
||||
elif isinstance(dataset, de.TransferDataset):
|
||||
op_type = OpName.DEVICEQUEUE
|
||||
elif isinstance(dataset, de.RenameDataset):
|
||||
op_type = OpName.RENAME
|
||||
elif isinstance(dataset, de.TFRecordDataset):
|
||||
op_type = OpName.TFREADER
|
||||
elif isinstance(dataset, de.ProjectDataset):
|
||||
op_type = OpName.PROJECT
|
||||
elif isinstance(dataset, de.MnistDataset):
|
||||
op_type = OpName.MNIST
|
||||
elif isinstance(dataset, de.ManifestDataset):
|
||||
op_type = OpName.MANIFEST
|
||||
elif isinstance(dataset, de.VOCDataset):
|
||||
op_type = OpName.VOC
|
||||
elif isinstance(dataset, de.CocoDataset):
|
||||
op_type = OpName.COCO
|
||||
elif isinstance(dataset, de.Cifar10Dataset):
|
||||
op_type = OpName.CIFAR10
|
||||
elif isinstance(dataset, de.Cifar100Dataset):
|
||||
op_type = OpName.CIFAR100
|
||||
elif isinstance(dataset, de.CelebADataset):
|
||||
op_type = OpName.CELEBA
|
||||
elif isinstance(dataset, de.RandomDataset):
|
||||
op_type = OpName.RANDOMDATA
|
||||
elif isinstance(dataset, de.TextFileDataset):
|
||||
op_type = OpName.TEXTFILE
|
||||
elif isinstance(dataset, de.BuildVocabDataset):
|
||||
op_type = OpName.BUILDVOCAB
|
||||
elif isinstance(dataset, de.BuildSentencePieceVocabDataset):
|
||||
op_type = OpName.SENTENCEPIECEVOCAB
|
||||
elif isinstance(dataset, de.CLUEDataset):
|
||||
op_type = OpName.CLUE
|
||||
elif isinstance(dataset, de.CSVDataset):
|
||||
op_type = OpName.CSV
|
||||
else:
|
||||
raise ValueError("Unsupported DatasetOp.")
|
||||
|
||||
return op_type
|
||||
|
||||
# Convert Python node into C node and add to C layer execution tree in postorder traversal.
|
||||
def __convert_node_postorder(self, node):
|
||||
self.check_node_type(node)
|
||||
op_type = self.__get_dataset_type(node)
|
||||
c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
|
||||
|
||||
for py_child in node.children:
|
||||
c_child = self.__convert_node_postorder(py_child)
|
||||
self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"])
|
||||
|
||||
return c_nodes["top"]
|
||||
|
||||
def __batch_node(self, dataset, level):
|
||||
"""Recursively get batch node in the dataset tree."""
|
||||
if isinstance(dataset, de.BatchDataset):
|
||||
return
|
||||
for input_op in dataset.children:
|
||||
self.__batch_node(input_op, level + 1)
|
||||
|
||||
@staticmethod
|
||||
def __print_local(dataset, level):
|
||||
"""Recursively print the name and address of nodes in the dataset tree."""
|
||||
name = dataset.__class__.__name__
|
||||
ptr = hex(id(dataset))
|
||||
for _ in range(level):
|
||||
logger.info("\t", end='')
|
||||
if not dataset.children:
|
||||
logger.info("-%s (%s)", name, ptr)
|
||||
else:
|
||||
logger.info("+%s (%s)", name, ptr)
|
||||
for input_op in dataset.children:
|
||||
Iterator.__print_local(input_op, level + 1)
|
||||
|
||||
def print(self):
|
||||
"""Print the dataset tree"""
|
||||
self.__print_local(self.dataset, 0)
|
||||
if hasattr(self, '_runtime_context') and self._runtime_context:
|
||||
if hasattr(self, '_iterator') and self._iterator:
|
||||
self._runtime_context.Terminate()
|
||||
del self._iterator
|
||||
del self._runtime_context
|
||||
del self.dataset
|
||||
|
||||
def release(self):
|
||||
if hasattr(self, 'depipeline') and self.depipeline:
|
||||
del self.depipeline
|
||||
self.stop()
|
||||
|
||||
def __del__(self):
|
||||
self.release()
|
||||
|
||||
@abstractmethod
|
||||
def get_next(self):
|
||||
def _get_next(self):
|
||||
raise RuntimeError("Calling base class Iterator's get_next is invalid.")
|
||||
|
||||
def __next__(self):
|
||||
if not self.depipeline:
|
||||
if not self._runtime_context:
|
||||
logger.warning("Iterator does not have a running C++ pipeline." +
|
||||
"It might because Iterator stop() had been called, or C++ pipeline crashed silently.")
|
||||
raise RuntimeError("Iterator does not have a running C++ pipeline.")
|
||||
|
||||
data = self.get_next()
|
||||
data = self._get_next()
|
||||
if not data:
|
||||
if self._index == 0:
|
||||
logger.warning("No records available.")
|
||||
|
@ -277,100 +127,56 @@ class Iterator:
|
|||
self._index += 1
|
||||
return data
|
||||
|
||||
@abstractmethod
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def get_output_shapes(self):
|
||||
return [t for t in self.depipeline.GetOutputShapes()]
|
||||
|
||||
def get_output_types(self):
|
||||
return [t for t in self.depipeline.GetOutputTypes()]
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self.depipeline.GetDatasetSize()
|
||||
|
||||
def get_batch_size(self):
|
||||
return self.depipeline.GetBatchSize()
|
||||
|
||||
def get_repeat_count(self):
|
||||
return self.depipeline.GetRepeatCount()
|
||||
|
||||
def num_classes(self):
|
||||
return self.depipeline.GetNumClasses()
|
||||
|
||||
def get_col_names(self):
|
||||
return self.depipeline.GetColumnNames()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self
|
||||
|
||||
def _getters(self):
|
||||
"""
|
||||
Get pipeline information.
|
||||
"""
|
||||
getter = cde.TreeGetters()
|
||||
getter.Init(self.ir_tree)
|
||||
self._runtime_context.AssignConsumer(getter)
|
||||
self._col_names = getter.GetColumnNames()
|
||||
|
||||
class SaveOp(Iterator):
|
||||
"""
|
||||
The derived class of Iterator with dict type.
|
||||
"""
|
||||
def __init__(self, dataset, num_epochs=-1):
|
||||
super().__init__(dataset, num_epochs)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def get_next(self):
|
||||
pass
|
||||
|
||||
def check_node_type(self, node):
|
||||
if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
|
||||
logger.warning("Used shuffle, repeat, batch before save operator.")
|
||||
|
||||
def save(self, file_names, file_type):
|
||||
return self.depipeline.SaveDataset(file_names, file_type)
|
||||
def get_col_names(self):
|
||||
"""
|
||||
Get names of the columns in the dataset
|
||||
"""
|
||||
if self._col_names is None:
|
||||
self._getters()
|
||||
return self._col_names
|
||||
|
||||
|
||||
class DictIterator(Iterator):
|
||||
"""
|
||||
The derived class of Iterator with dict type.
|
||||
"""
|
||||
def __init__(self, dataset, num_epochs=-1, output_numpy=False):
|
||||
super().__init__(dataset, num_epochs, output_numpy)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def get_next(self):
|
||||
def _get_next(self):
|
||||
"""
|
||||
Returns the next record in the dataset as dictionary
|
||||
|
||||
Returns:
|
||||
Dict, the next record in the dataset.
|
||||
"""
|
||||
|
||||
if self.output_numpy:
|
||||
return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()}
|
||||
return {k: Tensor(v.as_array()) for k, v in self.depipeline.GetNextAsMap().items()}
|
||||
return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()}
|
||||
|
||||
|
||||
class TupleIterator(Iterator):
|
||||
"""
|
||||
The derived class of Iterator with list type.
|
||||
"""
|
||||
def check_node_type(self, node):
|
||||
pass
|
||||
|
||||
def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False):
|
||||
if columns is not None:
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
# todo: move next to IR
|
||||
dataset = dataset.project(columns)
|
||||
super().__init__(dataset, num_epochs, output_numpy)
|
||||
self.depipeline.LaunchTreeExec()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def get_next(self):
|
||||
def _get_next(self):
|
||||
"""
|
||||
Returns the next record in the dataset as a list
|
||||
|
||||
|
@ -378,15 +184,14 @@ class TupleIterator(Iterator):
|
|||
List, the next record in the dataset.
|
||||
"""
|
||||
|
||||
if self.output_numpy:
|
||||
return [t.as_array() for t in self.depipeline.GetNextAsList()]
|
||||
return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()]
|
||||
return [self._transform_tensor(t) for t in self._iterator.GetNextAsList()]
|
||||
|
||||
|
||||
class DummyIterator:
|
||||
"""
|
||||
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, mode):
|
||||
self.mode = mode
|
||||
self.shapes = dataset.output_shapes()
|
||||
|
|
|
@ -283,9 +283,12 @@ def create_node(node):
|
|||
node.get('shard_id'), sampler)
|
||||
|
||||
elif dataset_op == 'TFRecordDataset':
|
||||
shuffle = node.get('shuffle')
|
||||
if shuffle is not None and isinstance(shuffle, str):
|
||||
shuffle = de.Shuffle(shuffle)
|
||||
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
|
||||
node.get('num_samples'), node.get('num_parallel_workers'),
|
||||
de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id'))
|
||||
shuffle, node.get('num_shards'), node.get('shard_id'))
|
||||
|
||||
elif dataset_op == 'ManifestDataset':
|
||||
sampler = construct_sampler(node.get('sampler'))
|
||||
|
|
|
@ -293,14 +293,38 @@ def check_save(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_iterator(method):
|
||||
def check_tuple_iterator(method):
|
||||
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
[columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
nreq_param_bool = ['output_numpy']
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
if num_epochs is not None:
|
||||
type_check(num_epochs, (int,), "num_epochs")
|
||||
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
||||
|
||||
if columns is not None:
|
||||
check_columns(columns, "column_names")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_dict_iterator(method):
|
||||
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
nreq_param_bool = ['output_numpy']
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
if num_epochs is not None:
|
||||
type_check(num_epochs, (int,), "num_epochs")
|
||||
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
@ -523,6 +547,8 @@ def check_batch(method):
|
|||
sig = ins.signature(batch_size)
|
||||
if len(sig.parameters) != 1:
|
||||
raise ValueError("callable batch_size should take one parameter (BatchInfo).")
|
||||
else:
|
||||
check_pos_int32(int(batch_size), "batch_size")
|
||||
|
||||
if num_parallel_workers is not None:
|
||||
check_num_parallel_workers(num_parallel_workers)
|
||||
|
@ -807,6 +833,21 @@ def check_project(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_schema(method):
|
||||
"""check the input arguments of Schema.__init__."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[schema_file], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if schema_file is not None:
|
||||
type_check(schema_file, (str,), "schema_file")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_add_column(method):
|
||||
"""check the input arguments of add_column."""
|
||||
|
||||
|
@ -1261,3 +1302,23 @@ def check_cache_option(cache):
|
|||
"""Sanity check for cache parameter"""
|
||||
if cache is not None:
|
||||
type_check(cache, (cache_client.DatasetCache,), "cache")
|
||||
|
||||
|
||||
def check_to_device_send(method):
|
||||
"""A wrapper that wraps a parameter checker around the check_to_device_send."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[num_epochs], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if num_epochs is not None:
|
||||
type_check(num_epochs, (int,), "num_epochs")
|
||||
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def replace_none(value, default):
|
||||
return value if value is not None else default
|
||||
|
|
|
@ -18,13 +18,13 @@ use to_bytes and to_str to encode and decode strings into a specified format.
|
|||
"""
|
||||
from enum import IntEnum
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \
|
||||
check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Vocab", "SentencePieceVocab", "to_str", "to_bytes"
|
||||
]
|
||||
|
@ -39,8 +39,7 @@ class Vocab(cde.Vocab):
|
|||
|
||||
@classmethod
|
||||
@check_from_dataset
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
|
||||
special_first=True):
|
||||
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, special_first=True):
|
||||
"""
|
||||
Build a vocab from a dataset.
|
||||
|
||||
|
@ -69,21 +68,7 @@ class Vocab(cde.Vocab):
|
|||
Returns:
|
||||
Vocab, Vocab object built from dataset.
|
||||
"""
|
||||
|
||||
vocab = Vocab()
|
||||
if columns is None:
|
||||
columns = []
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
if freq_range is None:
|
||||
freq_range = (None, None)
|
||||
if special_tokens is None:
|
||||
special_tokens = []
|
||||
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
|
||||
for d in root.create_dict_iterator(num_epochs=1):
|
||||
if d is not None:
|
||||
raise ValueError("from_dataset should receive data other than None.")
|
||||
return vocab
|
||||
return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first)
|
||||
|
||||
@classmethod
|
||||
@check_from_list
|
||||
|
@ -143,6 +128,7 @@ class SentencePieceVocab(cde.SentencePieceVocab):
|
|||
"""
|
||||
SentencePiece obiect that is used to segmentate words
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@check_from_dataset_sentencepiece
|
||||
def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params):
|
||||
|
@ -164,13 +150,8 @@ class SentencePieceVocab(cde.SentencePieceVocab):
|
|||
SentencePiece, SentencePiece object from dataset.
|
||||
"""
|
||||
|
||||
vocab = SentencePieceVocab()
|
||||
root = copy.deepcopy(dataset).build_sentencepiece_vocab(vocab, col_names, vocab_size, character_coverage,
|
||||
model_type, params)
|
||||
for d in root.create_dict_iterator(num_epochs=1):
|
||||
if d is None:
|
||||
raise ValueError("from_dataset should receive data other than None.")
|
||||
return vocab
|
||||
return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage,
|
||||
DE_C_INTER_SENTENCEPIECE_MODE[model_type], params)
|
||||
|
||||
@classmethod
|
||||
@check_from_file_sentencepiece
|
||||
|
@ -270,6 +251,7 @@ class SentencePieceModel(IntEnum):
|
|||
CHAR = 2
|
||||
WORD = 3
|
||||
|
||||
|
||||
DE_C_INTER_SENTENCEPIECE_MODE = {
|
||||
SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM,
|
||||
SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE,
|
||||
|
|
|
@ -432,7 +432,7 @@ def check_from_dataset_sentencepiece(method):
|
|||
[_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
if col_names is not None:
|
||||
type_check(col_names, (list,), "col_names")
|
||||
type_check_list(col_names, (str,), "col_names")
|
||||
|
||||
if vocab_size is not None:
|
||||
check_uint32(vocab_size, "vocab_size")
|
||||
|
|
|
@ -146,6 +146,7 @@ if (BUILD_MINDDATA STREQUAL "full")
|
|||
|
||||
list(REMOVE_ITEM MINDDATA_ENGINE_IR_CACHE_SRC_FILES
|
||||
"${MINDDATA_DIR}/engine/ir/cache/dataset_cache_impl.cc"
|
||||
"${MINDDATA_DIR}/engine/ir/cache/pre_built_dataset_cache.cc"
|
||||
)
|
||||
|
||||
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -123,6 +123,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||
return network
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
|
||||
|
@ -197,7 +198,6 @@ class DatasetHelper:
|
|||
def get_data_info(self):
|
||||
return self.iter.get_data_info()
|
||||
|
||||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset helper"""
|
||||
|
||||
|
@ -331,7 +331,6 @@ class _DatasetIterPSLite(_DatasetIter):
|
|||
|
||||
class _DatasetIterNormal:
|
||||
"""Iter for normal(non sink) mode, feed the data from host."""
|
||||
|
||||
def __init__(self, dataset, epoch_num=-1):
|
||||
self.dataset = dataset
|
||||
self.device_num = _get_device_num()
|
||||
|
|
|
@ -61,15 +61,15 @@ class MindData:
|
|||
def send(self, num_epochs=-1):
|
||||
pass
|
||||
|
||||
def get_data_info(self):
|
||||
pass
|
||||
|
||||
def stop_send(self):
|
||||
pass
|
||||
|
||||
def continue_send(self):
|
||||
pass
|
||||
|
||||
def get_data_info(self):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return self._size
|
||||
|
||||
|
|
|
@ -177,8 +177,8 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) {
|
|||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
// 5 batches of size 2
|
||||
EXPECT_EQ(i, 5);
|
||||
// With 2 boundaries, 3 buckets are created
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
|
|
@ -132,6 +132,6 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
|
|||
// verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not
|
||||
EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos);
|
||||
EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos);
|
||||
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
|
|||
const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}};
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {2, 2, 0, 0};
|
||||
std::vector<size_t> row_sizes = {2, 2, 0};
|
||||
|
||||
TensorRow row;
|
||||
for (size_t sz : row_sizes) {
|
||||
|
@ -75,7 +75,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
|
|||
rc = tree_adapter.GetNext(&row);
|
||||
EXPECT_TRUE(rc.IsError());
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
|
||||
|
@ -97,7 +97,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
|
|||
const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap();
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0, 0};
|
||||
std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0};
|
||||
|
||||
TensorRow row;
|
||||
for (size_t sz : row_sizes) {
|
||||
|
@ -107,7 +107,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
|
|||
}
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
|
||||
|
@ -135,7 +135,7 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
|
|||
const std::unordered_map<std::string, int32_t> map = {{"label", 0}};
|
||||
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
|
||||
|
||||
std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0, 0};
|
||||
std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0};
|
||||
TensorRow row;
|
||||
|
||||
for (size_t sz : row_sizes) {
|
||||
|
@ -145,5 +145,5 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
|
|||
}
|
||||
rc = tree_adapter.GetNext(&row);
|
||||
const std::string err_msg = rc.ToString();
|
||||
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
|
||||
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
|
||||
}
|
||||
|
|
|
@ -451,6 +451,10 @@ def test_batch_exception_13():
|
|||
|
||||
|
||||
def test_batch_exception_14():
|
||||
"""
|
||||
Test per_batch_map and input column name
|
||||
"""
|
||||
logger.info("test_batch_exception_14")
|
||||
batch_size = 2
|
||||
input_columns = ["num"]
|
||||
data1 = ds.TFRecordDataset(DATA_DIR)
|
||||
|
@ -460,6 +464,22 @@ def test_batch_exception_14():
|
|||
assert "per_batch_map and input_columns need to be passed in together." in str(e)
|
||||
|
||||
|
||||
def test_batch_exception_15():
|
||||
"""
|
||||
Test batch_size = int32 max value + 1
|
||||
"""
|
||||
logger.info("test_batch_exception_15")
|
||||
batch_size = 2147483647 + 1
|
||||
input_columns = ["num"]
|
||||
data1 = ds.TFRecordDataset(DATA_DIR)
|
||||
err_msg = ""
|
||||
try:
|
||||
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
|
||||
except ValueError as e:
|
||||
err_msg = str(e)
|
||||
assert "batch_size is not within the required interval of (1 to 2147483647)" in err_msg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_batch_01()
|
||||
test_batch_02()
|
||||
|
@ -486,4 +506,5 @@ if __name__ == '__main__':
|
|||
test_batch_exception_12()
|
||||
test_batch_exception_13()
|
||||
test_batch_exception_14()
|
||||
test_batch_exception_15()
|
||||
logger.info('\n')
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
|
@ -354,6 +355,18 @@ def test_clue_to_device():
|
|||
data.send()
|
||||
|
||||
|
||||
def test_clue_invalid_files():
|
||||
"""
|
||||
Test CLUE with invalid files
|
||||
"""
|
||||
AFQMC_DIR = '../data/dataset/testCLUE/afqmc'
|
||||
afqmc_train_json = os.path.join(AFQMC_DIR)
|
||||
with pytest.raises(ValueError) as info:
|
||||
_ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False)
|
||||
assert "The following patterns did not match any files" in str(info.value)
|
||||
assert AFQMC_DIR in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_clue()
|
||||
test_clue_num_shards()
|
||||
|
@ -366,3 +379,4 @@ if __name__ == "__main__":
|
|||
test_clue_tnews()
|
||||
test_clue_wsc()
|
||||
test_clue_to_device()
|
||||
test_clue_invalid_files()
|
||||
|
|
|
@ -195,6 +195,19 @@ def test_csv_dataset_size():
|
|||
assert data.get_dataset_size() == 5
|
||||
|
||||
|
||||
def test_csv_dataset_type_error():
|
||||
TEST_FILE = '../data/dataset/testCSV/exception.csv'
|
||||
data = ds.CSVDataset(
|
||||
TEST_FILE,
|
||||
column_defaults=["", 0, "", ""],
|
||||
column_names=['col1', 'col2', 'col3', 'col4'],
|
||||
shuffle=False)
|
||||
with pytest.raises(Exception) as err:
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert "type does not match" in str(err.value)
|
||||
|
||||
|
||||
def test_csv_dataset_exception():
|
||||
TEST_FILE = '../data/dataset/testCSV/exception.csv'
|
||||
data = ds.CSVDataset(
|
||||
|
@ -208,17 +221,16 @@ def test_csv_dataset_exception():
|
|||
assert "failed to parse file" in str(err.value)
|
||||
|
||||
|
||||
def test_csv_dataset_type_error():
|
||||
TEST_FILE = '../data/dataset/testCSV/exception.csv'
|
||||
def test_csv_dataset_duplicate_columns():
|
||||
data = ds.CSVDataset(
|
||||
TEST_FILE,
|
||||
column_defaults=["", 0, "", ""],
|
||||
column_names=['col1', 'col2', 'col3', 'col4'],
|
||||
DATA_FILE,
|
||||
column_defaults=["1", "2", "3", "4"],
|
||||
column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'],
|
||||
shuffle=False)
|
||||
with pytest.raises(Exception) as err:
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
assert "type does not match" in str(err.value)
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
_ = data.create_dict_iterator(num_epochs=1, output_numpy=True)
|
||||
assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value)
|
||||
assert "column_names" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -234,5 +246,6 @@ if __name__ == "__main__":
|
|||
test_csv_dataset_header()
|
||||
test_csv_dataset_number()
|
||||
test_csv_dataset_size()
|
||||
test_csv_dataset_exception()
|
||||
test_csv_dataset_type_error()
|
||||
test_csv_dataset_exception()
|
||||
test_csv_dataset_duplicate_columns()
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
|
||||
IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train"
|
||||
IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
|
||||
|
@ -21,9 +22,18 @@ IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-000
|
|||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
|
||||
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
|
||||
MNIST_DATA_DIR = "../data/dataset/testMnistData"
|
||||
MIND_CV_FILE_NAME = "../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord"
|
||||
SCHEMA_FILE = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
|
||||
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data"
|
||||
CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data"
|
||||
VOC_DATA_DIR = "../data/dataset/testVOC2012"
|
||||
COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
|
||||
ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
|
||||
CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
|
||||
CLUE_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
CSV_FILE = '../data/dataset/testCSV/1.csv'
|
||||
TEXT_DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
|
||||
|
||||
|
||||
def test_imagenet_rawdata_dataset_size():
|
||||
|
@ -50,8 +60,15 @@ def test_imagenet_tf_file_dataset_size():
|
|||
ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0)
|
||||
assert ds_shard_2_0.get_dataset_size() == 6
|
||||
|
||||
# FIXME: dataset_size == 6 looks wrong but seem it aims to match the current code.
|
||||
# Correct answer should be 12/3=4, the code issue should be addressed.
|
||||
ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0)
|
||||
assert ds_shard_3_0.get_dataset_size() == 4
|
||||
assert ds_shard_3_0.get_dataset_size() == 6
|
||||
|
||||
count = 0
|
||||
for _ in ds_shard_3_0.create_dict_iterator():
|
||||
count += 1
|
||||
assert ds_shard_3_0.get_dataset_size() == count
|
||||
|
||||
|
||||
def test_mnist_dataset_size():
|
||||
|
@ -76,6 +93,14 @@ def test_mnist_dataset_size():
|
|||
assert ds_shard_3_0.get_dataset_size() == 3334
|
||||
|
||||
|
||||
def test_mind_dataset_size():
|
||||
dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0")
|
||||
assert dataset.get_dataset_size() == 20
|
||||
|
||||
dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 10
|
||||
|
||||
|
||||
def test_manifest_dataset_size():
|
||||
ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE)
|
||||
assert ds_total.get_dataset_size() == 4
|
||||
|
@ -95,10 +120,11 @@ def test_cifar10_dataset_size():
|
|||
assert ds_total.get_dataset_size() == 10000
|
||||
|
||||
# test get_dataset_size with usage flag
|
||||
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 0
|
||||
train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 10000
|
||||
test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size()
|
||||
assert test_size == 0
|
||||
|
||||
all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size()
|
||||
assert all_size == 10000
|
||||
|
||||
|
@ -120,8 +146,6 @@ def test_cifar100_dataset_size():
|
|||
assert ds_total.get_dataset_size() == 10000
|
||||
|
||||
# test get_dataset_size with usage flag
|
||||
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
|
||||
assert train_size == 0
|
||||
test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size()
|
||||
assert test_size == 10000
|
||||
all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size()
|
||||
|
@ -137,10 +161,97 @@ def test_cifar100_dataset_size():
|
|||
assert ds_shard_3_0.get_dataset_size() == 3334
|
||||
|
||||
|
||||
def test_voc_dataset_size():
|
||||
dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
|
||||
assert dataset.get_dataset_size() == 10
|
||||
|
||||
dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True,
|
||||
num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 5
|
||||
|
||||
|
||||
def test_coco_dataset_size():
|
||||
dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
|
||||
decode=True, shuffle=False)
|
||||
assert dataset.get_dataset_size() == 6
|
||||
|
||||
dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
|
||||
shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 3
|
||||
|
||||
|
||||
def test_celeba_dataset_size():
|
||||
dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
|
||||
assert dataset.get_dataset_size() == 4
|
||||
|
||||
dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_clue_dataset_size():
|
||||
dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
|
||||
dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_csv_dataset_size():
|
||||
dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
|
||||
shuffle=False)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
|
||||
dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
|
||||
shuffle=False, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_text_file_dataset_size():
|
||||
dataset = ds.TextFileDataset(TEXT_DATA_FILE)
|
||||
assert dataset.get_dataset_size() == 3
|
||||
|
||||
dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0)
|
||||
assert dataset_shard_2_0.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_padded_dataset_size():
|
||||
dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}])
|
||||
assert dataset.get_dataset_size() == 2
|
||||
|
||||
|
||||
def test_pipeline_get_dataset_size():
|
||||
dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False)
|
||||
assert dataset.get_dataset_size() == 12
|
||||
|
||||
dataset = dataset.shuffle(buffer_size=3)
|
||||
assert dataset.get_dataset_size() == 12
|
||||
|
||||
decode_op = vision.Decode()
|
||||
resize_op = vision.RandomResize(10)
|
||||
|
||||
dataset = dataset.map([decode_op, resize_op], input_columns=["image"])
|
||||
assert dataset.get_dataset_size() == 12
|
||||
|
||||
dataset = dataset.batch(batch_size=3)
|
||||
assert dataset.get_dataset_size() == 4
|
||||
|
||||
dataset = dataset.repeat(count=2)
|
||||
assert dataset.get_dataset_size() == 8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_imagenet_rawdata_dataset_size()
|
||||
test_imagenet_tf_file_dataset_size()
|
||||
test_mnist_dataset_size()
|
||||
test_mind_dataset_size()
|
||||
test_manifest_dataset_size()
|
||||
test_cifar10_dataset_size()
|
||||
test_cifar100_dataset_size()
|
||||
test_voc_dataset_size()
|
||||
test_coco_dataset_size()
|
||||
test_celeba_dataset_size()
|
||||
test_clue_dataset_size()
|
||||
test_csv_dataset_size()
|
||||
test_text_file_dataset_size()
|
||||
test_padded_dataset_size()
|
||||
test_pipeline_get_dataset_size()
|
||||
|
|
|
@ -521,7 +521,7 @@ def test_chained_sampler_04():
|
|||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 24
|
||||
assert data1_size == 6
|
||||
|
||||
# Verify number of iterations
|
||||
num_iter = 0
|
||||
|
|
|
@ -182,6 +182,15 @@ def test_voc_exception():
|
|||
pass
|
||||
|
||||
|
||||
def test_voc_num_classes():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
|
||||
assert data1.num_classes() is None
|
||||
|
||||
class_index = {'car': 0, 'cat': 1, 'train': 5}
|
||||
data2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True)
|
||||
assert data2.num_classes() is None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_voc_segmentation()
|
||||
test_voc_detection()
|
||||
|
@ -191,3 +200,4 @@ if __name__ == '__main__':
|
|||
test_case_1()
|
||||
test_case_2()
|
||||
test_voc_exception()
|
||||
test_voc_num_classes()
|
||||
|
|
|
@ -107,7 +107,7 @@ def test_decode_op():
|
|||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
assert "object has no attribute '_runtime_context'" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
iter2.__next__()
|
||||
|
@ -205,7 +205,7 @@ def test_generator_dict_3():
|
|||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
assert "object has no attribute '_runtime_context'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_dict_4():
|
||||
|
@ -396,7 +396,7 @@ def test_generator_tuple_3():
|
|||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
assert "object has no attribute '_runtime_context'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_tuple_4():
|
||||
|
@ -546,7 +546,7 @@ def test_generator_tuple_repeat_repeat_2():
|
|||
# Expect a AttributeError since iter1 has been stopped.
|
||||
with pytest.raises(AttributeError) as info:
|
||||
iter1.__next__()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
assert "object has no attribute '_runtime_context'" in str(info.value)
|
||||
|
||||
|
||||
def test_generator_tuple_repeat_repeat_3():
|
||||
|
|
|
@ -74,9 +74,11 @@ def test_case2():
|
|||
|
||||
|
||||
def test_case3():
|
||||
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
|
||||
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(5)
|
||||
data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2)
|
||||
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename(
|
||||
["col_sint64"], ["a1"])
|
||||
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(5).rename(
|
||||
["col_sint64"], ["a2"])
|
||||
data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).rename(["col_sint64"], ["a3"])
|
||||
|
||||
data4 = ds.zip((data1, data2, data3))
|
||||
|
||||
|
@ -84,8 +86,9 @@ def test_case3():
|
|||
|
||||
|
||||
def test_case4():
|
||||
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
|
||||
data2 = ds.TFRecordDataset(FILES)
|
||||
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename(
|
||||
["col_sint64"], ["a1"])
|
||||
data2 = ds.TFRecordDataset(FILES, columns_list=["col_sint64"]).rename(["col_sint64"], ["a2"])
|
||||
assert data2.get_dataset_size() == 12
|
||||
data2 = data2.batch(2)
|
||||
assert data2.get_dataset_size() == 6
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue