!8594 [MD] Pybind Pushdown Support for dataset

From: @cathwong
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-26 05:55:49 +08:00 committed by Gitee
commit adc8e3e707
113 changed files with 3254 additions and 4629 deletions

View File

@ -2,9 +2,11 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_PYTHON)
add_library(APItoPython OBJECT
python/de_pipeline.cc
python/pybind_register.cc
python/bindings.cc
python/pybind_conversion.cc
python/bindings/dataset/include/datasets_bindings.cc
python/bindings/dataset/include/iterator_bindings.cc
python/bindings/dataset/include/schema_bindings.cc
python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/core/bindings.cc
python/bindings/dataset/callback/bindings.cc

View File

@ -115,7 +115,8 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
#ifndef ENABLE_ANDROID
// Function to return a transferred Node that transfers data through a device.
bool Dataset::DeviceQueue(bool send_epoch_end) {
bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32_t num_epochs, bool send_epoch_end,
int32_t total_batches, bool create_data_info_queue) {
Status rc;
// Build and launch tree
@ -126,11 +127,12 @@ bool Dataset::DeviceQueue(bool send_epoch_end) {
return false;
}
// Add TransferNode IR on top of dataset d
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end);
// Add TransferNode IR on top of dataset
auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), queue_name, device_type, send_epoch_end,
total_batches, create_data_info_queue);
// Get ToDevice consumer
auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1);
auto consumer = std::make_unique<ToDevice>(num_epochs);
ToDevice *consumer_ = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
@ -199,127 +201,55 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
int64_t Dataset::GetDatasetSize() {
int64_t dataset_size;
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetDatasetSize(&dataset_size);
return rc.IsError() ? -1 : dataset_size;
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1);
return dataset_size;
}
std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types;
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
return types;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
return types;
}
rc = tree_getters_->GetOutputTypes(&types);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed.";
types.clear();
return types;
}
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputTypes(&types), {});
return types;
}
std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes;
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
return shapes;
}
rc = tree_getters_->Init(this->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
return shapes;
}
rc = tree_getters_->GetOutputShapes(&shapes);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed.";
shapes.clear();
return shapes;
}
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputShapes(&shapes), {});
return shapes;
}
int64_t Dataset::GetNumClasses() {
int64_t num_classes;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetNumClasses(&num_classes);
return rc.IsError() ? -1 : num_classes;
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->GetNumClasses(&num_classes), -1);
return num_classes;
}
std::vector<std::string> Dataset::GetColumnNames() {
std::vector<std::string> col_names;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed.";
return std::vector<std::string>();
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed.";
return std::vector<std::string>();
}
rc = tree_getters_->GetColumnNames(&col_names);
return rc.IsError() ? std::vector<std::string>() : col_names;
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(tree_getters_->GetColumnNames(&col_names), {});
return col_names;
}
std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed.";
return output_class_indexing;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed.";
return output_class_indexing;
}
rc = tree_getters_->GetClassIndexing(&output_class_indexing);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed.";
output_class_indexing.clear();
return output_class_indexing;
}
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
RETURN_SECOND_IF_ERROR(tree_getters_->GetClassIndexing(&output_class_indexing), {});
return output_class_indexing;
}
@ -501,9 +431,13 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset(
std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) {
auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function, pad_info,
pad_to_bucket_boundary, drop_remainder);
std::shared_ptr<TensorOp> c_func = nullptr;
if (element_length_function != nullptr) {
c_func = std::make_shared<CFuncOp>(element_length_function);
}
auto ds =
std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, bucket_batch_sizes,
c_func, pad_info, pad_to_bucket_boundary, drop_remainder);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
@ -522,7 +456,9 @@ ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datase
FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate,
std::vector<std::string> input_columns) {
auto ds = std::make_shared<FilterNode>(input->IRNode(), predicate, input_columns);
std::shared_ptr<TensorOp> c_func = nullptr;
if (predicate) c_func = std::make_shared<CFuncOp>(predicate);
auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, input_columns);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
@ -604,40 +540,20 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
#endif
int64_t Dataset::GetBatchSize() {
int64_t batch_size;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetBatchSize(&batch_size);
return rc.IsError() ? -1 : batch_size;
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1);
return batch_size;
}
int64_t Dataset::GetRepeatCount() {
int64_t repeat_count;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds->IRNode());
if (rc.IsError()) {
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetRepeatCount(&repeat_count);
return rc.IsError() ? 0 : repeat_count;
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0);
RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0);
return repeat_count;
}
std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
@ -720,62 +636,65 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}
// SchemaObj init function
bool SchemaObj::init() {
if (schema_file_ != "") {
Status SchemaObj::init() {
if (!schema_file_.empty()) {
Path schema_file(schema_file_);
if (!schema_file.Exists()) {
MS_LOG(ERROR) << "The file " << schema_file << " does not exist or permission denied!";
return false;
}
CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(),
"The file " + schema_file_ + " does not exist or permission denied!");
nlohmann::json js;
try {
std::ifstream in(schema_file_);
in >> js;
if (js.find("columns") == js.end()) {
MS_LOG(ERROR) << "\"columns\" node is required in the schema json file.";
return false;
}
CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
"\"columns\" node is required in the schema json file.");
} catch (const std::exception &err) {
MS_LOG(ERROR) << "Schema file failed to load";
return false;
RETURN_STATUS_SYNTAX_ERROR("Schema file failed to load");
}
return from_json(js);
}
return true;
return Status::OK();
}
// Function to add a column to schema with a mstype de_type
bool SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) {
nlohmann::json new_column;
new_column["name"] = name;
// if de_type is mstype
// Function to add a column to schema with a mstype de_type and known shape
Status SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) {
DataType data_type = dataset::MSTypeToDEType(de_type);
new_column["type"] = data_type.ToString();
if (shape.size() > 0) {
new_column["shape"] = shape;
new_column["rank"] = shape.size();
} else {
new_column["rank"] = 1;
}
columns_.push_back(new_column);
return true;
return add_column(name, data_type.ToString(), shape);
}
// Function to add a column to schema with a string de_type
bool SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) {
// Function to add a column to schema with a string de_type and known shape
Status SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) {
DataType data_type(de_type);
CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
nlohmann::json new_column;
new_column["name"] = name;
DataType data_type(de_type);
new_column["type"] = data_type.ToString();
if (shape.size() > 0) {
new_column["shape"] = shape;
new_column["rank"] = shape.size();
} else {
new_column["rank"] = 1;
}
new_column["shape"] = shape;
new_column["rank"] = shape.size();
columns_.push_back(new_column);
return true;
return Status::OK();
}
// Function to add a column to schema with a mstype de_type and without shape
Status SchemaObj::add_column(std::string name, TypeId de_type) {
DataType data_type = dataset::MSTypeToDEType(de_type);
return add_column(name, data_type.ToString());
}
// Function to add a column to schema with a string de_type and without shape
Status SchemaObj::add_column(std::string name, std::string de_type) {
DataType data_type(de_type);
CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown.");
nlohmann::json new_column;
new_column["name"] = name;
new_column["type"] = data_type.ToString();
new_column["rank"] = 1;
columns_.push_back(new_column);
return Status::OK();
}
std::string SchemaObj::to_json() {
@ -792,7 +711,7 @@ std::string SchemaObj::to_json() {
return json_file.dump(2);
}
bool SchemaObj::parse_column(nlohmann::json columns) {
Status SchemaObj::parse_column(nlohmann::json columns) {
std::string name, de_type;
std::vector<int32_t> shape;
@ -802,15 +721,13 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
for (auto column : columns) {
auto key_name = column.find("name");
if (key_name == column.end()) {
MS_LOG(ERROR) << "Column's name is missing";
return false;
RETURN_STATUS_SYNTAX_ERROR("Column's name is missing");
}
name = *key_name;
auto key_type = column.find("type");
if (key_type == column.end()) {
MS_LOG(ERROR) << "Column's type is missing";
return false;
RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
}
de_type = *key_type;
@ -819,17 +736,14 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
if (key_shape != column.end()) {
shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
}
if (!add_column(name, de_type, shape)) {
return false;
}
RETURN_IF_NOT_OK(add_column(name, de_type, shape));
}
} else if (columns.type() == nlohmann::json::value_t::object) {
for (const auto &it_child : columns.items()) {
name = it_child.key();
auto key_type = it_child.value().find("type");
if (key_type == it_child.value().end()) {
MS_LOG(ERROR) << "Column's type is missing";
return false;
RETURN_STATUS_SYNTAX_ERROR("Column's type is missing");
}
de_type = *key_type;
@ -839,43 +753,45 @@ bool SchemaObj::parse_column(nlohmann::json columns) {
shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
}
if (!add_column(name, de_type, shape)) {
return false;
}
RETURN_IF_NOT_OK(add_column(name, de_type, shape));
}
} else {
MS_LOG(ERROR) << "columns must be dict or list, columns contain name, type, shape(optional).";
return false;
RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional).");
}
return true;
return Status::OK();
}
bool SchemaObj::from_json(nlohmann::json json_obj) {
Status SchemaObj::from_json(nlohmann::json json_obj) {
for (const auto &it_child : json_obj.items()) {
if (it_child.key() == "datasetType") {
dataset_type_ = it_child.value();
} else if (it_child.key() == "numRows") {
num_rows_ = it_child.value();
} else if (it_child.key() == "columns") {
if (!parse_column(it_child.value())) {
MS_LOG(ERROR) << "parse columns failed";
return false;
}
RETURN_IF_NOT_OK(parse_column(it_child.value()));
} else {
MS_LOG(ERROR) << "Unknown field " << it_child.key();
return false;
RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key());
}
}
if (columns_.empty()) {
MS_LOG(ERROR) << "Columns are missing.";
return false;
RETURN_STATUS_SYNTAX_ERROR("Columns are missing.");
}
if (num_rows_ <= 0) {
MS_LOG(ERROR) << "numRows must be greater than 0";
return false;
if (num_rows_ < 0) {
RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0");
}
return true;
return Status::OK();
}
Status SchemaObj::FromJSONString(const std::string &json_string) {
try {
nlohmann::json js = nlohmann::json::parse(json_string);
CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(),
"\"columns\" node is required in the schema json JSON.");
RETURN_IF_NOT_OK(from_json(js));
} catch (const std::exception &err) {
RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse");
}
return Status::OK();
}
// OTHER FUNCTIONS

View File

@ -1,136 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/api/python/de_pipeline.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
DEPipeline, 0, ([](const py::module *m) {
(void)py::class_<DEPipeline>(*m, "DEPipeline")
.def(py::init<>())
.def(
"AddNodeToTree",
[](DEPipeline &de, const OpName &op_name, const py::dict &args) {
py::dict out;
THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
return out;
},
py::return_value_policy::reference)
.def_static("AddChildToParentNode",
[](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
})
.def("AssignRootNode",
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
.def("SetBatchParameters",
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
.def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); })
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
.def("GetColumnNames",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetColumnNames(&out));
return out;
})
.def("GetNextAsMap",
[](DEPipeline &de) {
py::dict out;
THROW_IF_ERROR(de.GetNextAsMap(&out));
return out;
})
.def("GetNextAsList",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetNextAsList(&out));
return out;
})
.def("GetOutputShapes",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetOutputShapes(&out));
return out;
})
.def("GetOutputTypes",
[](DEPipeline &de) {
py::list out;
THROW_IF_ERROR(de.GetOutputTypes(&out));
return out;
})
.def("GetDataInfo",
[](DEPipeline &de) {
py::list types, shapes;
THROW_IF_ERROR(de.GetDataInfo(&types, &shapes));
return py::make_tuple(types, shapes);
})
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
.def("GetBatchSize", &DEPipeline::GetBatchSize)
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
.def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;
});
}));
PYBIND_REGISTER(OpName, 0, ([](const py::module *m) {
(void)py::enum_<OpName>(*m, "OpName", py::arithmetic())
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BUCKETBATCH", OpName::kBucketBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)
.value("REPEAT", OpName::kRepeat)
.value("SKIP", OpName::kSkip)
.value("TAKE", OpName::kTake)
.value("ZIP", OpName::kZip)
.value("CONCAT", OpName::kConcat)
.value("MAP", OpName::kMap)
.value("FILTER", OpName::kFilter)
.value("DEVICEQUEUE", OpName::kDeviceQueue)
.value("GENERATOR", OpName::kGenerator)
.export_values()
.value("RENAME", OpName::kRename)
.value("TFREADER", OpName::kTfReader)
.value("PROJECT", OpName::kProject)
.value("IMAGEFOLDER", OpName::kImageFolder)
.value("MNIST", OpName::kMnist)
.value("MANIFEST", OpName::kManifest)
.value("VOC", OpName::kVoc)
.value("COCO", OpName::kCoco)
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData)
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile)
.value("EPOCHCTRL", OpName::kEpochCtrl)
.value("CSV", OpName::kCsv)
.value("CLUE", OpName::kClue);
}));
} // namespace dataset
} // namespace mindspore

View File

