forked from mindspore-Ecosystem/mindspore
TextFileDataset
This commit is contained in:
parent
18580a7867
commit
2795e492ff
|
@ -28,10 +28,10 @@
|
|||
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "mindrecord/include/shard_category.h"
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
#include "mindrecord/include/shard_shuffle.h"
|
||||
|
||||
#include "dataset/util/random.h"
|
||||
#include "dataset/util/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -61,7 +61,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
|
|||
{kVoc, &DEPipeline::ParseVOCOp},
|
||||
{kCifar10, &DEPipeline::ParseCifar10Op},
|
||||
{kCifar100, &DEPipeline::ParseCifar100Op},
|
||||
{kCelebA, &DEPipeline::ParseCelebAOp}};
|
||||
{kCelebA, &DEPipeline::ParseCelebAOp},
|
||||
{kTextFile, &DEPipeline::ParseTextFileOp}};
|
||||
|
||||
DEPipeline::DEPipeline() : iterator_(nullptr) {
|
||||
try {
|
||||
|
@ -985,5 +986,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
|
|||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
// Required arguments
|
||||
std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
|
||||
if (!args["dataset_files"].is_none()) {
|
||||
(void)builder->SetTextFilesList(ToStringVector(args["dataset_files"]));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
|
||||
}
|
||||
// Optional arguments
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_shards") {
|
||||
(void)builder->SetNumDevices(ToInt(value));
|
||||
} else if (key == "shard_id") {
|
||||
(void)builder->SetDeviceId(ToInt(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<TextFileOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,7 +58,8 @@ enum OpName {
|
|||
kVoc,
|
||||
kCifar10,
|
||||
kCifar100,
|
||||
kCelebA
|
||||
kCelebA,
|
||||
kTextFile
|
||||
};
|
||||
|
||||
// The C++ binder class that we expose to the python script.
|
||||
|
@ -148,6 +149,8 @@ class DEPipeline {
|
|||
|
||||
Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
private:
|
||||
// Execution tree that links the dataset operators.
|
||||
std::shared_ptr<ExecutionTree> tree_;
|
||||
|
|
|
@ -55,6 +55,7 @@
|
|||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/kernels/data/to_float16_op.h"
|
||||
#include "dataset/util/random.h"
|
||||
#include "mindrecord/include/shard_operator.h"
|
||||
|
@ -176,6 +177,17 @@ void bindDatasetOps(py::module *m) {
|
|||
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
|
||||
.def_static("get_num_rows", [](const py::list &files) {
|
||||
int64_t count = 0;
|
||||
std::vector<std::string> filenames;
|
||||
for (auto file : files) {
|
||||
!file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
|
||||
}
|
||||
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
|
||||
return count;
|
||||
});
|
||||
}
|
||||
void bindTensor(py::module *m) {
|
||||
(void)py::class_<GlobalContext>(*m, "GlobalContext")
|
||||
|
@ -463,7 +475,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
.value("VOC", OpName::kVoc)
|
||||
.value("CIFAR10", OpName::kCifar10)
|
||||
.value("CIFAR100", OpName::kCifar100)
|
||||
.value("CELEBA", OpName::kCelebA);
|
||||
.value("CELEBA", OpName::kCelebA)
|
||||
.value("TEXTFILE", OpName::kTextFile);
|
||||
|
||||
(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
|
||||
.value("DE_INTER_LINEAR", InterpolationMode::kLinear)
|
||||
|
|
|
@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT
|
|||
manifest_op.cc
|
||||
cifar_op.cc
|
||||
celeba_op.cc
|
||||
text_file_op.cc
|
||||
)
|
||||
|
||||
add_dependencies(engine-datasetops-source mindspore::protobuf)
|
||||
|
|
|
@ -0,0 +1,459 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "dataset/util/wait_post.h"
|
||||
#include "dataset/util/random.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
TextFileOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
builder_rows_per_buffer_ = config_manager->rows_per_buffer();
|
||||
builder_worker_connector_size_ = config_manager->worker_connector_size();
|
||||
}
|
||||
|
||||
Status TextFileOp::Builder::ValidateInputs() const {
|
||||
std::string err_msg;
|
||||
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : "";
|
||||
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
|
||||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
||||
RETURN_IF_NOT_OK(ValidateInputs());
|
||||
|
||||
// Throttle the number of workers if we have more workers than files!
|
||||
if (static_cast<size_t>(builder_num_workers_) > builder_text_files_list_.size()) {
|
||||
builder_num_workers_ = builder_text_files_list_.size();
|
||||
MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers.";
|
||||
}
|
||||
|
||||
builder_schema_ = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(
|
||||
builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
|
||||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
|
||||
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
|
||||
builder_num_devices_, builder_device_id_);
|
||||
RETURN_IF_NOT_OK(text_file_op->Init());
|
||||
*op = std::move(text_file_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_samples_(num_samples),
|
||||
text_files_list_(std::move(text_files_list)),
|
||||
shuffle_files_(shuffle_files),
|
||||
data_schema_(std::move(schema)),
|
||||
all_num_rows_(0),
|
||||
num_rows_per_shard_(0),
|
||||
filename_index_(std::make_unique<StringIndex>()),
|
||||
finished_reading_dataset_(false),
|
||||
load_io_block_queue_(true),
|
||||
load_jagged_connector_(true) {
|
||||
worker_connector_size_ = worker_connector_size;
|
||||
}
|
||||
|
||||
Status TextFileOp::Init() {
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_));
|
||||
|
||||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(text_files_list_.size() / num_workers_) + 1);
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
col_name_map_[data_schema_->column(i).name()] = i;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
|
||||
jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::Reset() {
|
||||
load_jagged_connector_ = true;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset());
|
||||
NotifyToFillIOBlockQueue();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) {
|
||||
TensorRow tRow(1, nullptr);
|
||||
(*tensor_table)->push_back(std::move(tRow));
|
||||
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateTensor(&tensor, data_schema_->column(0).tensorImpl(),
|
||||
TensorShape(std::vector<dsize_t>(1, line.size())), data_schema_->column(0).type(),
|
||||
const_cast<unsigned char *>(reinterpret_cast<const unsigned char *>(common::SafeCStr(line)))));
|
||||
(**tensor_table)[row][0] = std::move(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file " + file);
|
||||
}
|
||||
|
||||
int64_t rows_each_buffer = 0;
|
||||
int64_t rows_total = 0;
|
||||
std::string line;
|
||||
std::unique_ptr<DataBuffer> cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
|
||||
cur_buffer->set_column_name_map(col_name_map_);
|
||||
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
|
||||
|
||||
while (getline(handle, line)) {
|
||||
// If read to the end offset of this file, break.
|
||||
if (rows_total >= end_offset) {
|
||||
break;
|
||||
}
|
||||
// Skip line before start offset.
|
||||
if (rows_total < start_offset) {
|
||||
rows_total++;
|
||||
continue;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer));
|
||||
rows_each_buffer++;
|
||||
rows_total++;
|
||||
if (rows_each_buffer == rows_per_buffer_) {
|
||||
cur_buffer->set_tensor_table(std::move(tensor_table));
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
|
||||
|
||||
cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
|
||||
cur_buffer->set_column_name_map(col_name_map_);
|
||||
tensor_table = std::make_unique<TensorQTable>();
|
||||
rows_each_buffer = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (rows_each_buffer > 0) {
|
||||
cur_buffer->set_tensor_table(std::move(tensor_table));
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::unique_ptr<FilenameBlock> io_block;
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pops an element from a queue in io_block_queues
|
||||
Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes an element to a queue in io_block_queues
|
||||
Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
Status TextFileOp::PostEndOfData() {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
Status TextFileOp::PostEndOfEpoch(int32_t queue_index) {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
|
||||
std::mt19937 rng(seed);
|
||||
std::shuffle(i_keys->begin(), i_keys->end(), rng);
|
||||
}
|
||||
|
||||
bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count) {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool push = false;
|
||||
int64_t start_index = device_id_ * num_rows_per_shard_;
|
||||
if (device_id_ + 1 < 0) {
|
||||
MS_LOG(ERROR) << "Device id is invalid";
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
||||
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
|
||||
*start_offset = start_index - pre_count;
|
||||
push = true;
|
||||
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
if (pre_count >= start_index && pre_count < end_index) {
|
||||
*start_offset = 0;
|
||||
push = true;
|
||||
if (pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
return push;
|
||||
}
|
||||
|
||||
Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
int64_t pre_count = 0;
|
||||
int64_t start_offset = 0;
|
||||
int64_t end_offset = 0;
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
std::vector<std::pair<std::string, int64_t>> file_index;
|
||||
if (!i_keys.empty()) {
|
||||
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto file_it = filename_index_->Search(*it);
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>(file_it.value(), *it));
|
||||
}
|
||||
} else {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
|
||||
}
|
||||
}
|
||||
for (auto file_info : file_index) {
|
||||
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
|
||||
auto ioBlock =
|
||||
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
||||
queue_index = (queue_index + 1) % num_workers_;
|
||||
}
|
||||
|
||||
pre_count += filename_numrows_[file_info.first];
|
||||
}
|
||||
|
||||
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
|
||||
finish = false;
|
||||
} else {
|
||||
finish = true;
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::WaitToFillIOBlockQueue() {
|
||||
// must be called first if called by worker spanwed by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::vector<int64_t> i_keys;
|
||||
if (shuffle_files_) {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
i_keys.push_back(it.key());
|
||||
}
|
||||
}
|
||||
uint32_t seed = 0;
|
||||
while (true) {
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
|
||||
io_block_queue_wait_post_.Clear();
|
||||
|
||||
if (finished_reading_dataset_) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (shuffle_files_) {
|
||||
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
|
||||
}
|
||||
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
|
||||
|
||||
Status TextFileOp::operator()() {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
|
||||
// launch one thread, responsible for filling IoBlockQueue
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this)));
|
||||
|
||||
// Read data from disk into buffers
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1)));
|
||||
|
||||
// must be called after launching workers.
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
io_block_queue_wait_post_.Register(tree_->AllTasks());
|
||||
NotifyToFillIOBlockQueue();
|
||||
while (!finished_reading_dataset_) {
|
||||
int64_t buffer_id = 0;
|
||||
int32_t workers_done = 0;
|
||||
int64_t rows_read = 0;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
while (workers_done < num_workers_) {
|
||||
std::unique_ptr<DataBuffer> buffer;
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
|
||||
if (buffer->eoe()) {
|
||||
workers_done++;
|
||||
} else if (num_samples_ == 0 || rows_read < num_samples_) {
|
||||
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
|
||||
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
|
||||
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
|
||||
}
|
||||
rows_read += buffer->NumRows();
|
||||
buffer->set_id(buffer_id++);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
|
||||
} else {
|
||||
// end of epoch
|
||||
load_jagged_connector_ = false;
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfData());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t TextFileOp::CountTotalRows(const std::string &file) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
MS_LOG(ERROR) << "Failed to open file: " << file;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int64_t count = 0;
|
||||
while (getline(handle, line)) {
|
||||
count++;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
Status TextFileOp::CalculateNumRowsPerShard() {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
int64_t count = CountTotalRows(it.value());
|
||||
filename_numrows_[it.value()] = count;
|
||||
all_num_rows_ += count;
|
||||
}
|
||||
if (all_num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Number of rows can not be zero");
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
|
||||
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
|
||||
std::shared_ptr<TextFileOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op));
|
||||
for (auto file : files) {
|
||||
*count += op->CountTotalRows(file);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,263 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/util/status.h"
|
||||
#include "dataset/util/auto_index.h"
|
||||
#include "dataset/engine/data_schema.h"
|
||||
#include "dataset/engine/datasetops/parallel_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/util/queue.h"
|
||||
#include "dataset/util/wait_post.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
using StringIndex = AutoIndexObj<std::string>;
|
||||
|
||||
class TextFileOp : public ParallelOp {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
// Builder constructor. Creates the builder object.
|
||||
// @note No default args
|
||||
// @return This is a constructor.
|
||||
Builder();
|
||||
|
||||
// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
// Checks if the inputs of the builder is valid.
|
||||
// @return Status - the error code returned.
|
||||
Status ValidateInputs() const;
|
||||
|
||||
// Create the final object.
|
||||
// @param op - dataset op.
|
||||
// @return - the error code return.
|
||||
Status Build(std::shared_ptr<TextFileOp> *op);
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
builder_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t op_connector_size) {
|
||||
builder_op_connector_size_ = op_connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
|
||||
builder_rows_per_buffer_ = rows_per_buffer;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumDevices(int64_t num_dev) {
|
||||
builder_num_devices_ = num_dev;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetDeviceId(int64_t dev_id) {
|
||||
builder_device_id_ = dev_id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetTextFilesList(const std::vector<std::string> &files_list) {
|
||||
builder_text_files_list_ = files_list;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetShuffleFiles(bool shuffle_files) {
|
||||
builder_shuffle_files_ = shuffle_files;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
int64_t builder_rows_per_buffer_;
|
||||
int64_t builder_num_samples_;
|
||||
int32_t builder_worker_connector_size_;
|
||||
std::vector<std::string> builder_text_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
std::unique_ptr<DataSchema> builder_schema_;
|
||||
};
|
||||
|
||||
// Constructor of TextFileOp
|
||||
// @note The builder class should be used to call this constructor.
|
||||
// @param num_workers - number of worker threads reading data from tf_file files.
|
||||
// @param rows_per_buffer - number of rows that a full buffer will contain.
|
||||
// @param total_num_rows - number of rows to read
|
||||
// @param dataset_files_list - list of filepaths for the dataset files.
|
||||
// @param data_schema - the data schema object.
|
||||
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
|
||||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Default destructor
|
||||
~TextFileOp() = default;
|
||||
|
||||
// Instantiates the internal queues and connectors
|
||||
// @return Status - the error code returned
|
||||
Status Init();
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - the error code returned.
|
||||
Status operator()() override;
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
// @return Status - the error code returned.
|
||||
Status Reset() override;
|
||||
|
||||
// Get total rows in files.
|
||||
// @param files - all text files.
|
||||
// @param count - number of rows.
|
||||
// @return Status - the error coed returned.
|
||||
static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count);
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Parses a single row and puts the data into a tensor table.
|
||||
// @param line - the content of the row.
|
||||
// @param tensor_table - the tensor table to put the parsed data in.
|
||||
// @param row - the id of the row filled in the tensor table.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row);
|
||||
|
||||
// Reads a text file and loads the data into multiple buffers.
|
||||
// @param file - the file to read.
|
||||
// @param start_offset - the start offset of file.
|
||||
// @param end_offset - the end offset of file.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id);
|
||||
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
Status CalculateNumRowsPerShard();
|
||||
|
||||
// Count number of rows in each file.
|
||||
// @param filename - text file name.
|
||||
// @return int64_t - the total number of rows in file.
|
||||
int64_t CountTotalRows(const std::string &file);
|
||||
|
||||
// Notifies the thread which called FillIoBlockQueue to resume execution
|
||||
void NotifyToFillIOBlockQueue();
|
||||
|
||||
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
|
||||
// @return Status - the error code returned.
|
||||
Status WaitToFillIOBlockQueue();
|
||||
|
||||
// Fill the IOBlockQueue.
|
||||
// @para i_keys - keys of file to fill to the IOBlockQueue
|
||||
// @return Status - the error code returned.
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
|
||||
|
||||
// Select file and push it to the block queue.
|
||||
// @param file_name - File name.
|
||||
// @param start_file - If file contains the first sample of data.
|
||||
// @param end_file - If file contains the end sample of data.
|
||||
// @param pre_count - Total rows of previous files.
|
||||
// @return Status - the error code returned.
|
||||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Pops an element from a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to pop from.
|
||||
// @param out_block - the popped element.
|
||||
// @return Status - the error code returned.
|
||||
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
|
||||
|
||||
// Pushes an element to a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to push to.
|
||||
// @param io_block - the element to push onto the queue.
|
||||
// @return Status - the error code returned.
|
||||
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfData();
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfEpoch(int32_t queue_index);
|
||||
|
||||
int32_t device_id_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
int64_t num_samples_;
|
||||
std::vector<std::string> text_files_list_;
|
||||
bool shuffle_files_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
int64_t all_num_rows_;
|
||||
int64_t num_rows_per_shard_;
|
||||
std::map<std::string, int64_t> filename_numrows_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
WaitPost io_block_queue_wait_post_;
|
||||
bool finished_reading_dataset_;
|
||||
bool load_io_block_queue_;
|
||||
bool load_jagged_connector_;
|
||||
std::unordered_map<std::string, int32_t> col_name_map_;
|
||||
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
|
|
@ -20,8 +20,8 @@ can also create samplers with this module to sample data.
|
|||
|
||||
from .core.configuration import config
|
||||
from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \
|
||||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \
|
||||
Shuffle, zip
|
||||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
|
||||
Schema, Shuffle, zip
|
||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
WeightedRandomSampler
|
||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
|
@ -29,5 +29,5 @@ from .engine.serializer_deserializer import serialize, deserialize, show
|
|||
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset",
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
|
||||
"VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
|
||||
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
|
||||
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip"]
|
||||
|
|
|
@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset",
|
|||
"ImageFolderDatasetV2", "MnistDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset",
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
|
||||
"VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
|
||||
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]
|
||||
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler",
|
||||
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"]
|
||||
|
|
|
@ -29,7 +29,7 @@ from importlib import import_module
|
|||
|
||||
import numpy as np
|
||||
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
|
||||
MindRecordOp, CBatchInfo
|
||||
MindRecordOp, TextFileOp, CBatchInfo
|
||||
from mindspore._c_expression import typing
|
||||
|
||||
from mindspore import log as logger
|
||||
|
@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator
|
|||
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \
|
||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
|
||||
check_zip_dataset, check_add_column
|
||||
check_zip_dataset, check_add_column, check_textfiledataset
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
||||
try:
|
||||
|
@ -888,6 +888,29 @@ class SourceDataset(Dataset):
|
|||
|
||||
# No need for __init__ since it is the same as the super's init
|
||||
|
||||
@staticmethod
|
||||
def _find_files(patterns):
|
||||
"""
|
||||
Utility function to search for files with the given glob patterns.
|
||||
|
||||
Args:
|
||||
patterns (str or list[str]): string or list of patterns to be searched.
|
||||
|
||||
Returns:
|
||||
List, files.
|
||||
"""
|
||||
|
||||
def flat(lists):
|
||||
return list(np.array(lists).flatten())
|
||||
|
||||
if not isinstance(patterns, list):
|
||||
patterns = [patterns]
|
||||
|
||||
file_list = flat([glob.glob(file, recursive=True) for file in patterns])
|
||||
if file_list: # not empty
|
||||
return file_list
|
||||
raise ValueError("The list of path names matching the patterns is empty.")
|
||||
|
||||
|
||||
class DatasetOp(Dataset):
|
||||
"""
|
||||
|
@ -2126,30 +2149,6 @@ class TFRecordDataset(SourceDataset):
|
|||
>>> # 3) get all rows from dataset_files with schema file "./schema.json":
|
||||
>>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _find_files(patterns):
|
||||
"""
|
||||
Utility function to search for files with the given glob patterns.
|
||||
|
||||
Args:
|
||||
patterns (str or list[str]): string or list of patterns to be searched.
|
||||
|
||||
Returns:
|
||||
List, files.
|
||||
"""
|
||||
|
||||
def flat(lists):
|
||||
return list(np.array(lists).flatten())
|
||||
|
||||
if not isinstance(patterns, list):
|
||||
patterns = [patterns]
|
||||
|
||||
file_list = flat([glob.glob(file, recursive=True) for file in patterns])
|
||||
if file_list: # not empty
|
||||
return file_list
|
||||
raise ValueError("The list of path names matching the patterns is empty.")
|
||||
|
||||
@check_tfrecorddataset
|
||||
def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False):
|
||||
|
@ -2952,3 +2951,82 @@ class CelebADataset(SourceDataset):
|
|||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
return args
|
||||
|
||||
class TextFileDataset(SourceDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in text format.
|
||||
The generated dataset has one columns ['text'].
|
||||
|
||||
Args:
|
||||
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
|
||||
files. The list will be sorted in a lexicographical order.
|
||||
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
|
||||
num_parallel_workers (int, optional): number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
|
||||
If shuffle is False, no shuffling will be performed;
|
||||
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
||||
Otherwise, there are two levels of shuffling:
|
||||
|
||||
- Shuffle.GLOBAL: Shuffle both the files and samples.
|
||||
|
||||
- Shuffle.FILES: Shuffle files only.
|
||||
|
||||
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument should be specified only when num_shards is also specified.
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
|
||||
>>> dataset = ds.TextFileDataset(dataset_files=dataset_files)
|
||||
"""
|
||||
|
||||
@check_textfiledataset
|
||||
def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None,
|
||||
shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_files = self._find_files(dataset_files)
|
||||
self.dataset_files.sort()
|
||||
self.num_samples = num_samples
|
||||
|
||||
if not isinstance(shuffle, (bool, Shuffle)):
|
||||
raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
|
||||
if not isinstance(shuffle, Shuffle):
|
||||
if shuffle:
|
||||
self.shuffle_level = Shuffle.GLOBAL
|
||||
self.shuffle_files = True
|
||||
else:
|
||||
self.shuffle_level = None
|
||||
self.shuffle_files = False
|
||||
else:
|
||||
self.shuffle_level = shuffle
|
||||
self.shuffle_files = True
|
||||
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_files"] = self.dataset_files
|
||||
args["num_samples"] = self.num_samples
|
||||
if self.shuffle_files is not None:
|
||||
args["shuffle_files"] = self.shuffle_files
|
||||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
"""
|
||||
Get the number of batches in an epoch.
|
||||
|
||||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
num_rows = TextFileOp.get_num_rows(self.dataset_files)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is None:
|
||||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
return self._dataset_size
|
||||
|
|
|
@ -48,12 +48,16 @@ def alter_tree(node):
|
|||
|
||||
def _alter_node(node):
|
||||
"""Performing some alteration to a dataset node. A common alteration is to insert a node."""
|
||||
if isinstance(node, de.TFRecordDataset) and node.shuffle_level == de.Shuffle.GLOBAL:
|
||||
if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL:
|
||||
# Remove the connection between the parent's node to the current node because we are inserting a node.
|
||||
if node.output:
|
||||
node.output.pop()
|
||||
# Perform a fast scan for average rows per file
|
||||
avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files)
|
||||
if isinstance(node, de.TFRecordDataset):
|
||||
avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files)
|
||||
else:
|
||||
avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files)
|
||||
|
||||
# Shuffle between 4 files with a minimum size of 10000 rows
|
||||
new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000))
|
||||
return new_shuffle
|
||||
|
@ -157,6 +161,8 @@ class Iterator:
|
|||
op_type = OpName.CIFAR100
|
||||
elif isinstance(dataset, de.CelebADataset):
|
||||
op_type = OpName.CELEBA
|
||||
elif isinstance(dataset, de.TextFileDataset):
|
||||
op_type = OpName.TEXTFILE
|
||||
else:
|
||||
raise ValueError("Unsupported DatasetOp")
|
||||
|
||||
|
|
|
@ -849,3 +849,25 @@ def check_add_column(method):
|
|||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_textfiledataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
|
||||
# check dataset_files; required argument
|
||||
dataset_files = param_dict.get('dataset_files')
|
||||
if dataset_files is None:
|
||||
raise ValueError("dataset_files is not provided.")
|
||||
if not isinstance(dataset_files, (str, list)):
|
||||
raise TypeError("dataset_files should be of type str or a list of strings.")
|
||||
|
||||
check_param_type(nreq_param_int, param_dict, int)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This module is to support nlp augmentations. It includes two parts:
|
||||
c_transforms and py_transforms. C_transforms is a high performance
|
||||
image augmentation module which is developed with c++ opencv. Py_transforms
|
||||
provide more kinds of image augmentations which is developed with python PIL.
|
||||
"""
|
||||
from .utils import as_text
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Some basic function for nlp
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
def as_text(array, encoding='utf8'):
|
||||
"""
|
||||
Convert data of array to unicode.
|
||||
|
||||
Args:
|
||||
array (numpy array): Data of array should be ASCII values of each character after converted.
|
||||
encoding (string): Indicating the charset for decoding.
|
||||
Returns:
|
||||
A 'str' object.
|
||||
|
||||
"""
|
||||
|
||||
if not isinstance(array, np.ndarray):
|
||||
raise ValueError('input should be a numpy array')
|
||||
|
||||
byte_array = bytearray(list(array))
|
||||
return byte_array.decode(encoding)
|
|
@ -65,7 +65,7 @@ SET(DE_UT_SRCS
|
|||
cifar_op_test.cc
|
||||
celeba_op_test.cc
|
||||
take_op_test.cc
|
||||
)
|
||||
text_file_op_test.cc)
|
||||
|
||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "common/utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace common = mindspore::common;
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestTextFileOp : public UT::DatasetOpTesting {
|
||||
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestTextFileOp, TestTextFileBasic) {
|
||||
// Start with an empty execution tree
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
|
||||
std::shared_ptr<TextFileOp> op;
|
||||
TextFileOp::Builder builder;
|
||||
builder.SetTextFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(16)
|
||||
.SetNumWorkers(16)
|
||||
.SetOpConnectorSize(2);
|
||||
|
||||
Status rc = builder.Build(&op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->AssociateNode(op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->AssignRoot(op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration.";
|
||||
rc = tree->Prepare();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->Launch();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(tree);
|
||||
TensorRow tensor_list;
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
int row_count = 0;
|
||||
while (!tensor_list.empty()) {
|
||||
// Display the tensor by calling the printer on it
|
||||
for (int i = 0; i < tensor_list.size(); i++) {
|
||||
std::ostringstream ss;
|
||||
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
|
||||
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
|
||||
}
|
||||
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
row_count++;
|
||||
}
|
||||
|
||||
ASSERT_EQ(row_count, 3);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTextFileOp, TestTotalRows) {
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::vector<std::string> files;
|
||||
files.push_back(tf_file1);
|
||||
int64_t total_rows = 0;
|
||||
TextFileOp::CountAllFileRows(files, &total_rows);
|
||||
ASSERT_EQ(total_rows, 3);
|
||||
files.clear();
|
||||
|
||||
files.push_back(tf_file2);
|
||||
TextFileOp::CountAllFileRows(files, &total_rows);
|
||||
ASSERT_EQ(total_rows, 2);
|
||||
files.clear();
|
||||
|
||||
files.push_back(tf_file1);
|
||||
files.push_back(tf_file2);
|
||||
TextFileOp::CountAllFileRows(files, &total_rows);
|
||||
ASSERT_EQ(total_rows, 5);
|
||||
files.clear();
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
This is a text file.
|
||||
Be happy every day.
|
||||
Good luck to everyone.
|
|
@ -0,0 +1,2 @@
|
|||
Another file.
|
||||
End of file.
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.transforms.nlp.utils as nlp
|
||||
|
||||
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
|
||||
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
|
||||
|
||||
def test_textline_dataset_one_file():
|
||||
data = ds.TextFileDataset(DATA_FILE)
|
||||
count = 0
|
||||
for i in data.create_dict_iterator():
|
||||
logger.info("{}".format(i["text"]))
|
||||
count += 1
|
||||
assert(count == 3)
|
||||
|
||||
def test_textline_dataset_all_file():
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE)
|
||||
count = 0
|
||||
for i in data.create_dict_iterator():
|
||||
logger.info("{}".format(i["text"]))
|
||||
count += 1
|
||||
assert(count == 5)
|
||||
|
||||
def test_textline_dataset_totext():
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
|
||||
count = 0
|
||||
line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."]
|
||||
for i in data.create_dict_iterator():
|
||||
str = nlp.as_text(i["text"])
|
||||
assert(str == line[count])
|
||||
count += 1
|
||||
assert(count == 5)
|
||||
|
||||
def test_textline_dataset_num_samples():
|
||||
data = ds.TextFileDataset(DATA_FILE, num_samples=2)
|
||||
count = 0
|
||||
for i in data.create_dict_iterator():
|
||||
count += 1
|
||||
assert(count == 2)
|
||||
|
||||
def test_textline_dataset_distribution():
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1)
|
||||
count = 0
|
||||
for i in data.create_dict_iterator():
|
||||
count += 1
|
||||
assert(count == 3)
|
||||
|
||||
def test_textline_dataset_repeat():
|
||||
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
|
||||
data = data.repeat(3)
|
||||
count = 0
|
||||
line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
|
||||
"This is a text file.", "Be happy every day.", "Good luck to everyone.",
|
||||
"This is a text file.", "Be happy every day.", "Good luck to everyone."]
|
||||
for i in data.create_dict_iterator():
|
||||
str = nlp.as_text(i["text"])
|
||||
assert(str == line[count])
|
||||
count += 1
|
||||
assert(count == 9)
|
||||
|
||||
def test_textline_dataset_get_datasetsize():
|
||||
data = ds.TextFileDataset(DATA_FILE)
|
||||
size = data.get_dataset_size()
|
||||
assert(size == 3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_textline_dataset_one_file()
|
||||
test_textline_dataset_all_file()
|
||||
test_textline_dataset_totext()
|
||||
test_textline_dataset_num_samples()
|
||||
test_textline_dataset_distribution()
|
||||
test_textline_dataset_repeat()
|
||||
test_textline_dataset_get_datasetsize()
|
Loading…
Reference in New Issue