From 2795e492ffe14fa02939a8c3c315af6b9d3dbfaf Mon Sep 17 00:00:00 2001 From: yanghaitao Date: Thu, 16 Apr 2020 15:03:41 +0800 Subject: [PATCH] TextFileDataset --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 37 +- mindspore/ccsrc/dataset/api/de_pipeline.h | 5 +- .../ccsrc/dataset/api/python_bindings.cc | 15 +- .../engine/datasetops/source/CMakeLists.txt | 1 + .../engine/datasetops/source/text_file_op.cc | 459 ++++++++++++++++++ .../engine/datasetops/source/text_file_op.h | 263 ++++++++++ mindspore/dataset/__init__.py | 6 +- mindspore/dataset/engine/__init__.py | 4 +- mindspore/dataset/engine/datasets.py | 130 ++++- mindspore/dataset/engine/iterators.py | 10 +- mindspore/dataset/engine/validators.py | 22 + mindspore/dataset/transforms/nlp/__init__.py | 20 + mindspore/dataset/transforms/nlp/utils.py | 35 ++ tests/ut/cpp/dataset/CMakeLists.txt | 2 +- tests/ut/cpp/dataset/text_file_op_test.cc | 112 +++++ .../ut/data/dataset/testTextFileDataset/1.txt | 3 + .../ut/data/dataset/testTextFileDataset/2.txt | 2 + .../dataset/test_datasets_textfileop.py | 87 ++++ 18 files changed, 1175 insertions(+), 38 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h create mode 100644 mindspore/dataset/transforms/nlp/__init__.py create mode 100644 mindspore/dataset/transforms/nlp/utils.py create mode 100644 tests/ut/cpp/dataset/text_file_op_test.cc create mode 100644 tests/ut/data/dataset/testTextFileDataset/1.txt create mode 100644 tests/ut/data/dataset/testTextFileDataset/2.txt create mode 100644 tests/ut/python/dataset/test_datasets_textfileop.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 5f61c86f06e..f6440710b1b 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -28,10 +28,10 @@ #include "dataset/engine/datasetops/source/manifest_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/celeba_op.h" +#include "dataset/engine/datasetops/source/text_file_op.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" - #include "dataset/util/random.h" #include "dataset/util/status.h" #include "utils/log_adapter.h" @@ -61,7 +61,8 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kVoc, &DEPipeline::ParseVOCOp}, {kCifar10, &DEPipeline::ParseCifar10Op}, {kCifar100, &DEPipeline::ParseCifar100Op}, - {kCelebA, &DEPipeline::ParseCelebAOp}}; + {kCelebA, &DEPipeline::ParseCelebAOp}, + {kTextFile, &DEPipeline::ParseTextFileOp}}; DEPipeline::DEPipeline() : iterator_(nullptr) { try { @@ -985,5 +986,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr) { + // Required arguments + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + (void)builder->SetTextFilesList(ToStringVector(args["dataset_files"])); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + (void)builder->SetNumDevices(ToInt(value)); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 6ff7bb091cd..eadde2c1910 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -58,7 +58,8 @@ enum OpName { kVoc, kCifar10, kCifar100, - kCelebA + kCelebA, + kTextFile }; // The C++ binder class that we expose to the python script. @@ -148,6 +149,8 @@ class DEPipeline { Status ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseTextFileOp(const py::dict &args, std::shared_ptr *ptr); + private: // Execution tree that links the dataset operators. std::shared_ptr tree_; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 076f2ecc364..5399e7e4259 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -55,6 +55,7 @@ #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/jagged_connector.h" +#include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/kernels/data/to_float16_op.h" #include "dataset/util/random.h" #include "mindrecord/include/shard_operator.h" @@ -176,6 +177,17 @@ void bindDatasetOps(py::module *m) { THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count)); return count; }); + + (void)py::class_>(*m, "TextFileOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); + return count; + }); } void bindTensor(py::module *m) { (void)py::class_(*m, "GlobalContext") @@ -463,7 +475,8 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("VOC", OpName::kVoc) .value("CIFAR10", OpName::kCifar10) .value("CIFAR100", OpName::kCifar100) - .value("CELEBA", OpName::kCelebA); + .value("CELEBA", OpName::kCelebA) + .value("TEXTFILE", OpName::kTextFile); (void)py::enum_(m, "InterpolationMode", py::arithmetic()) .value("DE_INTER_LINEAR", InterpolationMode::kLinear) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt index a7c0dfd725b..8801205f6cb 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt @@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT manifest_op.cc cifar_op.cc celeba_op.cc + text_file_op.cc ) add_dependencies(engine-datasetops-source mindspore::protobuf) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc new file mode 100644 index 00000000000..2b626163669 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc @@ -0,0 +1,459 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/core/config_manager.h" +#include "dataset/util/task_manager.h" +#include "dataset/util/wait_post.h" +#include "dataset/util/random.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +TextFileOp::Builder::Builder() + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status TextFileOp::Builder::ValidateInputs() const { + std::string err_msg; + err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : ""; + err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +Status TextFileOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_text_files_list_.size()) { + builder_num_workers_ = builder_text_files_list_.size(); + MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + + std::shared_ptr text_file_op = std::make_shared( + builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, + std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, + builder_num_devices_, builder_device_id_); + RETURN_IF_NOT_OK(text_file_op->Init()); + *op = std::move(text_file_op); + + return Status::OK(); +} + +TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + std::unique_ptr schema, std::vector text_files_list, + int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), + device_id_(device_id), + num_devices_(num_device), + rows_per_buffer_(rows_per_buffer), + num_samples_(num_samples), + text_files_list_(std::move(text_files_list)), + shuffle_files_(shuffle_files), + data_schema_(std::move(schema)), + all_num_rows_(0), + num_rows_per_shard_(0), + filename_index_(std::make_unique()), + finished_reading_dataset_(false), + load_io_block_queue_(true), + load_jagged_connector_(true) { + worker_connector_size_ = worker_connector_size; +} + +Status TextFileOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(text_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + col_name_map_[data_schema_->column(i).name()] = i; + } + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + return Status::OK(); +} + +Status TextFileOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { + TensorRow tRow(1, nullptr); + (*tensor_table)->push_back(std::move(tRow)); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&tensor, data_schema_->column(0).tensorImpl(), + TensorShape(std::vector(1, line.size())), data_schema_->column(0).type(), + const_cast(reinterpret_cast(common::SafeCStr(line))))); + (**tensor_table)[row][0] = std::move(tensor); + return Status::OK(); +} + +Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + std::ifstream handle(file); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Failed to open file " + file); + } + + int64_t rows_each_buffer = 0; + int64_t rows_total = 0; + std::string line; + std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + cur_buffer->set_column_name_map(col_name_map_); + std::unique_ptr tensor_table = std::make_unique(); + + while (getline(handle, line)) { + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + + RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); + rows_each_buffer++; + rows_total++; + if (rows_each_buffer == rows_per_buffer_) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + + cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + cur_buffer->set_column_name_map(col_name_map_); + tensor_table = std::make_unique(); + rows_each_buffer = 0; + } + } + + if (rows_each_buffer > 0) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + } + + return Status::OK(); +} + +Status TextFileOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// Pops an element from a queue in io_block_queues +Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status TextFileOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +Status TextFileOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + auto file_it = filename_index_->Search(*it); + file_index.emplace_back(std::pair(file_it.value(), *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +Status TextFileOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +Status TextFileOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); + + // Read data from disk into buffers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + + io_block_queue_wait_post_.Register(tree_->AllTasks()); + NotifyToFillIOBlockQueue(); + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (num_samples_ == 0 || rows_read < num_samples_) { + if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { + int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + + return Status::OK(); +} + +int64_t TextFileOp::CountTotalRows(const std::string &file) { + std::ifstream handle(file); + if (!handle.is_open()) { + MS_LOG(ERROR) << "Failed to open file: " << file; + return 0; + } + + std::string line; + int64_t count = 0; + while (getline(handle, line)) { + count++; + } + + return count; +} + +Status TextFileOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED("Number of rows can not be zero"); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +Status TextFileOp::CountAllFileRows(const std::vector &files, int64_t *count) { + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h new file mode 100644 index 00000000000..49f224ffc35 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h @@ -0,0 +1,263 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "dataset/util/status.h" +#include "dataset/util/auto_index.h" +#include "dataset/engine/data_schema.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/util/queue.h" +#include "dataset/util/wait_post.h" +#include "dataset/engine/jagged_connector.h" + +namespace mindspore { +namespace dataset { +using StringIndex = AutoIndexObj; + +class TextFileOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetTextFilesList(const std::vector &files_list) { + builder_text_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumSamples(int64_t num_samples) { + builder_num_samples_ = num_samples; + return *this; + } + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_num_samples_; + int32_t builder_worker_connector_size_; + std::vector builder_text_files_list_; + bool builder_shuffle_files_; + std::unique_ptr builder_schema_; + }; + + // Constructor of TextFileOp + // @note The builder class should be used to call this constructor. + // @param num_workers - number of worker threads reading data from tf_file files. + // @param rows_per_buffer - number of rows that a full buffer will contain. + // @param total_num_rows - number of rows to read + // @param dataset_files_list - list of filepaths for the dataset files. + // @param data_schema - the data schema object. + // @param op_connector_size - size of each queue in the connector that the child operator pulls from. + // @param columns_to_load - the names of the columns to load data from. + // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param equal_rows_per_shard - whether or not to get equal rows for each process. + TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Default destructor + ~TextFileOp() = default; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all text files. + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, int64_t *count); + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a text file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - text file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + int32_t device_id_; + int32_t num_devices_; + int64_t rows_per_buffer_; + int64_t num_samples_; + std::vector text_files_list_; + bool shuffle_files_; + std::unique_ptr data_schema_; + int64_t all_num_rows_; + int64_t num_rows_per_shard_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + QueueList> io_block_queues_; + WaitPost io_block_queue_wait_post_; + bool finished_reading_dataset_; + bool load_io_block_queue_; + bool load_jagged_connector_; + std::unordered_map col_name_map_; + std::unique_ptr jagged_buffer_connector_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 479c66045fe..2a30b616ad0 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -20,8 +20,8 @@ can also create samplers with this module to sample data. from .core.configuration import config from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \ - GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \ - Shuffle, zip + GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ + Schema, Shuffle, zip from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ WeightedRandomSampler from .engine.serializer_deserializer import serialize, deserialize, show @@ -29,5 +29,5 @@ from .engine.serializer_deserializer import serialize, deserialize, show __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", - "VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", + "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip"] diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index 720b56b96d5..86d29713324 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", - "VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", - "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] + "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", + "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8de56a6dff2..ca717643c91 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -29,7 +29,7 @@ from importlib import import_module import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ - MindRecordOp, CBatchInfo + MindRecordOp, TextFileOp, CBatchInfo from mindspore._c_expression import typing from mindspore import log as logger @@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_zip_dataset, check_add_column + check_zip_dataset, check_add_column, check_textfiledataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -888,6 +888,29 @@ class SourceDataset(Dataset): # No need for __init__ since it is the same as the super's init + @staticmethod + def _find_files(patterns): + """ + Utility function to search for files with the given glob patterns. + + Args: + patterns (str or list[str]): string or list of patterns to be searched. + + Returns: + List, files. + """ + + def flat(lists): + return list(np.array(lists).flatten()) + + if not isinstance(patterns, list): + patterns = [patterns] + + file_list = flat([glob.glob(file, recursive=True) for file in patterns]) + if file_list: # not empty + return file_list + raise ValueError("The list of path names matching the patterns is empty.") + class DatasetOp(Dataset): """ @@ -2126,30 +2149,6 @@ class TFRecordDataset(SourceDataset): >>> # 3) get all rows from dataset_files with schema file "./schema.json": >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json") """ - - @staticmethod - def _find_files(patterns): - """ - Utility function to search for files with the given glob patterns. - - Args: - patterns (str or list[str]): string or list of patterns to be searched. - - Returns: - List, files. - """ - - def flat(lists): - return list(np.array(lists).flatten()) - - if not isinstance(patterns, list): - patterns = [patterns] - - file_list = flat([glob.glob(file, recursive=True) for file in patterns]) - if file_list: # not empty - return file_list - raise ValueError("The list of path names matching the patterns is empty.") - @check_tfrecorddataset def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False): @@ -2952,3 +2951,82 @@ class CelebADataset(SourceDataset): args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id return args + +class TextFileDataset(SourceDataset): + """ + A source dataset that reads and parses datasets stored on disk in text format. + The generated dataset has one columns ['text']. + + Args: + dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of + files. The list will be sorted in a lexicographical order. + num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_parallel_workers (int, optional): number of workers to read the data + (default=None, number set in the config). + shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). + If shuffle is False, no shuffling will be performed; + If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL + Otherwise, there are two levels of shuffling: + + - Shuffle.GLOBAL: Shuffle both the files and samples. + + - Shuffle.FILES: Shuffle files only. + + num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. + Examples: + >>> import mindspore.dataset as ds + >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files + >>> dataset = ds.TextFileDataset(dataset_files=dataset_files) + """ + + @check_textfiledataset + def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, + shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + super().__init__(num_parallel_workers) + self.dataset_files = self._find_files(dataset_files) + self.dataset_files.sort() + self.num_samples = num_samples + + if not isinstance(shuffle, (bool, Shuffle)): + raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") + if not isinstance(shuffle, Shuffle): + if shuffle: + self.shuffle_level = Shuffle.GLOBAL + self.shuffle_files = True + else: + self.shuffle_level = None + self.shuffle_files = False + else: + self.shuffle_level = shuffle + self.shuffle_files = True + + self.num_shards = num_shards + self.shard_id = shard_id + + def get_args(self): + args = super().get_args() + args["dataset_files"] = self.dataset_files + args["num_samples"] = self.num_samples + if self.shuffle_files is not None: + args["shuffle_files"] = self.shuffle_files + args["shuffle"] = self.shuffle_level + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + if self._dataset_size is None: + num_rows = TextFileOp.get_num_rows(self.dataset_files) + num_rows = get_num_rows(num_rows, self.num_shards) + if self.num_samples is None: + return num_rows + return min(self.num_samples, num_rows) + return self._dataset_size diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 2bb130f3038..a74d69b9c7b 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -48,12 +48,16 @@ def alter_tree(node): def _alter_node(node): """Performing some alteration to a dataset node. A common alteration is to insert a node.""" - if isinstance(node, de.TFRecordDataset) and node.shuffle_level == de.Shuffle.GLOBAL: + if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL: # Remove the connection between the parent's node to the current node because we are inserting a node. if node.output: node.output.pop() # Perform a fast scan for average rows per file - avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) + if isinstance(node, de.TFRecordDataset): + avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) + else: + avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files) + # Shuffle between 4 files with a minimum size of 10000 rows new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000)) return new_shuffle @@ -157,6 +161,8 @@ class Iterator: op_type = OpName.CIFAR100 elif isinstance(dataset, de.CelebADataset): op_type = OpName.CELEBA + elif isinstance(dataset, de.TextFileDataset): + op_type = OpName.TEXTFILE else: raise ValueError("Unsupported DatasetOp") diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b74e913202f..a340eb5affa 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -849,3 +849,25 @@ def check_add_column(method): return method(*args, **kwargs) return new_method + + +def check_textfiledataset(method): + """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + + # check dataset_files; required argument + dataset_files = param_dict.get('dataset_files') + if dataset_files is None: + raise ValueError("dataset_files is not provided.") + if not isinstance(dataset_files, (str, list)): + raise TypeError("dataset_files should be of type str or a list of strings.") + + check_param_type(nreq_param_int, param_dict, int) + + return method(*args, **kwargs) + + return new_method diff --git a/mindspore/dataset/transforms/nlp/__init__.py b/mindspore/dataset/transforms/nlp/__init__.py new file mode 100644 index 00000000000..01d425e2ebc --- /dev/null +++ b/mindspore/dataset/transforms/nlp/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module is to support nlp augmentations. It includes two parts: +c_transforms and py_transforms. C_transforms is a high performance +image augmentation module which is developed with c++ opencv. Py_transforms +provide more kinds of image augmentations which is developed with python PIL. +""" +from .utils import as_text diff --git a/mindspore/dataset/transforms/nlp/utils.py b/mindspore/dataset/transforms/nlp/utils.py new file mode 100644 index 00000000000..adcc7cc71d7 --- /dev/null +++ b/mindspore/dataset/transforms/nlp/utils.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Some basic function for nlp +""" +import numpy as np + +def as_text(array, encoding='utf8'): + """ + Convert data of array to unicode. + + Args: + array (numpy array): Data of array should be ASCII values of each character after converted. + encoding (string): Indicating the charset for decoding. + Returns: + A 'str' object. + + """ + + if not isinstance(array, np.ndarray): + raise ValueError('input should be a numpy array') + + byte_array = bytearray(list(array)) + return byte_array.decode(encoding) diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index ae9c46e62c9..b05f12eee12 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -65,7 +65,7 @@ SET(DE_UT_SRCS cifar_op_test.cc celeba_op_test.cc take_op_test.cc - ) + text_file_op_test.cc) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/text_file_op_test.cc b/tests/ut/cpp/dataset/text_file_op_test.cc new file mode 100644 index 00000000000..7887eda9552 --- /dev/null +++ b/tests/ut/cpp/dataset/text_file_op_test.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "dataset/core/client.h" +#include "common/common.h" +#include "common/utils.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/util/status.h" + +namespace common = mindspore::common; + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestTextFileOp : public UT::DatasetOpTesting { + +}; + +TEST_F(MindDataTestTextFileOp, TestTextFileBasic) { + // Start with an empty execution tree + auto tree = std::make_shared(); + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testTextFileDataset/1.txt"; + + std::shared_ptr op; + TextFileOp::Builder builder; + builder.SetTextFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16) + .SetOpConnectorSize(2); + + Status rc = builder.Build(&op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssociateNode(op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssignRoot(op); + ASSERT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration."; + rc = tree->Prepare(); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + int row_count = 0; + while (!tensor_list.empty()) { + // Display the tensor by calling the printer on it + for (int i = 0; i < tensor_list.size(); i++) { + std::ostringstream ss; + ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; + MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + } + + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + row_count++; + } + + ASSERT_EQ(row_count, 3); +} + +TEST_F(MindDataTestTextFileOp, TestTotalRows) { + std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; + std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; + std::vector files; + files.push_back(tf_file1); + int64_t total_rows = 0; + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 3); + files.clear(); + + files.push_back(tf_file2); + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 2); + files.clear(); + + files.push_back(tf_file1); + files.push_back(tf_file2); + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 5); + files.clear(); +} diff --git a/tests/ut/data/dataset/testTextFileDataset/1.txt b/tests/ut/data/dataset/testTextFileDataset/1.txt new file mode 100644 index 00000000000..9d911eacc07 --- /dev/null +++ b/tests/ut/data/dataset/testTextFileDataset/1.txt @@ -0,0 +1,3 @@ +This is a text file. +Be happy every day. +Good luck to everyone. diff --git a/tests/ut/data/dataset/testTextFileDataset/2.txt b/tests/ut/data/dataset/testTextFileDataset/2.txt new file mode 100644 index 00000000000..7382722eb8b --- /dev/null +++ b/tests/ut/data/dataset/testTextFileDataset/2.txt @@ -0,0 +1,2 @@ +Another file. +End of file. diff --git a/tests/ut/python/dataset/test_datasets_textfileop.py b/tests/ut/python/dataset/test_datasets_textfileop.py new file mode 100644 index 00000000000..720fcdcce09 --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_textfileop.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.transforms.nlp.utils as nlp + +DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" +DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" + +def test_textline_dataset_one_file(): + data = ds.TextFileDataset(DATA_FILE) + count = 0 + for i in data.create_dict_iterator(): + logger.info("{}".format(i["text"])) + count += 1 + assert(count == 3) + +def test_textline_dataset_all_file(): + data = ds.TextFileDataset(DATA_ALL_FILE) + count = 0 + for i in data.create_dict_iterator(): + logger.info("{}".format(i["text"])) + count += 1 + assert(count == 5) + +def test_textline_dataset_totext(): + data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) + count = 0 + line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."] + for i in data.create_dict_iterator(): + str = nlp.as_text(i["text"]) + assert(str == line[count]) + count += 1 + assert(count == 5) + +def test_textline_dataset_num_samples(): + data = ds.TextFileDataset(DATA_FILE, num_samples=2) + count = 0 + for i in data.create_dict_iterator(): + count += 1 + assert(count == 2) + +def test_textline_dataset_distribution(): + data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1) + count = 0 + for i in data.create_dict_iterator(): + count += 1 + assert(count == 3) + +def test_textline_dataset_repeat(): + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.repeat(3) + count = 0 + line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", + "This is a text file.", "Be happy every day.", "Good luck to everyone.", + "This is a text file.", "Be happy every day.", "Good luck to everyone."] + for i in data.create_dict_iterator(): + str = nlp.as_text(i["text"]) + assert(str == line[count]) + count += 1 + assert(count == 9) + +def test_textline_dataset_get_datasetsize(): + data = ds.TextFileDataset(DATA_FILE) + size = data.get_dataset_size() + assert(size == 3) + +if __name__ == "__main__": + test_textline_dataset_one_file() + test_textline_dataset_all_file() + test_textline_dataset_totext() + test_textline_dataset_num_samples() + test_textline_dataset_distribution() + test_textline_dataset_repeat() + test_textline_dataset_get_datasetsize()