@ -19,8 +19,10 @@
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/client.h" // DE client
#include "minddata/dataset/util/status.h"
#include "pybind11/numpy.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/api/python/de_pipeline.h"
namespace mindspore {
namespace dataset {

View File

@ -0,0 +1,551 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/datasets.h"
// IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
// IR non-leaf nodes - for android
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
#endif
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/services.h"
// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
// IR leaf nodes disabled for android
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
#endif
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
.def("SetNumWorkers",
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) {
return num_workers ? self->SetNumWorkers(*num_workers) : self;
})
.def(
"Zip",
[](std::shared_ptr<DatasetNode> self, py::list datasets) {
auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets)));
THROW_IF_ERROR(zip->ValidateParams());
return zip;
},
py::arg("datasets"));
}));
// PYBIND FOR LEAF NODES
// (In alphabetical order)
PYBIND_REGISTER(
CelebANode, 2, ([](const py::module *m) {
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", "to create a CelebANode")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, bool decode,
std::optional<py::list> extensions, std::optional<std::shared_ptr<CacheClient>> cc) {
auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
toStringSet(extensions), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(celebA->ValidateParams());
return celebA;
}));
}));
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
"to create a Cifar10Node")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(cifar10->ValidateParams());
return cifar10;
}));
}));
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
"to create a Cifar100Node")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(cifar100->ValidateParams());
return cifar100;
}));
}));
PYBIND_REGISTER(
CLUENode, 2, ([](const py::module *m) {
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode")
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<CLUENode> clue_node =
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(clue_node->ValidateParams());
return clue_node;
}));
}));
PYBIND_REGISTER(
CocoNode, 2, ([](const py::module *m) {
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", "to create a CocoNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, bool decode,
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(coco->ValidateParams());
return coco;
}));
}));
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults),
column_names, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(csv->ValidateParams());
return csv;
}));
}));
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
*m, "GeneratorNode", "to create a GeneratorNode")
.def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types) {
auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types);
THROW_IF_ERROR(gen->ValidateParams());
return gen;
}))
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema) {
auto gen = std::make_shared<GeneratorNode>(generator_function, schema);
THROW_IF_ERROR(gen->ValidateParams());
return gen;
}));
}));
PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
*m, "ImageFolderNode", "to create an ImageFolderNode")
.def(py::init([](std::string dataset_dir, bool decode, std::optional<py::handle> sampler,
std::optional<py::list> extensions, std::optional<py::dict> class_indexing,
std::optional<std::shared_ptr<CacheClient>> cc) {
bool recursive = true;
auto imagefolder = std::make_shared<ImageFolderNode>(
dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions),
toStringMap(class_indexing), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(imagefolder->ValidateParams());
return imagefolder;
}));
}));
PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
"to create a ManifestNode")
.def(py::init([](std::string dataset_file, std::string usage, std::optional<py::handle> sampler,
std::optional<py::dict> class_indexing, bool decode,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
toStringMap(class_indexing), decode,
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(manifest->ValidateParams());
return manifest;
}));
}));
PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
"to create a MindDataNode")
.def(py::init([](std::string dataset_file, std::optional<py::list> columns_list,
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
nlohmann::json padded_sample_json;
std::map<std::string, std::string> sample_bytes;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
auto minddata =
std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list),
toSamplerObj(sampler, true), padded_sample_json, num_padded);
minddata->SetSampleBytes(&sample_bytes);
THROW_IF_ERROR(minddata->ValidateParams());
return minddata;
}))
.def(py::init([](py::list dataset_file, std::optional<py::list> columns_list,
std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) {
nlohmann::json padded_sample_json;
std::map<std::string, std::string> sample_bytes;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
auto minddata =
std::make_shared<MindDataNode>(toStringVector(dataset_file), toStringVector(columns_list),
toSamplerObj(sampler, true), padded_sample_json, num_padded);
minddata->SetSampleBytes(&sample_bytes);
THROW_IF_ERROR(minddata->ValidateParams());
return minddata;
}));
}));
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
"to create an MnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(mnist->ValidateParams());
return mnist;
}));
}));
PYBIND_REGISTER(
RandomNode, 2, ([](const py::module *m) {
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode")
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto random_node =
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(random_node->ValidateParams());
return random_node;
}))
.def(py::init([](int32_t total_rows, std::string schema, std::optional<py::list> columns_list,
std::optional<std::shared_ptr<CacheClient>> cc) {
auto random_node =
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(random_node->ValidateParams());
return random_node;
}));
}));
PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
"to create a TextFileNode")
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>(
toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id,
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(textfile_node->ValidateParams());
return textfile_node;
}));
}));
PYBIND_REGISTER(
TFRecordNode, 2, ([](const py::module *m) {
(void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode",
"to create a TFRecordNode")
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list,
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
std::optional<int32_t> shard_id, bool shard_equal_rows,
std::optional<std::shared_ptr<CacheClient>> cc) {
if (!num_samples) {
*num_samples = 0;
}
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord;
}))
.def(py::init([](py::list dataset_files, std::string schema, std::optional<py::list> columns_list,
std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards,
std::optional<int32_t> shard_id, bool shard_equal_rows,
std::optional<std::shared_ptr<CacheClient>> cc) {
if (!num_samples) {
*num_samples = 0;
}
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle),
*num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc)));
THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord;
}));
}));
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
.def(
py::init([](std::string dataset_dir, std::string task, std::string usage,
std::optional<py::dict> class_indexing, bool decode,
std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) {
std::shared_ptr<VOCNode> voc =
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
toSamplerObj(sampler), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(voc->ValidateParams());
return voc;
}));
}));
// PYBIND FOR NON-LEAF NODES
// (In alphabetical order)
PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) {
(void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode",
"to create a BatchNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t batch_size, bool drop_remainder,
bool pad, py::list in_col_names, py::list out_col_names, py::list col_order,
py::object size_obj, py::object map_obj, py::dict pad_info) {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
if (pad) {
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
}
py::function size_func =
py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function();
py::function map_func =
py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function();
auto batch = std::make_shared<BatchNode>(
self, batch_size, drop_remainder, pad, toStringVector(in_col_names),
toStringVector(out_col_names), toStringVector(col_order), size_func, map_func, c_pad_info);
THROW_IF_ERROR(batch->ValidateParams());
return batch;
}));
}));
PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) {
(void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>(
*m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode")
.def(py::init([](std::shared_ptr<DatasetNode> dataset, py::list column_names,
std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes,
py::object element_length_function, py::dict pad_info, bool pad_to_bucket_boundary,
bool drop_remainder) {
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info;
THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info));
auto bucket_batch = std::make_shared<BucketBatchByLengthNode>(
dataset, toStringVector(column_names), bucket_boundaries, bucket_batch_sizes,
toPyFuncOp(std::move(element_length_function), DataType::DE_INT32), c_pad_info,
pad_to_bucket_boundary, drop_remainder);
THROW_IF_ERROR(bucket_batch->ValidateParams());
return bucket_batch;
}),
py::arg("dataset"), py::arg("column_names"), py::arg("bucket_boundaries"),
py::arg("bucket_batch_sizes"), py::arg("element_length_function") = py::none(),
py::arg("pad_info"), py::arg("pad_to_bucket_boundary"), py::arg("drop_remainder"));
}));
PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) {
(void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>(
*m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<SentencePieceVocab> vocab,
const std::vector<std::string> &col_names, uint32_t vocab_size,
float character_coverage, SentencePieceModel model_type,
const std::unordered_map<std::string, std::string> &params) {
auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>(
self, vocab, col_names, vocab_size, character_coverage, model_type, params);
THROW_IF_ERROR(build_sentence_vocab->ValidateParams());
return build_sentence_vocab;
}));
}));
PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) {
(void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>(
*m, "BuildVocabNode", "to create a BuildVocabNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<Vocab> vocab, py::list columns,
py::tuple freq_range, int64_t top_k, py::list special_tokens, bool special_first) {
auto build_vocab =
std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range),
top_k, toStringVector(special_tokens), special_first);
THROW_IF_ERROR(build_vocab->ValidateParams());
return build_vocab;
}));
}));
PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) {
(void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode",
"to create a ConcatNode")
.def(
py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, std::optional<py::handle> sampler,
py::list children_flag_and_nums, py::list children_start_end_index) {
auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler),
toPairVector(children_flag_and_nums),
toPairVector(children_start_end_index));
THROW_IF_ERROR(concat->ValidateParams());
return concat;
}));
}));
PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
(void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode",
"to create a FilterNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::object predicate,
std::vector<std::string> input_columns) {
auto filter =
std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns);
THROW_IF_ERROR(filter->ValidateParams());
return filter;
}));
}));
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> operations,
std::optional<py::list> input_columns, std::optional<py::list> output_columns,
std::optional<py::list> project_columns,
std::optional<std::shared_ptr<CacheClient>> cc,
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) {
auto map = std::make_shared<MapNode>(
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)),
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()));
THROW_IF_ERROR(map->ValidateParams());
return map;
}));
}));
PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) {
(void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode",
"to create a ProjectNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list columns) {
auto project = std::make_shared<ProjectNode>(self, toStringVector(columns));
THROW_IF_ERROR(project->ValidateParams());
return project;
}));
}));
PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) {
(void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode",
"to create a RenameNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> input_columns,
std::optional<py::list> output_columns) {
auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns),
toStringVector(output_columns));
THROW_IF_ERROR(rename->ValidateParams());
return rename;
}));
}));
PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) {
(void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode",
"to create a RepeatNode")
.def(py::init([](std::shared_ptr<DatasetNode> input, int32_t count) {
auto repeat = std::make_shared<RepeatNode>(input, count);
THROW_IF_ERROR(repeat->ValidateParams());
return repeat;
}));
}));
PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) {
(void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode",
"to create a ShuffleNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t shuffle_size, bool reset_every_epoch) {
auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch);
THROW_IF_ERROR(shuffle->ValidateParams());
return shuffle;
}));
}));
PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) {
(void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode",
"to create a SkipNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
auto skip = std::make_shared<SkipNode>(self, count);
THROW_IF_ERROR(skip->ValidateParams());
return skip;
}));
}));
PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) {
(void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode",
"to create a SyncWaitNode")
.def(
py::init([](std::shared_ptr<DatasetNode> self, std::string condition_name, py::object callback) {
py::function callback_func =
py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function();
auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback);
THROW_IF_ERROR(sync_wait->ValidateParams());
return sync_wait;
}));
}));
PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) {
(void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode",
"to create a TakeNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) {
auto take = std::make_shared<TakeNode>(self, count);
THROW_IF_ERROR(take->ValidateParams());
return take;
}));
}));
PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) {
(void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode",
"to create a TransferNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type,
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) {
auto transfer = std::make_shared<TransferNode>(self, queue_name, device_type, send_epoch_end,
total_batch, create_data_info_queue);
THROW_IF_ERROR(transfer->ValidateParams());
return transfer;
}));
}));
PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
(void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode")
.def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets) {
auto zip = std::make_shared<ZipNode>(datasets);
THROW_IF_ERROR(zip->ValidateParams());
return zip;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,168 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/engine/python_runtime_context.h"
#include "minddata/dataset/engine/consumers/python_tree_consumer.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer");
}));
PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) {
(void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>(
*m, "PythonIteratorConsumer")
.def(py::init<int32_t>())
.def("Init", [](PythonIteratorConsumer &self,
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("GetNextAsMap",
[](PythonIteratorConsumer &self) {
py::dict output;
THROW_IF_ERROR(self.GetNextAsDict(&output));
return output;
})
.def("GetNextAsList", [](PythonIteratorConsumer &self) {
py::list output;
THROW_IF_ERROR(self.GetNextAsList(&output));
return output;
});
}));
PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) {
(void)py::class_<PythonTreeGetters, TreeConsumer, std::shared_ptr<PythonTreeGetters>>(*m,
"TreeGetters")
.def(py::init<>())
.def("Init",
[](PythonTreeGetters &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("GetOutputShapes",
[](PythonTreeGetters &self) {
std::vector<TensorShape> shapes;
THROW_IF_ERROR(self.GetOutputShapes(&shapes));
return shapesToListOfShape(shapes);
})
.def("GetOutputTypes",
[](PythonTreeGetters &self) {
std::vector<DataType> types;
THROW_IF_ERROR(self.GetOutputTypes(&types));
return typesToListOfType(types);
})
.def("GetNumClasses",
[](PythonTreeGetters &self) {
int64_t num_classes;
THROW_IF_ERROR(self.GetNumClasses(&num_classes));
return num_classes;
})
.def("GetRepeatCount",
[](PythonTreeGetters &self) {
int64_t repeat_count;
THROW_IF_ERROR(self.GetRepeatCount(&repeat_count));
return repeat_count;
})
.def("GetBatchSize",
[](PythonTreeGetters &self) {
int64_t batch_size;
THROW_IF_ERROR(self.GetBatchSize(&batch_size));
return batch_size;
})
.def("GetColumnNames",
[](PythonTreeGetters &self) {
std::vector<std::string> col_names;
THROW_IF_ERROR(self.GetColumnNames(&col_names));
return col_names;
})
.def("GetClassIndexing",
[](PythonTreeGetters &self) {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing));
return output_class_indexing;
})
.def("GetDatasetSize",
[](PythonTreeGetters &self) {
int64_t dataset_size;
THROW_IF_ERROR(self.GetDatasetSize(&dataset_size));
return dataset_size;
})
.def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; });
}));
PYBIND_REGISTER(PythonRuntimeContext, 2, ([](const py::module *m) {
(void)py::class_<PythonRuntimeContext, std::shared_ptr<PythonRuntimeContext>>(*m,
"PythonRuntimeContext")
.def(py::init<>())
.def("Init", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Init()); })
.def("AssignConsumer", &PythonRuntimeContext::AssignConsumer)
.def("Terminate", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Terminate()); })
.def("GetConsumer", &PythonRuntimeContext::GetPythonConsumer, py::return_value_policy::reference)
.def("__deepcopy__", [](py::object &runtime_context, py::dict memo) { return runtime_context; });
}));
PYBIND_REGISTER(PythonBuildVocabConsumer, 1, ([](const py::module *m) {
(void)py::class_<PythonBuildVocabConsumer, TreeConsumer, std::shared_ptr<PythonBuildVocabConsumer>>(
*m, "PythonBuildVocabConsumer")
.def(py::init<>())
.def("Init", [](PythonBuildVocabConsumer &self,
std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("Start", [](PythonBuildVocabConsumer &self) { THROW_IF_ERROR(self.Start()); });
}));
PYBIND_REGISTER(ToDevice, 1, ([](const py::module *m) {
(void)py::class_<ToDevice, TreeConsumer, std::shared_ptr<ToDevice>>(*m, "ToDevice")
.def(py::init<int32_t>())
.def("Init", [](ToDevice &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); })
.def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); })
.def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); })
.def("GetDataInfo",
[](ToDevice &self) {
std::vector<DataType> types_c;
std::vector<TensorShape> shapes_c;
{
py::gil_scoped_release rel;
THROW_IF_ERROR(self.GetDataInfo(&types_c, &shapes_c));
}
py::list types, shapes;
for (auto el : types_c) {
types.append(el.AsNumpyType());
py::list shape;
}
for (auto el : shapes_c) {
py::list shape = el.AsPyList();
shapes.append(shape);
}
return py::make_tuple(types, shapes);
})
.def("__deepcopy__", [](py::object &to_device, py::dict memo) { return to_device; });
}));
PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) {
(void)py::class_<PythonSaveToDisk, TreeConsumer, std::shared_ptr<PythonSaveToDisk>>(
*m, "PythonSaveToDisk")
.def(py::init([](std::string &dataset_path, int32_t numFiles, std::string &datasetType) {
auto save = std::make_shared<PythonSaveToDisk>(dataset_path, numFiles, datasetType);
THROW_IF_ERROR(save->ValidateParams());
return save;
}))
.def("Init",
[](PythonSaveToDisk &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); })
.def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); });
}));
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(
SchemaObj, 0, ([](const py::module *m) {
(void)py::class_<SchemaObj, std::shared_ptr<SchemaObj>>(*m, "SchemaObj", "to create a SchemaObj")
.def(py::init([](std::string schema_file) {
auto schema = std::make_shared<SchemaObj>(schema_file);
THROW_IF_ERROR(schema->init());
return schema;
}))
.def("add_column", [](SchemaObj &self, std::string name, TypeId de_type,
std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); })
.def("add_column", [](SchemaObj &self, std::string name, std::string de_type,
std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); })
.def("add_column",
[](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
.def("add_column", [](SchemaObj &self, std::string name,
std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
.def("to_json", &SchemaObj::to_json)
.def("to_string", &SchemaObj::to_string)
.def("from_string",
[](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.FromJSONString(json_string)); })
.def("set_dataset_type", [](SchemaObj &self, std::string dataset_type) { self.set_dataset_type(dataset_type); })
.def("set_num_rows", [](SchemaObj &self, int32_t num_rows) { self.set_num_rows(num_rows); })
.def("get_num_rows", &SchemaObj::get_num_rows)
.def("__deepcopy__", [](py::object &schema, py::dict memo) { return schema; });
}));
} // namespace dataset
} // namespace mindspore

View File

@ -17,7 +17,6 @@
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/api/python/de_pipeline.h"
#include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h"
#include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h"

File diff suppressed because it is too large Load Diff

View File

@ -1,265 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_
#include <iostream>
#include <map>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "minddata/dataset/core/client.h" // DE client
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/util/status.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace mindspore {
namespace dataset {
using json = nlohmann::json;
using DsOpPtr = std::shared_ptr<DatasetOp>;
class CacheClient;
// enum for the dataset operator names
enum OpName {
kShuffle,
kMindrecord,
kBatch,
kBucketBatch,
kBarrier,
kCache,
kRepeat,
kSkip,
kTake,
kZip,
kConcat,
kMap,
kFilter,
kDeviceQueue,
kGenerator,
kRename,
kTfReader,
kProject,
kImageFolder,
kMnist,
kManifest,
kVoc,
kCoco,
kCifar10,
kCifar100,
kCelebA,
kRandomData,
kTextFile,
kBuildVocab,
kClue,
kEpochCtrl,
kSentencePieceVocab,
kCsv
};
// The C++ binder class that we expose to the python script.
class DEPipeline {
public:
DEPipeline();
~DEPipeline();
// Function to add a Node to the Execution Tree.
Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output);
// Function to add a child and parent relationship.
static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op);
// Function to assign the node as root.
Status AssignRootNode(const DsOpPtr &dataset_op);
// Function to get the column names in the last node in the tree in order
Status GetColumnNames(py::list *output);
// Function to prepare the tree for execution
Status PrepareTree(const int32_t num_epochs);
// Function to launch the tree execution.
Status LaunchTreeExec();
// Get a row of data as dictionary of column name to the value.
Status GetNextAsMap(py::dict *output);
// Get a row of data as list.
Status GetNextAsList(py::list *output);
Status GetOutputShapes(py::list *output);
Status GetOutputTypes(py::list *output);
Status GetDataInfo(py::list *types, py::list *shapes);
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
int GetDatasetSize() const;
int GetBatchSize() const;
int GetRepeatCount() const;
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
template <typename T, typename S>
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert = false);
Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map,
const TensorRow &row, json *schema, std::vector<std::string> *index_fields);
Status FetchDataFromTensorRow(const TensorRow &row,
const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data,
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data);
Status BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded);
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);
Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseProjectOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseImageFolderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseManifestOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
void PrintTree();
int32_t GetNumClasses() const;
Status ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status SetBatchParameters(const py::dict &args);
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status StopSend();
Status ContinueSend();
Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
private:
// Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_;
std::unique_ptr<DatasetIterator> iterator_;
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
/// \brief Helper function to inject a cache operator over top of the current operation being built.
/// \param[in] cache_client The client to use for caching
/// \param[in] num_workers The number of workers to use in the cache op
/// \param[in] input_op The operator to build the cache on top of
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the cache operator
/// \return Status return code
Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *cache_op);
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
/// \param[in] shuffle_size The size to use in the shuffle buffer
/// \param[in] input_op The operator to build shuffle on top of
/// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the shuffle operator
/// \return Status return code
Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *shuffle_op);
/// \brief Helper function to compute the shuffle size
/// \param[in] num_files The number of files in the dataset
/// \param[in] num_devices The number of devices in the dataset
/// \param[in] num_rows The number of rows in the dataset
/// \param[in] total_rows An upper bound on the total rows in the dataset
/// \param[out] shuffle_size The resultant computed shuffle size
/// \return Status return code
Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int64_t *shuffle_size);
int batch_size_;
int repeat_num_;
int num_rows_;
int num_classes_;
int temp_batch_size_;
bool temp_drop_remainder_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_

View File

@ -0,0 +1,265 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/api/python/pybind_conversion.h"
namespace mindspore {
namespace dataset {
float toFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
int toInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
int64_t toInt64(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
bool toBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
std::string toString(const py::handle &handle) { return py::reinterpret_borrow<py::str>(handle); }
std::set<std::string> toStringSet(const std::optional<py::list> list) {
std::set<std::string> set;
if (list) {
for (auto l : *list) {
if (!l.is_none()) {
(void)set.insert(py::str(l));
}
}
}
return set;
}
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict) {
std::map<std::string, int32_t> map;
if (dict) {
for (auto p : *dict) {
(void)map.emplace(toString(p.first), toInt(p.second));
}
}
return map;
}
std::vector<std::string> toStringVector(const std::optional<py::list> list) {
std::vector<std::string> vector;
if (list) {
for (auto l : *list) {
if (l.is_none())
vector.emplace_back("");
else
vector.push_back(py::str(l));
}
}
return vector;
}
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple) {
std::pair<int64_t, int64_t> pair;
if (tuple) {
pair = std::make_pair(toInt64((*tuple)[0]), toInt64((*tuple)[1]));
}
return pair;
}
std::vector<std::pair<int, int>> toPairVector(const py::list list) {
std::vector<std::pair<int, int>> vector;
if (list) {
for (auto data : list) {
auto l = data.cast<py::tuple>();
if (l[1].is_none())
vector.emplace_back(toInt64(l[0]), 0);
else
vector.emplace_back(toInt64(l[0]), toInt64(l[1]));
}
}
return vector;
}
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations) {
std::vector<std::shared_ptr<TensorOperation>> vector;
if (operations) {
for (auto op : *operations) {
std::shared_ptr<TensorOp> tensor_op;
if (py::isinstance<TensorOp>(op)) {
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
} else if (py::isinstance<py::function>(op)) {
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
} else {
THROW_IF_ERROR(
[]() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }());
}
vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op));
}
}
return vector;
}
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets) {
std::vector<std::shared_ptr<DatasetNode>> vector;
vector.push_back(self);
if (datasets) {
for (auto ds : *datasets) {
if (py::isinstance<DatasetNode>(ds)) {
vector.push_back(ds.cast<std::shared_ptr<DatasetNode>>());
} else {
THROW_IF_ERROR(
[]() { RETURN_STATUS_UNEXPECTED("Error: datasets is not recognised (not a DatasetNode instance)."); }());
}
}
}
return vector;
}
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset) {
if (py_sampler) {
std::shared_ptr<SamplerObj> sampler_obj;
if (!isMindDataset) {
// Common Sampler
std::shared_ptr<SamplerRT> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create");
sampler = create().cast<std::shared_ptr<SamplerRT>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
} else {
// Mindrecord Sampler
std::shared_ptr<mindrecord::ShardOperator> sampler;
auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create_for_minddataset");
sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler));
}
return sampler_obj;
} else {
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: sampler input is not SamplerRT."); }());
}
return nullptr;
}
// Here we take in a python object, that holds a reference to a C++ object
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc) {
if (cc) {
std::shared_ptr<DatasetCache> built_cache;
// Common Sampler
built_cache = std::make_shared<PreBuiltDatasetCache>(std::move(cc.value()));
return built_cache;
} else {
// don't need to check here as cache is not enabled.
return nullptr;
}
}
ShuffleMode toShuffleMode(const int32_t shuffle) {
if (shuffle == 0) return ShuffleMode::kFalse;
if (shuffle == 1) return ShuffleMode::kFiles;
if (shuffle == 2) return ShuffleMode::kGlobal;
return ShuffleMode();
}
std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases) {
std::vector<std::shared_ptr<CsvBase>> vector;
if (csv_bases) {
for (auto base : *csv_bases) {
if (py::isinstance<py::int_>(base)) {
vector.push_back(std::make_shared<CsvRecord<int>>(CsvType::INT, toInt(base)));
} else if (py::isinstance<py::float_>(base)) {
vector.push_back(std::make_shared<CsvRecord<float>>(CsvType::FLOAT, toFloat(base)));
} else if (py::isinstance<py::str>(base)) {
vector.push_back(std::make_shared<CsvRecord<std::string>>(CsvType::STRING, toString(base)));
} else {
THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: each default value must be int, float, or string"); }());
}
}
}
return vector;
}
Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json,
std::map<std::string, std::string> *sample_bytes) {
for (const py::handle &key : padded_sample) {
if (py::isinstance<py::bytes>(padded_sample[key])) {
(*sample_bytes)[py::str(key).cast<std::string>()] = padded_sample[key].cast<std::string>();
// py::str(key) enter here will loss its key name, so we create an unuse key for it in json, to pass ValidateParam
(*padded_sample_json)[py::str(key).cast<std::string>()] = nlohmann::json::object();
} else {
nlohmann::json obj_json;
if (padded_sample[key].is_none()) {
obj_json = nullptr;
} else if (py::isinstance<py::int_>(padded_sample[key])) {
obj_json = padded_sample[key].cast<int64_t>();
} else if (py::isinstance<py::float_>(padded_sample[key])) {
obj_json = padded_sample[key].cast<double>();
} else if (py::isinstance<py::str>(padded_sample[key])) {
obj_json = padded_sample[key].cast<std::string>(); // also catch py::bytes
} else {
MS_LOG(ERROR) << "Python object convert to json failed: " << py::cast<std::string>(padded_sample[key]);
RETURN_STATUS_SYNTAX_ERROR("Python object convert to json failed");
}
(*padded_sample_json)[py::str(key).cast<std::string>()] = obj_json;
}
}
return Status::OK();
}
Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info) {
for (auto p : value) {
if (!p.second.is_none()) {
auto tp = py::reinterpret_borrow<py::tuple>(p.second);
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
std::shared_ptr<Tensor> pad_val = nullptr;
if (py::isinstance<py::str>(tp[1])) {
std::string pad_val_string = tp[1].is_none() ? "" : toString(tp[1]);
CHECK_FAIL_RETURN_UNEXPECTED(
Tensor::CreateFromVector(std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar(), &pad_val),
"Cannot create pad_value Tensor");
} else {
float pad_val_float = tp[1].is_none() ? 0 : toFloat(tp[1]);
CHECK_FAIL_RETURN_UNEXPECTED(
Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_val),
"Cannot create pad_value Tensor");
pad_val->SetItemAt<float>({}, pad_val_float);
}
(void)pad_info->insert({toString(p.first), {shape, pad_val}});
} else { // tuple is None
(void)pad_info->insert({toString(p.first), {TensorShape({}), nullptr}});
}
}
return Status::OK();
}
std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type) {
std::shared_ptr<TensorOp> py_func;
if (!func.is_none()) {
py::function py_function = func.cast<py::function>();
py_func = std::make_shared<PyFuncOp>(py_function, data_type);
} else {
py_func = nullptr;
}
return py_func;
}
py::list shapesToListOfShape(std::vector<TensorShape> shapes) {
py::list shape_list;
for (const auto &shape : shapes) {
shape_list.append(shape.AsVector());
}
return shape_list;
}
py::list typesToListOfType(std::vector<DataType> types) {
py::list type_list;
for (const auto &type : types) {
type_list.append(type.AsNumpyType());
}
return type_list;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
#include "minddata/dataset/kernels/py_func_op.h"
namespace py = pybind11;
namespace mindspore {
namespace dataset {
float toFloat(const py::handle &handle);
int toInt(const py::handle &handle);
int64_t toInt64(const py::handle &handle);
bool toBool(const py::handle &handle);
std::string toString(const py::handle &handle);
std::set<std::string> toStringSet(const std::optional<py::list> list);
std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict);
std::vector<std::string> toStringVector(const std::optional<py::list> list);
std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple);
std::vector<std::pair<int, int>> toPairVector(const py::list list);
std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations);
std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets);
std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false);
std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc);
ShuffleMode toShuffleMode(const int32_t shuffle);
std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases);
std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type);
Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json,
std::map<std::string, std::string> *sample_bytes);
Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info);
py::list shapesToListOfShape(std::vector<TensorShape> shapes);
py::list typesToListOfType(std::vector<DataType> types);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_

View File

@ -190,6 +190,23 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
return sampler;
}
#ifndef ENABLE_ANDROID
// PreBuiltOperation
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler)
: sp_(std::move(sampler)), sp_minddataset_(nullptr) {}
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
: sp_(nullptr), sp_minddataset_(std::move(sampler)) {}
#endif
bool PreBuiltSamplerObj::ValidateParams() { return true; }
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
#endif
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object

View File

@ -222,6 +222,13 @@ Status OneHotOperation::ValidateParams() {
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
// PreBuiltOperation
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {}
Status PreBuiltOperation::ValidateParams() { return Status::OK(); }
std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
// RandomApplyOperation
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
: transforms_(transforms), prob_(prob) {}

View File

@ -18,6 +18,7 @@ set(SRC_FILES_LIST
dataset_iterator.cc
tree_adapter.cc
runtime_context.cc
python_runtime_context.cc
consumers/tree_consumer.cc
)
if (ENABLE_PYTHON)

View File

@ -32,15 +32,37 @@ Status PythonIteratorConsumer::GetNextAsList(py::list *out) {
}
return Status::OK();
}
Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) {
std::unordered_map<std::string, TensorPtr> row;
std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec;
Status s;
{
py::gil_scoped_release gil_release;
RETURN_IF_NOT_OK(GetNextAsMap(&row));
s = GetNextAsOrderedPair(&vec);
}
for (auto el : row) {
(*out)[common::SafeCStr(el.first)] = el.second;
RETURN_IF_NOT_OK(s);
// Generate Python dict, python dict maintains its insertion order
for (const auto &pair : vec) {
(*out)[common::SafeCStr(pair.first)] = pair.second;
}
return Status::OK();
}
Status PythonBuildVocabConsumer::Start() {
py::gil_scoped_release gil_release;
return BuildVocabConsumer::Start();
}
Status PythonSaveToDisk::Save() {
py::gil_scoped_release gil_release;
return SaveToDisk::Save();
}
PythonSaveToDisk::PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType)
: SaveToDisk(datasetPath, numFiles, datasetType) {}
Status PythonTreeGetters::GetRow(TensorRow *r) {
py::gil_scoped_release gil_release;
return TreeGetters::GetRow(r);
}
} // namespace mindspore::dataset

View File

@ -44,5 +44,21 @@ class PythonIteratorConsumer : public IteratorConsumer {
/// \return Status error code
Status GetNextAsDict(py::dict *out);
};
class PythonBuildVocabConsumer : public BuildVocabConsumer {
public:
Status Start() override;
};
class PythonSaveToDisk : public SaveToDisk {
public:
PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType);
Status Save() override;
};
class PythonTreeGetters : public TreeGetters {
public:
Status GetRow(TensorRow *r) override;
};
} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_

View File

@ -23,6 +23,7 @@
#include <vector>
#include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/tree_adapter.h"
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
#ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_header.h"
@ -35,7 +36,7 @@ namespace mindspore::dataset {
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); }
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); }
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->ServiceStop(); }
// IteratorConsumer
Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
@ -73,6 +74,38 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
return Status::OK();
}
Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec) {
CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty.");
TensorRow curr_row;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row));
RETURN_OK_IF_TRUE(curr_row.empty());
size_t num_cols = curr_row.size(); // num_cols is non-empty.
// order the column names according to their ids
if (column_order_.empty()) {
const int32_t invalid_col_id = -1;
column_order_.resize(num_cols, {std::string(), invalid_col_id});
for (const auto &itr : tree_adapter_->GetColumnNameMap()) {
int32_t ind = itr.second;
CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds.");
column_order_[ind] = std::make_pair(itr.first, ind);
}
// error check, make sure the ids in col_name_id_map are continuous and starts from 0
for (const auto &col : column_order_) {
CHECK_FAIL_RETURN_UNEXPECTED(col.second != invalid_col_id, "column ids are not continuous.");
}
}
vec->reserve(num_cols);
std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec),
[curr_row](const auto &col) { return std::make_pair(col.first, curr_row[col.second]); });
return Status::OK();
}
// ToDevice
Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); }
@ -81,7 +114,6 @@ Status ToDevice::Send() {
RETURN_IF_NOT_OK(tree_adapter_->Launch());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_IF_NOT_OK(root->GetNextBuffer(&db));
return Status::OK();
}
@ -101,9 +133,36 @@ Status ToDevice::Stop() {
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
op->StopSend();
return Status::OK();
}
Status ToDevice::GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes) {
// tree_.root() must be DeviceQueueOp
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DeviceQueueOp");
DATA_INFO data_info;
RETURN_IF_NOT_OK(op->GetDataInfo(&data_info));
for (auto el : data_info) {
types->push_back(el.first);
shapes->push_back(el.second);
}
return Status::OK();
}
Status ToDevice::Terminate() {
#ifdef ENABLE_TDTQUE
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
op->StopWaiting();
#endif
return TreeConsumer::Terminate();
}
#ifndef ENABLE_ANDROID
// SaveToDisk
Status SaveToDisk::ValidateParams() {
@ -282,50 +341,50 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
if (column_type == DataType::DE_INT8) {
std::unique_ptr<int32_t> data;
std::unique_ptr<int8_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT16) {
std::unique_ptr<int32_t> data;
std::unique_ptr<int16_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT16) {
std::unique_ptr<int32_t> data;
std::unique_ptr<uint16_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT8) {
std::unique_ptr<uint8_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT32) {
std::unique_ptr<int32_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_UINT32) {
std::unique_ptr<int64_t> data;
std::unique_ptr<uint32_t> dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_INT64) {
std::unique_ptr<int64_t> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_FLOAT32) {
std::unique_ptr<float> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_FLOAT64) {
std::unique_ptr<double> data, dummy;
s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy);
RETURN_IF_NOT_OK(s);
if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data);
} else if (column_type == DataType::DE_STRING) {
@ -346,7 +405,7 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row,
}
template <typename T, typename S>
Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert) {
if (nullptr == src) {
@ -379,47 +438,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
}
#endif
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) {
tree_adapter_ = std::make_unique<TreeAdapter>();
}
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); }
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
if (init_flag_) {
return Status::OK();
}
Status s = tree_adapter_->Compile(std::move(d), 1);
if (!s.IsError()) {
init_flag_ = true;
}
return s;
}
bool TreeGetters::isInitialized() { return init_flag_; }
Status TreeGetters::GetRow(TensorRow *row) {
if (row_flag_ == false) {
RETURN_IF_NOT_OK(tree_adapter_->GetNext(row));
row_flag_ = true;
}
root_ = std::move(d);
return Status::OK();
}
Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); }
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ == -1) {
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize)));
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_UNEXPECTED_IF_NULL(root);
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
dataset_size_ = *dataset_size;
if (*dataset_size == -1) {
RETURN_IF_NOT_OK(GetRow(&row_));
int64_t num_rows = 0;
TensorRow row = row_;
while (row.size() != 0) {
num_rows++;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
if (*dataset_size == -1) { // run through the tree and get everything
TensorRow row;
RETURN_IF_NOT_OK(GetRow(&row));
int64_t row_cnt = 0;
while (!row.empty()) {
++row_cnt;
RETURN_IF_NOT_OK(GetRow(&row));
}
dataset_size_ = num_rows;
*dataset_size = row_cnt;
}
dataset_size_ = *dataset_size; // save the previous result
}
*dataset_size = dataset_size_;
@ -427,68 +471,88 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
}
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
RETURN_IF_NOT_OK(GetRow(&row_));
for (auto ts : row_) {
DataType dt = ts->type();
types->push_back(dt);
}
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types),
[](const TensorPtr &t) { return t->type(); });
return Status::OK();
}
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
RETURN_IF_NOT_OK(GetRow(&row_));
for (auto ts : row_) {
TensorShape t = ts->shape();
shapes->push_back(t);
}
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes),
[](const TensorPtr &t) { return t->shape(); });
return Status::OK();
}
Status TreeGetters::GetBatchSize(int64_t *batch_size) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_UNEXPECTED_IF_NULL(root);
*batch_size = root->GetTreeBatchSize();
CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size.");
return Status::OK();
}
Status TreeGetters::GetRepeatCount(int64_t *repeat_count) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_UNEXPECTED_IF_NULL(root);
*repeat_count = root->GetTreeRepeatCount();
return Status::OK();
}
Status TreeGetters::GetNumClasses(int64_t *num_classes) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_UNEXPECTED_IF_NULL(root);
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
return Status::OK();
}
Status TreeGetters::GetColumnNames(std::vector<std::string> *output) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
RETURN_UNEXPECTED_IF_NULL(root);
std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map();
if (column_name_id_map.empty()) RETURN_STATUS_UNEXPECTED("GetColumnNames: column_name_id map was empty.");
std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(),
column_name_id_map.end());
std::sort(column_name_id_vector.begin(), column_name_id_vector.end(),
CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map is empty.");
std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end());
std::sort(col_name_id_vec.begin(), col_name_id_vec.end(),
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
return a.second < b.second;
});
for (auto item : column_name_id_vector) {
(*output).push_back(item.first);
}
std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output),
[](const std::pair<std::string, int32_t> &p) { return p.first; });
return Status::OK();
}
Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
RETURN_IF_NOT_OK(InternalInit());
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_UNEXPECTED_IF_NULL(root);
RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing));
return Status::OK();
}
Status TreeGetters::InternalInit(int8_t type) {
if (init_flag_) return Status::OK();
tree_adapter_->SetPrePassOverride([&type](OptPass pre) {
pre.push_back(std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(type)));
return pre;
});
Status s = tree_adapter_->Compile(std::move(root_), 1);
if (!s.IsError()) init_flag_ = true;
return s;
}
Status TreeGetters::InternalInit() {
if (init_flag_) return Status::OK();
Status s = tree_adapter_->Compile(std::move(root_), 1);
if (!s.IsError()) init_flag_ = true;
return s;
}
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
Status BuildVocabConsumer::Start() {

View File

@ -41,7 +41,7 @@ class TreeConsumer {
/// \return Status error code.
virtual Status Init(std::shared_ptr<DatasetNode> d);
Status Terminate();
virtual Status Terminate();
protected:
/// The class owns the tree_adapter that handles execution tree operations.
@ -72,6 +72,11 @@ class IteratorConsumer : public TreeConsumer {
/// \return Status error code
Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out);
/// Returns the next row in as a map
/// \param[out] out std::map of string to Tensor
/// \return Status error code
Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec);
protected:
/// Method to return the name of the consumer
/// \return string
@ -79,6 +84,7 @@ class IteratorConsumer : public TreeConsumer {
private:
int32_t num_epochs_;
std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id
};
#ifndef ENABLE_ANDROID
@ -101,7 +107,7 @@ class SaveToDisk : public TreeConsumer {
/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
/// would be written to disk)
/// \return Status error code
Status Save();
virtual Status Save();
protected:
/// Method to return the name of the consumer
@ -110,7 +116,7 @@ class SaveToDisk : public TreeConsumer {
private:
template <typename T, typename S>
Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
Status TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements,
std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr,
std::unique_ptr<S> *s, bool need_convert = false);
@ -131,24 +137,29 @@ class SaveToDisk : public TreeConsumer {
/// Consumer that iterates over the dataset and send it to a device
class ToDevice : public TreeConsumer {
public:
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
~ToDevice() = default;
Status Init(std::shared_ptr<DatasetNode> d) override;
Status Terminate() override;
/// Send the data to device
/// \return Status error code
Status Send();
virtual Status Send();
/// Stop to send data to device
/// \return Status error code
Status Stop();
virtual Status Stop();
/// Continue to send data to device
/// \return Status error code
Status Continue();
virtual Status Continue();
/// Get data info from TDT
/// \return Status error code
virtual Status GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes);
protected:
/// Method to return the name of the consumer
@ -156,8 +167,6 @@ class ToDevice : public TreeConsumer {
std::string Name() override { return "ToDevice"; }
private:
std::string device_type_;
bool send_epoch_end_;
int32_t num_epochs_;
};
@ -167,6 +176,7 @@ class TreeGetters : public TreeConsumer {
TreeGetters();
~TreeGetters() = default;
Status Init(std::shared_ptr<DatasetNode> d) override;
Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
@ -175,15 +185,17 @@ class TreeGetters : public TreeConsumer {
Status GetNumClasses(int64_t *num_classes);
Status GetColumnNames(std::vector<std::string> *output);
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
bool isInitialized();
std::string Name() override { return "TreeGetters"; }
Status GetRow(TensorRow *r);
virtual Status GetRow(TensorRow *r);
private:
std::shared_ptr<DatasetNode> root_;
int64_t dataset_size_;
TensorRow row_;
TensorRow first_row_;
bool init_flag_; // indicate whether the tree has initialized
bool row_flag_; // indicate whether the first row has been stored in row_
Status InternalInit(int8_t type);
Status InternalInit();
};
class BuildVocabConsumer : public TreeConsumer {
@ -197,7 +209,7 @@ class BuildVocabConsumer : public TreeConsumer {
/// Start consuming
/// \return Status error code
Status Start();
virtual Status Start();
protected:
/// Method to return the name of the consumer

View File

@ -44,9 +44,9 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
}
// Constructor of the ConcatOp.
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
std::vector<std::pair<int, int>> children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index)
ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
const std::vector<std::pair<int, int>> &children_flag_and_nums,
const std::vector<std::pair<int, int>> &children_start_end_index)
: PipelineOp(op_connector_size),
children_num_(0),
sampler_(sampler),

View File

@ -70,9 +70,9 @@ class ConcatOp : public PipelineOp {
// @note The builder class should be used to call it
// @param op_connector_size - connector size
explicit ConcatOp(int32_t op_connector_size);
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
std::vector<std::pair<int, int>> children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index);
ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
const std::vector<std::pair<int, int>> &children_flag_and_nums,
const std::vector<std::pair<int, int>> &children_start_end_index);
// Destructor
~ConcatOp() = default;

View File

@ -346,6 +346,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Name of the current Op
virtual std::string Name() const = 0;
/// Op name and ID getter
/// \return Name and ID of the current Op
std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; }
/// Execution Tree getter
/// \return Pointer to the ExecutionTree the current op belongs to, no ownership
ExecutionTree *Tree() { return tree_; }

View File

@ -205,7 +205,6 @@ Status DeviceQueueOp::SendDataToAscend() {
}
tree_->SetFinished();
MS_LOG(INFO) << "Device queue total batch is " << send_batch;
return Status::OK();
}

View File

@ -39,10 +39,10 @@ using mindspore::device::GpuBufferMgr;
namespace mindspore {
namespace dataset {
using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>;
using DATA_INFO_QUEUE = Queue<DATA_INFO>;
const int kDataInfoQueueCapacity = 128;
class DeviceQueueOp : public PipelineOp {
public:
static const uint32_t INVALID_HANDLE = 0xffffffffUL;
@ -184,7 +184,6 @@ class DeviceQueueOp : public PipelineOp {
#ifdef ENABLE_TDTQUE
Status SendDataToAscend();
bool ascend_keep_waiting_;
#endif
#ifdef ENABLE_GPUQUE

View File

@ -169,7 +169,7 @@ Status MapOp::operator()() {
}
// The operator class just starts off threads by calling the tree_ function
rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1));
rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1), NameWithID());
// Synchronize with TaskManager
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(rc);

View File

@ -704,6 +704,8 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
}
if (image_ids_.size() == 0) {
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
} else {
num_rows = image_ids_.size();
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;

View File

@ -480,13 +480,13 @@ Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = num_rows_, sample_size;
int64_t num_rows = num_rows_;
if (num_rows_ <= 0) {
std::shared_ptr<ShardOperator> op;
// The last operator is parent sampler
std::shared_ptr<ShardOperator> op = operators_.back();
RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_));
}
sample_size = operators_[0]->GetNumSamples(num_rows, 0);
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
*dataset_size = num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}

View File

@ -1067,6 +1067,19 @@ Status TFReaderOp::PrepareNodePostAction() {
return Status::OK();
}
// Get the file list of the specific shard ID
Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) {
if (!shard_filenames->empty()) {
RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n");
}
for (int index = 0; index < dataset_files_list_.size(); index++) {
if (index % num_devices_ == device_id_) {
shard_filenames->push_back(dataset_files_list_.at(index));
}
}
return Status::OK();
}
// Get Dataset size
Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
@ -1080,7 +1093,9 @@ Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
num_rows = num_rows_per_shard_;
} else {
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, dataset_files_list_));
std::vector<std::string> shard_file_list;
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list));
}
}
sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows();

View File

@ -400,6 +400,11 @@ class TFReaderOp : public ParallelOp {
// @return - Status
Status ComputeColMap() override;
// Private function for computing the file list of the specific shard ID. This is because in distributed scenario,
// data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case.
// @return - Status - the status code returned.
Status GetShardFileList(std::vector<std::string> *shard_filenames);
int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;

View File

@ -536,6 +536,8 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) {
RETURN_IF_NOT_OK(op->ParseImageIds());
num_rows = static_cast<int64_t>(op->image_ids_.size());
}
} else {
num_rows = image_ids_.size();
}
sample_size = sampler_->CalculateNumSamples(num_rows);
*dataset_size = sample_size;

View File

@ -141,8 +141,6 @@ Status ExecutionTree::Launch() {
" Expected state: " + std::to_string(static_cast<int>(kDeTStateReady));
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::ostringstream ss;
ss << *this;
// Profiling infrastructures need to be initialized before Op launching
if (profiling_manager_->IsProfilingEnable()) {
@ -152,6 +150,8 @@ Status ExecutionTree::Launch() {
RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor());
}
std::ostringstream ss;
ss << *this;
MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str();
for (auto itr = this->begin(); itr != this->end(); ++itr) {
// An inlined operator is one that has an output connector size of 0, and it does not
@ -160,7 +160,7 @@ Status ExecutionTree::Launch() {
// the launching tree/user thread. Do not exec any thread for an inlined op.
itr->state_ = DatasetOp::OpState::kDeOpRunning;
if (!itr->inlined()) {
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr)));
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(itr->NameWithID(), std::ref(*itr)));
// Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp
}
}
@ -189,10 +189,10 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_
// Given the number of workers, launches the worker entry function for each. Essentially a
// wrapper for the TaskGroup handling that is stored inside the execution tree.
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func) {
Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name) {
// Launch the workers
for (int32_t i = 0; i < num_workers; ++i) {
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i)));
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i)));
}
return Status::OK();
}

View File

@ -150,7 +150,7 @@ class ExecutionTree {
// @param num_workers - The number of workers to launch
// @param func - The function entry point that workers will execute
// @return Status - The error code return
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func);
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "");
// Getter method
// @return shared_ptr to the root operator

View File

@ -1,4 +1,5 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-ir-cache OBJECT
dataset_cache_impl.cc)
pre_built_dataset_cache.cc
dataset_cache_impl.cc)

View File

@ -18,8 +18,8 @@
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
namespace mindspore::dataset {
namespace mindspore {
namespace dataset {
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
Status DatasetCacheImpl::Build() {
@ -40,5 +40,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
return Status::OK();
}
} // namespace mindspore::dataset
} // namespace dataset
} // namespace mindspore

View File

@ -24,8 +24,8 @@
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
namespace mindspore::dataset {
namespace mindspore {
namespace dataset {
/// DatasetCache is the IR of CacheClient
class DatasetCacheImpl : public DatasetCache {
public:
@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache {
std::optional<int32_t> num_connections_;
std::optional<int32_t> prefetch_sz_;
};
} // namespace mindspore::dataset
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
namespace mindspore {
namespace dataset {
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
Status PreBuiltDatasetCache::Build() {
// we actually want to keep a reference of the runtime object so it can be shared by different pipelines
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheOp> cache_op = nullptr;
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op));
*ds = cache_op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
namespace mindspore {
namespace dataset {
/// DatasetCache is the IR of CacheClient
class PreBuiltDatasetCache : public DatasetCache {
public:
/// \brief Constructor
/// \param cc a pre-built cache client
explicit PreBuiltDatasetCache(std::shared_ptr<CacheClient> cc) : cache_client_(std::move(cc)) {}
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
/// \return Status Error code
Status Build() override;
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status ValidateParams() override { return Status::OK(); }
private:
std::shared_ptr<CacheClient> cache_client_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_

View File

@ -31,7 +31,7 @@ namespace dataset {
BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function,
std::shared_ptr<TensorOp> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder)
: column_names_(column_names),
@ -47,16 +47,13 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::shared_ptr<TensorOp> c_func;
if (element_length_function_ != nullptr) {
c_func = std::make_shared<CFuncOp>(element_length_function_);
} else {
c_func = nullptr;
bucket_boundaries_.insert(bucket_boundaries_.begin(), 0);
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(
column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_,
pad_to_bucket_boundary_, drop_remainder_, connector_que_size_));
if (bucket_boundaries_[0] == 0) {
bucket_boundaries_.erase(bucket_boundaries_.begin());
}
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
c_func, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_));
return node_ops;
}

View File

@ -33,7 +33,7 @@ class BucketBatchByLengthNode : public DatasetNode {
/// \brief Constructor
BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
std::shared_ptr<TensorOp> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
@ -52,7 +52,7 @@ class BucketBatchByLengthNode : public DatasetNode {
std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;
std::function<TensorRow(TensorRow)> element_length_function_;
std::shared_ptr<TensorOp> element_length_function_;
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
bool pad_to_bucket_boundary_;
bool drop_remainder_;

View File

@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/concat_op.h"
@ -27,7 +28,15 @@ namespace mindspore {
namespace dataset {
// Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; }
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets,
const std::shared_ptr<SamplerObj> &sampler,
const std::vector<std::pair<int, int>> &children_flag_and_nums,
const std::vector<std::pair<int, int>> &children_start_end_index)
: sampler_(sampler),
children_flag_and_nums_(children_flag_and_nums),
children_start_end_index_(children_start_end_index) {
this->children = datasets;
}
Status ConcatNode::ValidateParams() {
if (children.size() < 2) {
@ -42,14 +51,25 @@ Status ConcatNode::ValidateParams() {
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if ((children_flag_and_nums_.empty() && !children_start_end_index_.empty()) ||
(!children_flag_and_nums_.empty() && children_start_end_index_.empty())) {
std::string err_msg = "ConcatNode: children_flag_and_nums and children_start_end_index should be used together";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
} else {
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_,
children_start_end_index_));
}
node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
return node_ops;
}

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
@ -29,7 +30,10 @@ namespace dataset {
class ConcatNode : public DatasetNode {
public:
/// \brief Constructor
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets);
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets,
const std::shared_ptr<SamplerObj> &sampler = nullptr,
const std::vector<std::pair<int, int>> &children_flag_and_nums = {},
const std::vector<std::pair<int, int>> &children_start_end_index = {});
/// \brief Destructor
~ConcatNode() = default;
@ -41,6 +45,11 @@ class ConcatNode : public DatasetNode {
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::shared_ptr<SamplerObj> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
};
} // namespace dataset

View File

@ -240,6 +240,7 @@ DatasetNode::DatasetNode() {
rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
build_status = Status::OK(); // remove me after changing return val of Build()
}
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
@ -254,5 +255,13 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return p->VisitAfter(shared_from_this(), modified);
}
Status DatasetNode::GetShardId(int32_t *shard_id) {
if (!Children().empty()) {
// Get shard id from the child node
return Children()[0]->GetShardId(shard_id);
} else {
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node");
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -99,9 +99,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
virtual Status GetShardId(int32_t *shard_id);
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
@ -126,6 +124,10 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status of the node visit
virtual Status AcceptAfter(NodePass *p, bool *modified);
/// \brief Method to get status from Node.Build()
/// \notes Remove me after changing return val of Build()
Status BuildStatus() { return build_status; }
protected:
std::vector<std::shared_ptr<DatasetNode>> children;
std::shared_ptr<DatasetCache> cache_;
@ -135,6 +137,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
int32_t rows_per_buffer_;
int32_t connector_que_size_;
int32_t worker_connector_size_;
Status build_status; // remove me after changing return val of Build()
};
} // namespace dataset

View File

@ -28,7 +28,7 @@ namespace mindspore {
namespace dataset {
// Constructor for FilterNode
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate,
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
std::vector<std::string> input_columns)
: predicate_(predicate), input_columns_(input_columns) {
this->children.push_back(child);
@ -38,10 +38,7 @@ std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::shared_ptr<TensorOp> c_func;
c_func = std::make_shared<CFuncOp>(predicate_);
node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, c_func));
node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_));
return node_ops;
}

View File

@ -29,7 +29,7 @@ namespace dataset {
class FilterNode : public DatasetNode {
public:
/// \brief Constructor
FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate,
FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
std::vector<std::string> input_columns = {});
/// \brief Destructor
@ -44,7 +44,7 @@ class FilterNode : public DatasetNode {
Status ValidateParams() override;
private:
std::function<TensorRow(TensorRow)> predicate_;
std::shared_ptr<TensorOp> predicate_;
std::vector<std::string> input_columns_;
};

View File

@ -64,7 +64,8 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
auto project_op = std::make_shared<ProjectOp>(project_columns_);
node_ops.push_back(project_op);
}
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(map_op);
return node_ops;

View File

@ -59,7 +59,8 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
std::vector<std::shared_ptr<DatasetOp>> node_ops;
auto schema = std::make_unique<DataSchema>();
RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_));
build_status = schema->LoadSchemaFile(schema_path_, column_names_);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Argument that is not exposed to user in the API.
std::set<std::string> extensions = {};

View File

@ -60,7 +60,8 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, usage_, extensions_, std::move(schema),

View File

@ -56,7 +56,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),

View File

@ -54,7 +54,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),

View File

@ -197,18 +197,23 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(clue_op->Init());
build_status = clue_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows));
build_status = ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
node_ops.push_back(shuffle_op);
}
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));

View File

@ -111,7 +111,8 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(op);

View File

@ -108,18 +108,23 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_,
rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files,
num_shards_, shard_id_, std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(csv_op->Init());
build_status = csv_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows));
build_status = CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
node_ops.push_back(shuffle_op);
}

View File

@ -30,7 +30,25 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
const std::vector<DataType> &column_types)
: generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {}
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
: generator_function_(generator_function), schema_(schema) {}
std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>();
if (schema_ != nullptr) {
column_names_.clear();
column_types_.clear();
std::string schema_json_string = schema_->to_json();
RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, {}));
for (int32_t i = 0; i < data_schema->NumColumns(); i++) {
ColDescriptor col = data_schema->column(i);
column_names_.push_back(col.name());
column_types_.push_back((col.type()));
}
}
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
@ -43,6 +61,8 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
// This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which
// best be delivered when the test cases for this api is ready.
Status rc = op->Init();
build_status = rc; // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (rc.IsOk()) {
node_ops.push_back(op);
@ -56,5 +76,11 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() {
// no validation is needed for generator op.
Status GeneratorNode::ValidateParams() { return Status::OK(); }
Status GeneratorNode::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);
*shard_id = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -35,6 +35,9 @@ class GeneratorNode : public DatasetNode {
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
const std::vector<DataType> &column_types);
/// \brief Constructor
GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema);
/// \brief Destructor
~GeneratorNode() = default;
@ -46,10 +49,15 @@ class GeneratorNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node, is always 0 because generator_node doesn't support sharding
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
py::function generator_function_;
std::vector<std::string> column_names_;
std::vector<DataType> column_types_;
std::shared_ptr<SchemaObj> schema_;
};
} // namespace dataset

View File

@ -62,7 +62,8 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema),

View File

@ -79,7 +79,8 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
manifest_op =
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
class_index_, std::move(schema), std::move(sampler_->Build()), usage_);
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(manifest_op);

View File

@ -138,7 +138,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::vector<std::shared_ptr<ShardOperator>> operators_;
RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_));
build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
std::shared_ptr<MindRecordOp> mindrecord_op;
// If pass a string to MindData(), it will be treated as a pattern to search for matched files,
@ -154,7 +155,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
padded_sample_, sample_bytes_);
}
RETURN_EMPTY_IF_ERROR(mindrecord_op->Init());
build_status = mindrecord_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(mindrecord_op);
return node_ops;

View File

@ -51,7 +51,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
TensorShape scalar = TensorShape::CreateScalar();
RETURN_EMPTY_IF_ERROR(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->Build())));

View File

@ -98,7 +98,8 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(op);

View File

@ -78,7 +78,8 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
build_status = text_file_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
@ -86,14 +87,17 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows));
build_status = TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
node_ops.push_back(shuffle_op);
}
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
// Add TextFileOp
node_ops.push_back(text_file_op);

View File

@ -118,7 +118,8 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_,
shard_id_, shard_equal_rows_, std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(tf_reader_op->Init());
build_status = tf_reader_op->Init(); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
@ -127,14 +128,17 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
build_status = TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
build_status = AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op);
RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build()
node_ops.push_back(shuffle_op);
}
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
// Add TFReaderOp
node_ops.push_back(tf_reader_op);

View File

@ -106,7 +106,8 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
std::shared_ptr<VOCOp> voc_op;
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
node_ops.push_back(voc_op);
return node_ops;

View File

@ -27,9 +27,8 @@ namespace mindspore {
namespace dataset {
// Constructor for SyncWaitNode
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch,
py::function callback)
: condition_name_(condition_name), num_batch_(num_batch), callback_(callback) {
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback)
: condition_name_(condition_name), callback_(callback) {
this->children.push_back(child);
}
@ -38,20 +37,16 @@ std::vector<std::shared_ptr<DatasetOp>> SyncWaitNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<BarrierOp>(num_batch_, connector_que_size_, condition_name_, callback_));
// Right now barrier should only take num_rows_per_buffer = 1
// The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details
int32_t rows_per_buffer = 1;
node_ops.push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_));
return node_ops;
}
// Function to validate the parameters for SyncWaitNode
Status SyncWaitNode::ValidateParams() {
if (num_batch_ <= 0) {
std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status SyncWaitNode::ValidateParams() { return Status::OK(); }
} // namespace dataset
} // namespace mindspore

View File

@ -31,8 +31,7 @@ namespace dataset {
class SyncWaitNode : public DatasetNode {
public:
/// \brief Constructor
explicit SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch,
py::function callback);
SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback);
/// \brief Destructor
~SyncWaitNode() = default;
@ -47,7 +46,6 @@ class SyncWaitNode : public DatasetNode {
private:
std::string condition_name_;
int32_t num_batch_;
py::function callback_;
};

View File

@ -18,73 +18,81 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/util/status.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace dataset {
// Constructor for TransferNode
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end)
: prefetch_size_(16), send_epoch_end_(send_epoch_end), total_batch_(0) {
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type,
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
: prefetch_size_(16),
queue_name_(std::move(queue_name)),
device_type_(std::move(device_type)),
send_epoch_end_(send_epoch_end),
total_batch_(total_batch),
create_data_info_queue_(create_data_info_queue),
device_id_(0) {
this->children.push_back(child);
}
// Validator for TransferNode
Status TransferNode::ValidateParams() {
// Check if device_type_ is in {"CPU", "GPU", "Ascend"}
RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"}));
if (total_batch_ < 0) {
std::string err_msg = "TransferNode: Total batches should be >= 0, value given: ";
MS_LOG(ERROR) << err_msg << total_batch_;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
// Function to build TransferNode
std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
// Get a uuid for queue name
queue_name_ = Services::GetUniqueID();
// TODO(CRC):
if (queue_name_.empty()) {
// Get a uuid for queue name
queue_name_ = Services::GetUniqueID();
}
if (device_type_.empty()) {
auto context = MsContext::GetInstance();
if (context == nullptr) {
device_type_ = kCPUDevice;
} else {
device_type_ = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
}
}
// Get device type from ms context
device_type_ = "CPU";
// Get device ID from children
// Convert device_type_ from string to DeviceType
DeviceQueueOp::DeviceType type;
if (device_type_ == kCPUDevice) {
type = DeviceQueueOp::DeviceType::CPU;
} else if (device_type_ == kGPUDevice) {
type = DeviceQueueOp::DeviceType::GPU;
} else if (device_type_ == kAscendDevice) {
type = DeviceQueueOp::DeviceType::Ascend;
} else {
MS_LOG(ERROR) << "Unknown device target.";
return {};
}
// Get device ID (shard ID) from children
device_id_ = 0;
RETURN_EMPTY_IF_ERROR(TransferNode::get_distribution(shared_from_this(), &device_id_));
build_status = this->GetShardId(&device_id_); // remove me after changing return val of Build()
RETURN_EMPTY_IF_ERROR(build_status);
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// Convert device_type_ from string to DeviceType
DeviceQueueOp::DeviceType type;
if (device_type_ == "CPU") {
type = DeviceQueueOp::DeviceType::CPU;
} else if (device_type_ == "GPU") {
type = DeviceQueueOp::DeviceType::GPU;
} else if (device_type_ == "Ascend") {
type = DeviceQueueOp::DeviceType::Ascend;
}
node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
total_batch_, false));
total_batch_, create_data_info_queue_));
return node_ops;
}
// Function to get the device_id
Status TransferNode::get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id) {
// Get device id according to the type of dataset
Status rc = ds->GetShardId(device_id);
if (rc != Status::OK()) {
// Get device id from the child node
if (ds->Children().size()) {
ds = ds->Children()[0];
return TransferNode::get_distribution(ds, device_id);
} else {
std::string err_msg = "Unknown dataset type.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -29,7 +29,8 @@ namespace dataset {
class TransferNode : public DatasetNode {
public:
/// \brief Constructor
TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end);
TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, bool send_epoch_end,
int32_t total_batch, bool create_data_info_queue);
/// \brief Destructor
~TransferNode() = default;
@ -42,8 +43,6 @@ class TransferNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
private:
std::string queue_name_;
int32_t device_id_;
@ -51,6 +50,7 @@ class TransferNode : public DatasetNode {
int32_t prefetch_size_;
bool send_epoch_end_;
int32_t total_batch_;
bool create_data_info_queue_;
};
} // namespace dataset

View File

@ -40,21 +40,7 @@ Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *mo
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
if (type_ == kOutputShapeAndType) {
nodes_to_clear_callback_.push_back(node);
} else if (type_ == kDatasetSize) {
nodes_to_remove_.push_back(node);
}
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
return Status::OK();
}
Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
if (type_ == kDatasetSize) nodes_to_remove_.push_back(node);
nodes_to_clear_callback_.push_back(node);
return Status::OK();
}
@ -83,5 +69,6 @@ Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) {
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -34,6 +34,10 @@ class GetterPass : public TreePass {
enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 };
/// \brief Constructor
explicit GetterPass(GetterType tp) : pass_(tp) {}
/// \brief default copy Constructor
explicit GetterPass(const GetterPass &) = default;
/// \brief Destructor
~GetterPass() = default;
@ -51,11 +55,10 @@ class GetterPass : public TreePass {
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override { return Status::OK(); }
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
// whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
@ -67,7 +70,7 @@ class GetterPass : public TreePass {
std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_;
std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_;
};
// outter class needs only to own the inner class object since it automatically has access to its private variables
// outer class needs only to own the inner class object since it automatically has access to its private variables
GetterNodes pass_;
};
} // namespace dataset

View File

@ -19,7 +19,14 @@
namespace mindspore::dataset {
Status PythonRuntimeContext::Terminate() { return TerminateImpl(); }
Status PythonRuntimeContext::Terminate() {
MS_LOG(INFO) << "Terminating a PythonRuntime";
if (tree_consumer_ != nullptr) {
return TerminateImpl();
}
MS_LOG(WARNING) << "TreeConsumer was not initialized";
return Status::OK();
}
Status PythonRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");

View File

@ -22,7 +22,14 @@ namespace mindspore::dataset {
void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) {
tree_consumer_ = std::move(tree_consumer);
}
Status NativeRuntimeContext::Terminate() { return TerminateImpl(); }
Status NativeRuntimeContext::Terminate() {
MS_LOG(INFO) << "Terminating a NativeRuntime";
if (tree_consumer_ != nullptr) {
return TerminateImpl();
}
MS_LOG(WARNING) << "TreeConsumer was not initialized";
return Status::OK();
}
Status NativeRuntimeContext::TerminateImpl() {
CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized");

View File

@ -97,6 +97,8 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) {
// Build the DatasetOp ExecutionTree from the optimized IR tree
std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build();
RETURN_IF_NOT_OK(ir->BuildStatus()); // remove me after changing return val of Build()
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node.");
(*op) = ops.front(); // return the first op to be added as child by the caller of this function
@ -141,6 +143,8 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op));
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
// Note: We will gradually move the pre pass, optimizer pass, and post pass
// on ExecutionTree to perform on IR tree.
// Prepare the tree
@ -149,6 +153,11 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
// After the tree is prepared, the col_name_id_map can safely be obtained
column_name_map_ = tree_->root()->column_name_id_map();
// Profiling parameters init
cur_batch_num_ = 0;
cur_connector_size_ = 0;
cur_connector_capacity_ = 0;
return Status::OK();
}
@ -156,21 +165,55 @@ Status TreeAdapter::GetNext(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_UNEXPECTED_IF_NULL(row);
row->clear(); // make sure row is empty
bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
if (cur_db_ == nullptr) {
RETURN_IF_NOT_OK(tree_->Launch());
// Profiling
std::shared_ptr<Tracing> node;
Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
if (s.IsOk()) {
tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
}
if (tracing_ != nullptr) {
cur_connector_size_ = tree_->root()->ConnectorSize();
cur_connector_capacity_ = tree_->root()->ConnectorCapacity();
}
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag
RETURN_OK_IF_TRUE(cur_db_->eoe()); // return empty tensor if 1st buf is a ctrl buf (no rows)
if (cur_db_->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
MS_LOG(INFO) << "End of data iteration.";
if (isProfilingEnable) {
tree_->SetEpochEnd();
}
return Status::OK();
}
}
CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached.");
if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_));
RETURN_OK_IF_TRUE(cur_db_->eoe() || cur_db_->eof()); // return empty if this new buffer is a ctrl flag
if (cur_db_->eoe()) { // return empty if this new buffer is a ctrl flag
MS_LOG(INFO) << "End of data iteration.";
if (isProfilingEnable) {
tree_->SetEpochEnd();
}
return Status::OK();
}
if (cur_db_->eof()) {
tree_->SetFinished();
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
}
}
RETURN_IF_NOT_OK(cur_db_->PopRow(row));
// Record profiling info
if (tracing_ != nullptr) {
cur_batch_num_++;
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
}
return Status::OK();
}

View File

@ -25,6 +25,7 @@
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
namespace mindspore {
namespace dataset {
@ -60,6 +61,9 @@ class TreeAdapter {
// Set optional optimization pass
void SetOptimize(bool value) { optimize_ = value; }
// function to override override the pre-pass
void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; }
// Optional optimizations status
bool OptimizationEnabled() const { return optimize_; }
@ -82,9 +86,14 @@ class TreeAdapter {
std::unique_ptr<DataBuffer> cur_db_;
std::unordered_map<std::string, int32_t> column_name_map_;
std::unique_ptr<ExecutionTree> tree_;
std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling
int32_t num_epochs_;
bool optimize_; // Flag to enable optional optimization pass
bool optimize_; // Flag to enable optional optimization pass
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
int32_t cur_batch_num_; // current batch number, used for profiling
int32_t cur_connector_size_; // current connector size of root op, used for profiling
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
};
} // namespace dataset
} // namespace mindspore

View File

@ -145,9 +145,16 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \brief Function to transfer data through a device.
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
/// of data transmission per time is 256M.
/// \param[in] queue_name Channel name (default="", create new unique name).
/// \param[in] device_type Type of device (default="", get from MSContext).
/// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
/// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
/// of data or not(default=false).
/// \return Returns true if no error encountered else false.
bool DeviceQueue(bool send_epoch_end = true);
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t num_epochs = -1,
bool send_epoch_end = true, int32_t total_batches = 0, bool create_data_info_queue = false);
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
/// \note Usage restrictions:
@ -371,21 +378,34 @@ class SchemaObj {
/// \brief SchemaObj init function
/// \return bool true if schema init success
bool init();
Status init();
/// \brief Add new column to the schema with unknown shape of rank 1
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(TypeId).
/// \return bool true if schema init success
Status add_column(std::string name, TypeId de_type);
/// \brief Add new column to the schema with unknown shape of rank 1
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(std::string).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
Status add_column(std::string name, std::string de_type);
/// \brief Add new column to the schema
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(TypeId).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
bool add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
Status add_column(std::string name, TypeId de_type, std::vector<int32_t> shape);
/// \brief Add new column to the schema
/// \param[in] name name of the column.
/// \param[in] de_type data type of the column(std::string).
/// \param[in] shape shape of the column.
/// \return bool true if schema init success
bool add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
Status add_column(std::string name, std::string de_type, std::vector<int32_t> shape);
/// \brief Get a JSON string of the schema
/// \return JSON string of the schema
@ -395,25 +415,27 @@ class SchemaObj {
std::string to_string() { return to_json(); }
/// \brief set a new value to dataset_type
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = dataset_type; }
inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); }
/// \brief set a new value to num_rows
inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; }
/// \brief get the current num_rows
inline int32_t get_num_rows() { return num_rows_; }
inline int32_t get_num_rows() const { return num_rows_; }
Status FromJSONString(const std::string &json_string);
private:
/// \brief Parse the columns and add it to columns
/// \param[in] columns dataset attribution information, decoded from schema file.
/// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject.
/// \return JSON string of the schema
bool parse_column(nlohmann::json columns);
Status parse_column(nlohmann::json columns);
/// \brief Get schema file from json file
/// \param[in] json_obj object of json parsed.
/// \return bool true if json dump success
bool from_json(nlohmann::json json_obj);
Status from_json(nlohmann::json json_obj);
int32_t num_rows_;
std::string dataset_type_;

View File

@ -61,6 +61,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
class DistributedSamplerObj;
class PKSamplerObj;
class PreBuiltSamplerObj;
class RandomSamplerObj;
class SequentialSamplerObj;
class SubsetRandomSamplerObj;
@ -171,6 +172,31 @@ class PKSamplerObj : public SamplerObj {
int64_t num_samples_;
};
class PreBuiltSamplerObj : public SamplerObj {
public:
#ifndef ENABLE_ANDROID
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
#endif
~PreBuiltSamplerObj() = default;
std::shared_ptr<SamplerRT> Build() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
private:
std::shared_ptr<SamplerRT> sp_;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_;
#endif
};
class RandomSamplerObj : public SamplerObj {
public:
RandomSamplerObj(bool replacement, int64_t num_samples);

View File

@ -70,6 +70,7 @@ namespace transforms {
class ComposeOperation;
class DuplicateOperation;
class OneHotOperation;
class PreBuiltOperation;
class RandomApplyOperation;
class RandomChoiceOperation;
class TypeCastOperation;
@ -164,6 +165,20 @@ class OneHotOperation : public TensorOperation {
float num_classes_;
};
class PreBuiltOperation : public TensorOperation {
public:
explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op);
~PreBuiltOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
private:
std::shared_ptr<TensorOp> op_;
};
class RandomApplyOperation : public TensorOperation {
public:
explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob);
@ -192,7 +207,6 @@ class RandomChoiceOperation : public TensorOperation {
private:
std::vector<std::shared_ptr<TensorOperation>> transforms_;
};
class TypeCastOperation : public TensorOperation {
public:
explicit TypeCastOperation(std::string data_type);

View File

@ -71,6 +71,15 @@ namespace dataset {
return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \
} while (false)
#define RETURN_SECOND_IF_ERROR(_s, _r) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
MS_LOG(ERROR) << __rc; \
return _r; \
} \
} while (false)
enum class StatusCode : char {
kOK = 0,
kOutOfMemory = 1,

View File

@ -138,7 +138,9 @@ Status Task::Join(WaitFlag blocking) {
while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) {
// We can't tell which conditional_variable this thread is waiting on. So we may need
// to interrupt everything one more time.
MS_LOG(INFO) << "Some threads not responding. Interrupt again";
std::stringstream ss;
ss << get_id();
MS_LOG(ERROR) << MyName() << " Thread ID " << ss.str() << " is not responding. Interrupt again";
interrupt_svc->InterruptAll();
}
} else {

View File

@ -21,7 +21,8 @@ import numpy
import mindspore._c_dataengine as cde
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load']
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load',
'get_callback_timeout']
INT32_MAX = 2147483647
UINT32_MAX = 4294967295

View File

@ -65,5 +65,7 @@ def mstypelist_to_detypelist(type_list):
for index, _ in enumerate(type_list):
if type_list[index] is not None:
type_list[index] = mstype_to_detype(type_list[index])
else:
type_list[index] = cde.DataType("")
return type_list

File diff suppressed because it is too large Load Diff

View File

@ -15,17 +15,13 @@
"""Built-in iterators.
"""
from abc import abstractmethod
import copy
import weakref
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore._c_dataengine import DEPipeline
from mindspore._c_dataengine import OpName
import mindspore._c_dataengine as cde
from mindspore import log as logger
from . import datasets as de
_ITERATOR_CLEANUP = False
@ -57,29 +53,6 @@ def _cleanup():
itr.release()
def alter_tree(node):
"""Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
if not node.children:
return _alter_node(node)
converted_children = []
for input_op in node.children:
converted_children.append(alter_tree(input_op))
node.children = converted_children
return _alter_node(node)
def _alter_node(node):
"""DEPRECATED"""
# Please check ccsrc/dataset/engine/opt for tree transformation.
if isinstance(node, de.MapDataset):
if node.python_multiprocessing:
# Bootstrap can only be performed on a copy of the original dataset node.
# Bootstrap on original dataset node will make all iterators share the same process pool
node.iterator_bootstrap()
return node
class Iterator:
"""
General Iterator over a dataset.
@ -89,185 +62,62 @@ class Iterator:
"""
def __init__(self, dataset, num_epochs=-1, output_numpy=False):
self.num_epochs = num_epochs
self.output_numpy = output_numpy
self._col_names = None
# create a copy of tree and work on it.
self.ori_dataset = dataset
self.ir_tree, self.dataset = dataset.create_ir_tree()
self._runtime_context = cde.PythonRuntimeContext()
self._runtime_context.Init()
consumer = cde.PythonIteratorConsumer(num_epochs)
consumer.Init(self.ir_tree)
self._runtime_context.AssignConsumer(consumer)
self._iterator = self._runtime_context.GetConsumer()
self._transform_tensor = lambda t: t.as_array()
if not output_numpy:
self._transform_tensor = lambda t: Tensor(t.as_array())
self._index = 0
# todo remove next when ContextManager is done
ITERATORS_LIST.append(weakref.ref(self))
_unset_iterator_cleanup()
# create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset)
self.ori_dataset = dataset
self.parent_subtree = []
#######
# The dataset passed into the iterator is not the root of the tree.
# Trim the tree by saving the parent subtree into self.parent_subtree and
# restore it after launching our C++ pipeline.
if self.dataset.parent:
logger.info("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
self.parent_subtree = self.dataset.parent
self.dataset.parent = []
self.dataset = alter_tree(self.dataset)
if not self.__is_tree():
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers).")
self.depipeline = DEPipeline()
# for manifest temporary use
self.__batch_node(self.dataset, 0)
root = self.__convert_node_postorder(self.dataset)
self.depipeline.AssignRootNode(root)
self.depipeline.PrepareTree(self.num_epochs)
self._index = 0
def __iter__(self):
return self
def stop(self):
"""
Manually terminate Python iterator instead of relying on out of scope destruction.
"""
logger.info("Terminating Python iterator. This will also terminate C++ pipeline.")
if hasattr(self, 'depipeline') and self.depipeline:
del self.depipeline
def __is_tree_node(self, node):
"""Check if a node is tree node."""
if not node.children:
if len(node.parent) > 1:
return False
if len(node.parent) > 1:
return False
for input_node in node.children:
cls = self.__is_tree_node(input_node)
if not cls:
return False
return True
def __is_tree(self):
return self.__is_tree_node(self.dataset)
@staticmethod
def __get_dataset_type(dataset):
"""Get the dataset type."""
op_type = None
if isinstance(dataset, de.ShuffleDataset):
op_type = OpName.SHUFFLE
elif isinstance(dataset, de.MindDataset):
op_type = OpName.MINDRECORD
elif isinstance(dataset, de.BatchDataset):
op_type = OpName.BATCH
elif isinstance(dataset, de.BucketBatchByLengthDataset):
op_type = OpName.BUCKETBATCH
elif isinstance(dataset, de.SyncWaitDataset):
op_type = OpName.BARRIER
elif isinstance(dataset, de.ZipDataset):
op_type = OpName.ZIP
elif isinstance(dataset, de.ConcatDataset):
op_type = OpName.CONCAT
elif isinstance(dataset, de.MapDataset):
op_type = OpName.MAP
elif isinstance(dataset, de.FilterDataset):
op_type = OpName.FILTER
elif isinstance(dataset, de.RepeatDataset):
op_type = OpName.REPEAT
elif isinstance(dataset, de.SkipDataset):
op_type = OpName.SKIP
elif isinstance(dataset, de.TakeDataset):
op_type = OpName.TAKE
elif isinstance(dataset, de.ImageFolderDataset):
op_type = OpName.IMAGEFOLDER
elif isinstance(dataset, de.GeneratorDataset):
op_type = OpName.GENERATOR
elif isinstance(dataset, de.TransferDataset):
op_type = OpName.DEVICEQUEUE
elif isinstance(dataset, de.RenameDataset):
op_type = OpName.RENAME
elif isinstance(dataset, de.TFRecordDataset):
op_type = OpName.TFREADER
elif isinstance(dataset, de.ProjectDataset):
op_type = OpName.PROJECT
elif isinstance(dataset, de.MnistDataset):
op_type = OpName.MNIST
elif isinstance(dataset, de.ManifestDataset):
op_type = OpName.MANIFEST
elif isinstance(dataset, de.VOCDataset):
op_type = OpName.VOC
elif isinstance(dataset, de.CocoDataset):
op_type = OpName.COCO
elif isinstance(dataset, de.Cifar10Dataset):
op_type = OpName.CIFAR10
elif isinstance(dataset, de.Cifar100Dataset):
op_type = OpName.CIFAR100
elif isinstance(dataset, de.CelebADataset):
op_type = OpName.CELEBA
elif isinstance(dataset, de.RandomDataset):
op_type = OpName.RANDOMDATA
elif isinstance(dataset, de.TextFileDataset):
op_type = OpName.TEXTFILE
elif isinstance(dataset, de.BuildVocabDataset):
op_type = OpName.BUILDVOCAB
elif isinstance(dataset, de.BuildSentencePieceVocabDataset):
op_type = OpName.SENTENCEPIECEVOCAB
elif isinstance(dataset, de.CLUEDataset):
op_type = OpName.CLUE
elif isinstance(dataset, de.CSVDataset):
op_type = OpName.CSV
else:
raise ValueError("Unsupported DatasetOp.")
return op_type
# Convert Python node into C node and add to C layer execution tree in postorder traversal.
def __convert_node_postorder(self, node):
self.check_node_type(node)
op_type = self.__get_dataset_type(node)
c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
for py_child in node.children:
c_child = self.__convert_node_postorder(py_child)
self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"])
return c_nodes["top"]
def __batch_node(self, dataset, level):
"""Recursively get batch node in the dataset tree."""
if isinstance(dataset, de.BatchDataset):
return
for input_op in dataset.children:
self.__batch_node(input_op, level + 1)
@staticmethod
def __print_local(dataset, level):
"""Recursively print the name and address of nodes in the dataset tree."""
name = dataset.__class__.__name__
ptr = hex(id(dataset))
for _ in range(level):
logger.info("\t", end='')
if not dataset.children:
logger.info("-%s (%s)", name, ptr)
else:
logger.info("+%s (%s)", name, ptr)
for input_op in dataset.children:
Iterator.__print_local(input_op, level + 1)
def print(self):
"""Print the dataset tree"""
self.__print_local(self.dataset, 0)
if hasattr(self, '_runtime_context') and self._runtime_context:
if hasattr(self, '_iterator') and self._iterator:
self._runtime_context.Terminate()
del self._iterator
del self._runtime_context
del self.dataset
def release(self):
if hasattr(self, 'depipeline') and self.depipeline:
del self.depipeline
self.stop()
def __del__(self):
self.release()
@abstractmethod
def get_next(self):
def _get_next(self):
raise RuntimeError("Calling base class Iterator's get_next is invalid.")
def __next__(self):
if not self.depipeline:
if not self._runtime_context:
logger.warning("Iterator does not have a running C++ pipeline." +
"It might because Iterator stop() had been called, or C++ pipeline crashed silently.")
raise RuntimeError("Iterator does not have a running C++ pipeline.")
data = self.get_next()
data = self._get_next()
if not data:
if self._index == 0:
logger.warning("No records available.")
@ -277,100 +127,56 @@ class Iterator:
self._index += 1
return data
@abstractmethod
def check_node_type(self, node):
pass
def get_output_shapes(self):
return [t for t in self.depipeline.GetOutputShapes()]
def get_output_types(self):
return [t for t in self.depipeline.GetOutputTypes()]
def get_dataset_size(self):
return self.depipeline.GetDatasetSize()
def get_batch_size(self):
return self.depipeline.GetBatchSize()
def get_repeat_count(self):
return self.depipeline.GetRepeatCount()
def num_classes(self):
return self.depipeline.GetNumClasses()
def get_col_names(self):
return self.depipeline.GetColumnNames()
def __deepcopy__(self, memo):
return self
def _getters(self):
"""
Get pipeline information.
"""
getter = cde.TreeGetters()
getter.Init(self.ir_tree)
self._runtime_context.AssignConsumer(getter)
self._col_names = getter.GetColumnNames()
class SaveOp(Iterator):
"""
The derived class of Iterator with dict type.
"""
def __init__(self, dataset, num_epochs=-1):
super().__init__(dataset, num_epochs)
self.depipeline.LaunchTreeExec()
def get_next(self):
pass
def check_node_type(self, node):
if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
logger.warning("Used shuffle, repeat, batch before save operator.")
def save(self, file_names, file_type):
return self.depipeline.SaveDataset(file_names, file_type)
def get_col_names(self):
"""
Get names of the columns in the dataset
"""
if self._col_names is None:
self._getters()
return self._col_names
class DictIterator(Iterator):
"""
The derived class of Iterator with dict type.
"""
def __init__(self, dataset, num_epochs=-1, output_numpy=False):
super().__init__(dataset, num_epochs, output_numpy)
self.depipeline.LaunchTreeExec()
def check_node_type(self, node):
pass
def __iter__(self):
return self
def get_next(self):
def _get_next(self):
"""
Returns the next record in the dataset as dictionary
Returns:
Dict, the next record in the dataset.
"""
if self.output_numpy:
return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()}
return {k: Tensor(v.as_array()) for k, v in self.depipeline.GetNextAsMap().items()}
return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()}
class TupleIterator(Iterator):
"""
The derived class of Iterator with list type.
"""
def check_node_type(self, node):
pass
def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False):
if columns is not None:
if not isinstance(columns, list):
columns = [columns]
# todo: move next to IR
dataset = dataset.project(columns)
super().__init__(dataset, num_epochs, output_numpy)
self.depipeline.LaunchTreeExec()
def __iter__(self):
return self
def get_next(self):
def _get_next(self):
"""
Returns the next record in the dataset as a list
@ -378,15 +184,14 @@ class TupleIterator(Iterator):
List, the next record in the dataset.
"""
if self.output_numpy:
return [t.as_array() for t in self.depipeline.GetNextAsList()]
return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()]
return [self._transform_tensor(t) for t in self._iterator.GetNextAsList()]
class DummyIterator:
"""
A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
"""
def __init__(self, dataset, mode):
self.mode = mode
self.shapes = dataset.output_shapes()

View File

@ -283,9 +283,12 @@ def create_node(node):
node.get('shard_id'), sampler)
elif dataset_op == 'TFRecordDataset':
shuffle = node.get('shuffle')
if shuffle is not None and isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle)
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
node.get('num_samples'), node.get('num_parallel_workers'),
de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id'))
shuffle, node.get('num_shards'), node.get('shard_id'))
elif dataset_op == 'ManifestDataset':
sampler = construct_sampler(node.get('sampler'))

View File

@ -293,14 +293,38 @@ def check_save(method):
return new_method
def check_iterator(method):
def check_tuple_iterator(method):
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
[columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_bool = ['output_numpy']
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
if num_epochs is not None:
type_check(num_epochs, (int,), "num_epochs")
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
if columns is not None:
check_columns(columns, "column_names")
return method(self, *args, **kwargs)
return new_method
def check_dict_iterator(method):
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
@wraps(method)
def new_method(self, *args, **kwargs):
[num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_bool = ['output_numpy']
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
if num_epochs is not None:
type_check(num_epochs, (int,), "num_epochs")
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
return method(self, *args, **kwargs)
return new_method
@ -523,6 +547,8 @@ def check_batch(method):
sig = ins.signature(batch_size)
if len(sig.parameters) != 1:
raise ValueError("callable batch_size should take one parameter (BatchInfo).")
else:
check_pos_int32(int(batch_size), "batch_size")
if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers)
@ -807,6 +833,21 @@ def check_project(method):
return new_method
def check_schema(method):
"""check the input arguments of Schema.__init__."""
@wraps(method)
def new_method(self, *args, **kwargs):
[schema_file], _ = parse_user_args(method, *args, **kwargs)
if schema_file is not None:
type_check(schema_file, (str,), "schema_file")
return method(self, *args, **kwargs)
return new_method
def check_add_column(method):
"""check the input arguments of add_column."""
@ -1261,3 +1302,23 @@ def check_cache_option(cache):
"""Sanity check for cache parameter"""
if cache is not None:
type_check(cache, (cache_client.DatasetCache,), "cache")
def check_to_device_send(method):
"""A wrapper that wraps a parameter checker around the check_to_device_send."""
@wraps(method)
def new_method(self, *args, **kwargs):
[num_epochs], _ = parse_user_args(method, *args, **kwargs)
if num_epochs is not None:
type_check(num_epochs, (int,), "num_epochs")
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
return method(self, *args, **kwargs)
return new_method
def replace_none(value, default):
return value if value is not None else default

View File

@ -18,13 +18,13 @@ use to_bytes and to_str to encode and decode strings into a specified format.
"""
from enum import IntEnum
import copy
import numpy as np
import mindspore._c_dataengine as cde
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \
check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model
__all__ = [
"Vocab", "SentencePieceVocab", "to_str", "to_bytes"
]
@ -39,8 +39,7 @@ class Vocab(cde.Vocab):
@classmethod
@check_from_dataset
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None,
special_first=True):
def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, special_first=True):
"""
Build a vocab from a dataset.
@ -69,21 +68,7 @@ class Vocab(cde.Vocab):
Returns:
Vocab, Vocab object built from dataset.
"""
vocab = Vocab()
if columns is None:
columns = []
if not isinstance(columns, list):
columns = [columns]
if freq_range is None:
freq_range = (None, None)
if special_tokens is None:
special_tokens = []
root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first)
for d in root.create_dict_iterator(num_epochs=1):
if d is not None:
raise ValueError("from_dataset should receive data other than None.")
return vocab
return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first)
@classmethod
@check_from_list
@ -143,6 +128,7 @@ class SentencePieceVocab(cde.SentencePieceVocab):
"""
SentencePiece obiect that is used to segmentate words
"""
@classmethod
@check_from_dataset_sentencepiece
def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params):
@ -164,13 +150,8 @@ class SentencePieceVocab(cde.SentencePieceVocab):
SentencePiece, SentencePiece object from dataset.
"""
vocab = SentencePieceVocab()
root = copy.deepcopy(dataset).build_sentencepiece_vocab(vocab, col_names, vocab_size, character_coverage,
model_type, params)
for d in root.create_dict_iterator(num_epochs=1):
if d is None:
raise ValueError("from_dataset should receive data other than None.")
return vocab
return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage,
DE_C_INTER_SENTENCEPIECE_MODE[model_type], params)
@classmethod
@check_from_file_sentencepiece
@ -270,6 +251,7 @@ class SentencePieceModel(IntEnum):
CHAR = 2
WORD = 3
DE_C_INTER_SENTENCEPIECE_MODE = {
SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM,
SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE,

View File

@ -432,7 +432,7 @@ def check_from_dataset_sentencepiece(method):
[_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
if col_names is not None:
type_check(col_names, (list,), "col_names")
type_check_list(col_names, (str,), "col_names")
if vocab_size is not None:
check_uint32(vocab_size, "vocab_size")

View File

@ -146,6 +146,7 @@ if (BUILD_MINDDATA STREQUAL "full")
list(REMOVE_ITEM MINDDATA_ENGINE_IR_CACHE_SRC_FILES
"${MINDDATA_DIR}/engine/ir/cache/dataset_cache_impl.cc"
"${MINDDATA_DIR}/engine/ir/cache/pre_built_dataset_cache.cc"
)
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES

View File

@ -123,6 +123,7 @@ def connect_network_with_dataset(network, dataset_helper):
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
return network
class DatasetHelper:
"""
DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
@ -197,7 +198,6 @@ class DatasetHelper:
def get_data_info(self):
return self.iter.get_data_info()
class _DatasetIter:
"""Base iter for dataset helper"""
@ -331,7 +331,6 @@ class _DatasetIterPSLite(_DatasetIter):
class _DatasetIterNormal:
"""Iter for normal(non sink) mode, feed the data from host."""
def __init__(self, dataset, epoch_num=-1):
self.dataset = dataset
self.device_num = _get_device_num()

View File

@ -61,15 +61,15 @@ class MindData:
def send(self, num_epochs=-1):
pass
def get_data_info(self):
pass
def stop_send(self):
pass
def continue_send(self):
pass
def get_data_info(self):
pass
def __len__(self):
return self._size

View File

@ -177,8 +177,8 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) {
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
// 5 batches of size 2
EXPECT_EQ(i, 5);
// With 2 boundaries, 3 buckets are created
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();

View File

@ -132,6 +132,6 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) {
// verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not
EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos);
EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos);
EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos);
EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos);
EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos);
}

View File

@ -63,7 +63,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}};
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
std::vector<size_t> row_sizes = {2, 2, 0, 0};
std::vector<size_t> row_sizes = {2, 2, 0};
TensorRow row;
for (size_t sz : row_sizes) {
@ -75,7 +75,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) {
rc = tree_adapter.GetNext(&row);
EXPECT_TRUE(rc.IsError());
const std::string err_msg = rc.ToString();
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
}
TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
@ -97,7 +97,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap();
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0, 0};
std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0};
TensorRow row;
for (size_t sz : row_sizes) {
@ -107,7 +107,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) {
}
rc = tree_adapter.GetNext(&row);
const std::string err_msg = rc.ToString();
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
}
TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
@ -135,7 +135,7 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
const std::unordered_map<std::string, int32_t> map = {{"label", 0}};
EXPECT_EQ(tree_adapter.GetColumnNameMap(), map);
std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0, 0};
std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0};
TensorRow row;
for (size_t sz : row_sizes) {
@ -145,5 +145,5 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) {
}
rc = tree_adapter.GetNext(&row);
const std::string err_msg = rc.ToString();
EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos);
EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos);
}

View File

@ -451,6 +451,10 @@ def test_batch_exception_13():
def test_batch_exception_14():
"""
Test per_batch_map and input column name
"""
logger.info("test_batch_exception_14")
batch_size = 2
input_columns = ["num"]
data1 = ds.TFRecordDataset(DATA_DIR)
@ -460,6 +464,22 @@ def test_batch_exception_14():
assert "per_batch_map and input_columns need to be passed in together." in str(e)
def test_batch_exception_15():
"""
Test batch_size = int32 max value + 1
"""
logger.info("test_batch_exception_15")
batch_size = 2147483647 + 1
input_columns = ["num"]
data1 = ds.TFRecordDataset(DATA_DIR)
err_msg = ""
try:
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
except ValueError as e:
err_msg = str(e)
assert "batch_size is not within the required interval of (1 to 2147483647)" in err_msg
if __name__ == '__main__':
test_batch_01()
test_batch_02()
@ -486,4 +506,5 @@ if __name__ == '__main__':
test_batch_exception_12()
test_batch_exception_13()
test_batch_exception_14()
test_batch_exception_15()
logger.info('\n')

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import pytest
import mindspore.dataset as ds
@ -354,6 +355,18 @@ def test_clue_to_device():
data.send()
def test_clue_invalid_files():
"""
Test CLUE with invalid files
"""
AFQMC_DIR = '../data/dataset/testCLUE/afqmc'
afqmc_train_json = os.path.join(AFQMC_DIR)
with pytest.raises(ValueError) as info:
_ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False)
assert "The following patterns did not match any files" in str(info.value)
assert AFQMC_DIR in str(info.value)
if __name__ == "__main__":
test_clue()
test_clue_num_shards()
@ -366,3 +379,4 @@ if __name__ == "__main__":
test_clue_tnews()
test_clue_wsc()
test_clue_to_device()
test_clue_invalid_files()

View File

@ -195,6 +195,19 @@ def test_csv_dataset_size():
assert data.get_dataset_size() == 5
def test_csv_dataset_type_error():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", 0, "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
assert "type does not match" in str(err.value)
def test_csv_dataset_exception():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
data = ds.CSVDataset(
@ -208,17 +221,16 @@ def test_csv_dataset_exception():
assert "failed to parse file" in str(err.value)
def test_csv_dataset_type_error():
TEST_FILE = '../data/dataset/testCSV/exception.csv'
def test_csv_dataset_duplicate_columns():
data = ds.CSVDataset(
TEST_FILE,
column_defaults=["", 0, "", ""],
column_names=['col1', 'col2', 'col3', 'col4'],
DATA_FILE,
column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'],
shuffle=False)
with pytest.raises(Exception) as err:
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
pass
assert "type does not match" in str(err.value)
with pytest.raises(RuntimeError) as info:
_ = data.create_dict_iterator(num_epochs=1, output_numpy=True)
assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value)
assert "column_names" in str(info.value)
if __name__ == "__main__":
@ -234,5 +246,6 @@ if __name__ == "__main__":
test_csv_dataset_header()
test_csv_dataset_number()
test_csv_dataset_size()
test_csv_dataset_exception()
test_csv_dataset_type_error()
test_csv_dataset_exception()
test_csv_dataset_duplicate_columns()

View File

@ -14,6 +14,7 @@
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train"
IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
@ -21,9 +22,18 @@ IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-000
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
MNIST_DATA_DIR = "../data/dataset/testMnistData"
MIND_CV_FILE_NAME = "../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord"
SCHEMA_FILE = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data"
CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data"
VOC_DATA_DIR = "../data/dataset/testVOC2012"
COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
CLUE_FILE = '../data/dataset/testCLUE/afqmc/train.json'
CSV_FILE = '../data/dataset/testCSV/1.csv'
TEXT_DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
def test_imagenet_rawdata_dataset_size():
@ -50,8 +60,15 @@ def test_imagenet_tf_file_dataset_size():
ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0)
assert ds_shard_2_0.get_dataset_size() == 6
# FIXME: dataset_size == 6 looks wrong but seem it aims to match the current code.
# Correct answer should be 12/3=4, the code issue should be addressed.
ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0)
assert ds_shard_3_0.get_dataset_size() == 4
assert ds_shard_3_0.get_dataset_size() == 6
count = 0
for _ in ds_shard_3_0.create_dict_iterator():
count += 1
assert ds_shard_3_0.get_dataset_size() == count
def test_mnist_dataset_size():
@ -76,6 +93,14 @@ def test_mnist_dataset_size():
assert ds_shard_3_0.get_dataset_size() == 3334
def test_mind_dataset_size():
dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0")
assert dataset.get_dataset_size() == 20
dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 10
def test_manifest_dataset_size():
ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE)
assert ds_total.get_dataset_size() == 4
@ -95,10 +120,11 @@ def test_cifar10_dataset_size():
assert ds_total.get_dataset_size() == 10000
# test get_dataset_size with usage flag
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 0
train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 10000
test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size()
assert test_size == 0
all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size()
assert all_size == 10000
@ -120,8 +146,6 @@ def test_cifar100_dataset_size():
assert ds_total.get_dataset_size() == 10000
# test get_dataset_size with usage flag
train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size()
assert train_size == 0
test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size()
assert test_size == 10000
all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size()
@ -137,10 +161,97 @@ def test_cifar100_dataset_size():
assert ds_shard_3_0.get_dataset_size() == 3334
def test_voc_dataset_size():
dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True)
assert dataset.get_dataset_size() == 10
dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True,
num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 5
def test_coco_dataset_size():
dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
decode=True, shuffle=False)
assert dataset.get_dataset_size() == 6
dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 3
def test_celeba_dataset_size():
dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
assert dataset.get_dataset_size() == 4
dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
def test_clue_dataset_size():
dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False)
assert dataset.get_dataset_size() == 3
dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
def test_csv_dataset_size():
dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
shuffle=False)
assert dataset.get_dataset_size() == 3
dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'],
shuffle=False, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
def test_text_file_dataset_size():
dataset = ds.TextFileDataset(TEXT_DATA_FILE)
assert dataset.get_dataset_size() == 3
dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0)
assert dataset_shard_2_0.get_dataset_size() == 2
def test_padded_dataset_size():
dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}])
assert dataset.get_dataset_size() == 2
def test_pipeline_get_dataset_size():
dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False)
assert dataset.get_dataset_size() == 12
dataset = dataset.shuffle(buffer_size=3)
assert dataset.get_dataset_size() == 12
decode_op = vision.Decode()
resize_op = vision.RandomResize(10)
dataset = dataset.map([decode_op, resize_op], input_columns=["image"])
assert dataset.get_dataset_size() == 12
dataset = dataset.batch(batch_size=3)
assert dataset.get_dataset_size() == 4
dataset = dataset.repeat(count=2)
assert dataset.get_dataset_size() == 8
if __name__ == '__main__':
test_imagenet_rawdata_dataset_size()
test_imagenet_tf_file_dataset_size()
test_mnist_dataset_size()
test_mind_dataset_size()
test_manifest_dataset_size()
test_cifar10_dataset_size()
test_cifar100_dataset_size()
test_voc_dataset_size()
test_coco_dataset_size()
test_celeba_dataset_size()
test_clue_dataset_size()
test_csv_dataset_size()
test_text_file_dataset_size()
test_padded_dataset_size()
test_pipeline_get_dataset_size()

View File

@ -521,7 +521,7 @@ def test_chained_sampler_04():
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 24
assert data1_size == 6
# Verify number of iterations
num_iter = 0

View File

@ -182,6 +182,15 @@ def test_voc_exception():
pass
def test_voc_num_classes():
data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
assert data1.num_classes() is None
class_index = {'car': 0, 'cat': 1, 'train': 5}
data2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True)
assert data2.num_classes() is None
if __name__ == '__main__':
test_voc_segmentation()
test_voc_detection()
@ -191,3 +200,4 @@ if __name__ == '__main__':
test_case_1()
test_case_2()
test_voc_exception()
test_voc_num_classes()

View File

@ -107,7 +107,7 @@ def test_decode_op():
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
assert "object has no attribute '_runtime_context'" in str(info.value)
with pytest.raises(RuntimeError) as info:
iter2.__next__()
@ -205,7 +205,7 @@ def test_generator_dict_3():
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_dict_4():
@ -396,7 +396,7 @@ def test_generator_tuple_3():
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_tuple_4():
@ -546,7 +546,7 @@ def test_generator_tuple_repeat_repeat_2():
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
assert "object has no attribute '_runtime_context'" in str(info.value)
def test_generator_tuple_repeat_repeat_3():

View File

@ -74,9 +74,11 @@ def test_case2():
def test_case3():
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(5)
data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2)
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename(
["col_sint64"], ["a1"])
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(5).rename(
["col_sint64"], ["a2"])
data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).rename(["col_sint64"], ["a3"])
data4 = ds.zip((data1, data2, data3))
@ -84,8 +86,9 @@ def test_case3():
def test_case4():
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
data2 = ds.TFRecordDataset(FILES)
data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename(
["col_sint64"], ["a1"])
data2 = ds.TFRecordDataset(FILES, columns_list=["col_sint64"]).rename(["col_sint64"], ["a2"])
assert data2.get_dataset_size() == 12
data2 = data2.batch(2)
assert data2.get_dataset_size() == 6

Some files were not shown because too many files have changed in this diff Show More