add CLUE dataset

This commit is contained in:
jiangzhiwen 2020-05-28 19:30:24 +08:00
parent 3536185f5b
commit e0e167a000
33 changed files with 1676 additions and 12 deletions

View File

@ -31,6 +31,7 @@
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_distributed_sample.h"
@ -72,7 +73,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp}};
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
@ -1210,6 +1212,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
if (!p.second.is_none()) {
@ -1236,6 +1239,7 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
}
return Status::OK();
}
Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BuildVocabOp::Builder> builder = std::make_shared<BuildVocabOp::Builder>();
for (auto arg : args) {
@ -1267,5 +1271,45 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas
return Status::OK();
}
Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
if (!args["dataset_files"].is_none()) {
(void)builder->SetClueFilesList(ToStringVector(args["dataset_files"]));
} else {
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
}
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_shards") {
(void)builder->SetNumDevices(ToInt(value));
} else if (key == "shard_id") {
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "cols_to_keyword") {
std::map<std::string, std::string> map_dict;
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
if (!p.second.is_none()) {
map_dict.insert({ToString(p.first), ToString(p.second)});
} else {
map_dict.insert({ToString(p.first), ToString(p.first)});
}
}
(void)builder->SetColsKeyMap(map_dict);
}
}
}
std::shared_ptr<ClueOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -64,7 +64,8 @@ enum OpName {
kCelebA,
kRandomData,
kTextFile,
kBuildVocab
kBuildVocab,
kClue
};
// The C++ binder class that we expose to the python script.
@ -166,6 +167,8 @@ class DEPipeline {
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
private:
// Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_;

View File

@ -55,6 +55,7 @@
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/gnn/graph.h"
@ -201,6 +202,18 @@ void bindDatasetOps(py::module *m) {
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
return count;
});
(void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp")
.def_static("get_num_rows", [](const py::list &files) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
}
THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count));
return count;
});
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
@ -629,7 +642,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("RANDOMDATA", OpName::kRandomData)
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile);
.value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix)

View File

@ -19,4 +19,5 @@ add_library(engine-datasetops-source OBJECT
random_data_op.cc
celeba_op.cc
text_file_op.cc
clue_op.cc
)

View File

@ -0,0 +1,551 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/clue_op.h"
#include <string>
#include <vector>
#include <fstream>
#include <iomanip>
#include <utility>
#include "dataset/core/config_manager.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/random.h"
namespace mindspore {
namespace dataset {
ClueOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
builder_rows_per_buffer_ = config_manager->rows_per_buffer();
builder_worker_connector_size_ = config_manager->worker_connector_size();
}
Status ClueOp::Builder::ValidateInputs() const {
std::string err;
err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : "";
err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : "";
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err);
}
Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
RETURN_IF_NOT_OK(ValidateInputs());
// Throttle the number of workers if we have more workers than files!
if (static_cast<size_t>(builder_num_workers_) > builder_clue_files_list_.size()) {
builder_num_workers_ = builder_clue_files_list_.size();
MS_LOG(WARNING) << "ClueOp operator parallelism reduced to " << builder_num_workers_ << " workers.";
}
ColKeyMap ck_map;
for (auto &p : builder_cols_to_keyword_) {
ck_map.insert({p.first, split(p.second, '/')});
}
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
builder_device_id_);
RETURN_IF_NOT_OK(clue_op->Init());
*op = std::move(clue_op);
return Status::OK();
}
std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim) {
std::vector<std::string> res;
std::stringstream ss(s);
std::string item;
while (getline(ss, item, delim)) {
res.push_back(item);
}
return res;
}
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
num_samples_(num_samples),
filename_index_(std::make_unique<StringIndex>()),
clue_files_list_(std::move(clue_files_list)),
load_jagged_connector_(true),
cols_to_keyword_(cols_to_keyword),
shuffle_files_(shuffle_files),
finished_reading_dataset_(false),
num_devices_(num_device),
device_id_(device_id),
load_io_block_queue_(true) {
worker_connector_size_ = worker_connector_size;
}
Status ClueOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_));
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(clue_files_list_.size() / num_workers_) + 1);
io_block_queues_.Init(num_workers_, safe_queue_size);
// Set the column name mapping (base class field)
int count = 0;
for (auto &p : cols_to_keyword_) {
column_name_id_map_[p.first] = count;
count++;
}
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
return Status::OK();
}
Status ClueOp::Reset() {
load_jagged_connector_ = true;
load_io_block_queue_ = true;
RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}
Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) {
TensorRow tRow(1, nullptr);
(*tensor_table)->push_back(std::move(tRow));
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar()));
(**tensor_table)[row][0] = std::move(tensor);
return Status::OK();
}
Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) {
nlohmann::json cursor = js;
for (int i = 0; i < key_chain.size(); i++) {
if (cursor.find(key_chain[i]) != cursor.end()) {
cursor = cursor[key_chain[i]];
} else {
RETURN_STATUS_UNEXPECTED("Failed to find key: " + key_chain[i]);
}
}
std::string final_str = key_chain.back();
switch (cursor.type()) {
case nlohmann::detail::value_t::string:
RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get<std::string>()}, TensorShape::CreateScalar()));
break;
case nlohmann::detail::value_t::number_integer:
RETURN_IF_NOT_OK(
Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)));
(*t)->SetItemAt<int32_t>({0}, cursor.get<int32_t>());
break;
case nlohmann::detail::value_t::number_unsigned:
RETURN_IF_NOT_OK(
Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)));
(*t)->SetItemAt<int32_t>({0}, cursor.get<uint32_t>());
break;
case nlohmann::detail::value_t::number_float:
RETURN_IF_NOT_OK(
Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32)));
(*t)->SetItemAt<int32_t>({0}, cursor.get<float>());
break;
case nlohmann::detail::value_t::array:
RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get<std::vector<std::string>>()}, TensorShape::CreateScalar()));
break;
default:
break;
}
return Status::OK();
}
Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
std::ifstream handle(file);
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Failed to open file " + file);
}
int64_t rows_each_buffer = 0;
int64_t rows_total = 0;
std::string line;
std::unique_ptr<DataBuffer> cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
while (getline(handle, line)) {
if (line.empty()) {
continue;
}
// If read to the end offset of this file, break.
if (rows_total >= end_offset) {
break;
}
// Skip line before start offset.
if (rows_total < start_offset) {
rows_total++;
continue;
}
try {
nlohmann::json js = nlohmann::json::parse(line);
int cols_count = cols_to_keyword_.size();
TensorRow tRow(cols_count, nullptr);
tensor_table->push_back(std::move(tRow));
int cout = 0;
for (auto &p : cols_to_keyword_) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor));
(*tensor_table)[rows_each_buffer][cout] = std::move(tensor);
cout++;
}
} catch (const std::exception &err) {
// Catch any exception and convert to Status return code
RETURN_STATUS_UNEXPECTED("Failed to load json file");
}
// RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer));
rows_each_buffer++;
rows_total++;
if (rows_each_buffer == rows_per_buffer_) {
cur_buffer->set_tensor_table(std::move(tensor_table));
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
tensor_table = std::make_unique<TensorQTable>();
rows_each_buffer = 0;
}
}
if (rows_each_buffer > 0) {
cur_buffer->set_tensor_table(std::move(tensor_table));
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
}
return Status::OK();
}
Status ClueOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1)));
// must be called after launching workers.
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;
while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
}
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}
Status ClueOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}
// A print method typically used for debugging
void ClueOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <ClueOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n";
for (int i = 0; i < clue_files_list_.size(); ++i) {
out << " " << clue_files_list_[i];
}
out << "\n\n";
}
}
// Pops an element from a queue in io_block_queues
Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}
// Pushes an element to a queue in io_block_queues
Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}
Status ClueOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
if (finished_reading_dataset_) {
break;
}
if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}
Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
int64_t pre_count = 0;
int64_t start_offset = 0;
int64_t end_offset = 0;
bool finish = false;
while (!finish) {
std::vector<std::pair<std::string, int64_t>> file_index;
if (!i_keys.empty()) {
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
auto file_it = filename_index_->Search(*it);
file_index.emplace_back(std::pair<std::string, int64_t>(file_it.value(), *it));
}
} else {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
{
if (!load_io_block_queue_) {
break;
}
}
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
}
}
for (auto file_info : file_index) {
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
auto ioBlock =
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
queue_index = (queue_index + 1) % num_workers_;
}
pre_count += filename_numrows_[file_info.first];
}
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
finish = false;
} else {
finish = true;
}
}
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
return Status::OK();
}
void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
return push;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status ClueOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}
return Status::OK();
}
Status ClueOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count;
all_num_rows_ += count;
}
if (all_num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API "
"validation first.");
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK();
}
int64_t ClueOp::CountTotalRows(const std::string &file) {
std::ifstream handle(file);
if (!handle.is_open()) {
MS_LOG(ERROR) << "Failed to open file: " << file;
return 0;
}
std::string line;
int64_t count = 0;
while (getline(handle, line)) {
if (!line.empty()) {
count++;
}
}
return count;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status ClueOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}
return Status::OK();
}
Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
std::shared_ptr<ClueOp> op;
*count = 0;
RETURN_IF_NOT_OK(Builder().SetClueFilesList(files).Build(&op));
for (auto file : files) {
*count += op->CountTotalRows(file);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,270 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
#include <memory>
#include <map>
#include <mutex>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>
#include "dataset/util/auto_index.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
namespace mindspore {
namespace dataset {
using StringIndex = AutoIndexObj<std::string>;
using ColKeyMap = std::map<std::string, std::vector<std::string>>;
class JaggedConnector;
class ClueOp : public ParallelOp {
public:
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
// Checks if the inputs of the builder is valid.
// @return Status - the error code returned.
Status ValidateInputs() const;
// Create the final object.
// @param op - dataset op.
// @return - the error code return.
Status Build(std::shared_ptr<ClueOp> *op);
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumDevices(int64_t num_dev) {
builder_num_devices_ = num_dev;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetDeviceId(int64_t dev_id) {
builder_device_id_ = dev_id;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetClueFilesList(const std::vector<std::string> &files_list) {
builder_clue_files_list_ = files_list;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetShuffleFiles(bool shuffle_files) {
builder_shuffle_files_ = shuffle_files;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetNumSamples(int64_t num_samples) {
builder_num_samples_ = num_samples;
return *this;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder &SetColsKeyMap(const std::map<std::string, std::string> &cols_to_key) {
builder_cols_to_keyword_ = cols_to_key;
return *this;
}
// Split string based on a character delimiter
// @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim);
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
int64_t builder_num_samples_;
int32_t builder_worker_connector_size_;
std::vector<std::string> builder_clue_files_list_;
bool builder_shuffle_files_;
std::map<std::string, std::string> builder_cols_to_keyword_;
};
// Constructor of ClueOp
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~ClueOp() = default;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status Init();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Get total rows in files.
// @param files - all clue files.
// @param count - number of rows.
// @return Status - the error coed returned.
static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count);
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
// @param row - the id of the row filled in the tensor table.
// @return Status - the error code returned.
Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row);
// Reads a clue file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();
// Count number of rows in each file.
// @param filename - clue file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// @return Status - the error code returned.
Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t);
int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> clue_files_list_;
WaitPost io_block_queue_wait_post_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
ColKeyMap cols_to_keyword_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_

View File

@ -43,7 +43,7 @@ TextFileOp::Builder::Builder()
Status TextFileOp::Builder::ValidateInputs() const {
std::string err_msg;
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : "";
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : "";
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
}

View File

@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
from .core.configuration import config
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
TextFileDataset, Schema, Shuffle, zip, RandomDataset
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show
@ -29,6 +29,6 @@ from .engine.graphdata import GraphData
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset",
"CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler",
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]

View File

@ -30,7 +30,7 @@ from ..core.configuration import config, ConfigurationManager
__all__ = ["config", "ConfigurationManager", "zip",
"ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "BuildVocabDataset", "Schema", "Schema",
"DistributedSampler", "PKSampler",

View File

@ -33,7 +33,7 @@ import copy
import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
MindRecordOp, TextFileOp, VOCOp, CocoOp, CBatchInfo
MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo
from mindspore._c_expression import typing
from mindspore import log as logger
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_split
check_split, check_cluedataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
@ -4317,6 +4317,222 @@ class CelebADataset(MappableDataset):
return self.sampler.is_sharded()
class CLUEDataset(SourceDataset):
"""
A source dataset that reads and parses CLUE datasets.
CLUE, the Chinese Language Understanding Evaluation Benchmark, a collection of datasets, baselines, pre-trained
models, corpus and leaderboard. Here we bring in classification task of CLUE, which are AFQMC, TNEWS, IFLYTEK,
CMNLI, WSC and CSL.
Args:
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
files. The list will be sorted in a lexicographical order.
task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
(default=AFQMC).
usage (str, optional): Need train, test or eval data (default="train").
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Examples:
>>> import mindspore.dataset as ds
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
>>> dataset = ds.CLUEDataset(dataset_files=dataset_files, task='AFQMC', usage='train')
"""
@check_cluedataset
def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files)
self.dataset_files.sort()
self.num_samples = num_samples
self.task_dict = {
'AFQMC': {
'train': {
'sentence1': 'sentence1',
'sentence2': 'sentence2',
'label': 'label'
},
'test': {
'id': 'id',
'sentence1': 'sentence1',
'sentence2': 'sentence2'
},
'eval': {
'sentence1': 'sentence1',
'sentence2': 'sentence2',
'label': 'label'
}
},
'CMNLI': {
'train': {
'sentence1': 'sentence1',
'sentence2': 'sentence2',
'label': 'label'
},
'test': {
'id': 'id',
'sentence1': 'sentence1',
'sentence2': 'sentence2'
},
'eval': {
'sentence1': 'sentence1',
'sentence2': 'sentence2',
'label': 'label'
}
},
'CSL': {
'train': {
'id': 'id',
'abst': 'abst',
'keyword': 'keyword',
'label': 'label'
},
'test': {
'id': 'id',
'abst': 'abst',
'keyword': 'keyword'
},
'eval': {
'id': 'id',
'abst': 'abst',
'keyword': 'keyword',
'label': 'label'
}
},
'IFLYTEK': {
'train': {
'label': 'label',
'label_des': 'label_des',
'sentence': 'sentence'
},
'test': {
'id': 'id',
'sentence': 'sentence',
},
'eval': {
'label': 'label',
'label_des': 'label_des',
'sentence': 'sentence'
}
},
'TNEWS': {
'train': {
'label': 'label',
'label_desc': 'label_desc',
'sentence': 'sentence',
'keywords': 'keywords'
},
'test': {
'id': 'id',
'sentence': 'sentence',
'keywords': 'keywords'
},
'eval': {
'label': 'label',
'label_desc': 'label_desc',
'sentence': 'sentence',
'keywords': 'keywords'
}
},
'WSC': {
'train': {
'span1_index': 'target/span1_index',
'span2_index': 'target/span2_index',
'span1_text': 'target/span1_text',
'span2_text': 'target/span2_text',
'idx': 'idx',
'label': 'label',
'text': 'text'
},
'test': {
'span1_index': 'target/span1_index',
'span2_index': 'target/span2_index',
'span1_text': 'target/span1_text',
'span2_text': 'target/span2_text',
'idx': 'idx',
'text': 'text'
},
'eval': {
'span1_index': 'target/span1_index',
'span2_index': 'target/span2_index',
'span1_text': 'target/span1_text',
'span2_text': 'target/span2_text',
'idx': 'idx',
'label': 'label',
'text': 'text'
}
}
}
self.cols_to_keyword = self.task_dict[task][usage]
if not isinstance(shuffle, (bool, Shuffle)):
raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
if not isinstance(shuffle, Shuffle):
if shuffle:
self.shuffle_level = Shuffle.GLOBAL
self.shuffle_files = True
else:
self.shuffle_level = None
self.shuffle_files = False
else:
self.shuffle_level = shuffle
self.shuffle_files = True
self.num_shards = num_shards
self.shard_id = shard_id
def get_args(self):
args = super().get_args()
args["dataset_files"] = self.dataset_files
args["num_samples"] = self.num_samples
if self.shuffle_files is not None:
args["shuffle_files"] = self.shuffle_files
args["shuffle"] = self.shuffle_level
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
args["cols_to_keyword"] = self.cols_to_keyword
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if self._dataset_size is None:
num_rows = ClueOp.get_num_rows(self.dataset_files)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples is None:
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
def is_shuffled(self):
return self.shuffle_files
def is_sharded(self):
if self.num_shards is not None:
return self.num_shards > 1
return False
class TextFileDataset(SourceDataset):
"""
A source dataset that reads and parses datasets stored on disk in text format.

View File

@ -50,7 +50,8 @@ def alter_tree(node):
def _alter_node(node):
"""Performing some alteration to a dataset node. A common alteration is to insert a node."""
if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL:
if isinstance(node, (de.TFRecordDataset, de.TextFileDataset, de.CLUEDataset)) \
and node.shuffle_level == de.Shuffle.GLOBAL:
# Remove the connection between the parent's node to the current node because we are inserting a node.
if node.output:
node.output.pop()
@ -179,6 +180,8 @@ class Iterator:
op_type = OpName.TEXTFILE
elif isinstance(dataset, de.BuildVocabDataset):
op_type = OpName.BUILDVOCAB
elif isinstance(dataset, de.CLUEDataset):
op_type = OpName.CLUE
else:
raise ValueError("Unsupported DatasetOp")

View File

@ -1075,6 +1075,41 @@ def check_add_column(method):
return new_method
def check_cluedataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset)."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_files; required argument
dataset_files = param_dict.get('dataset_files')
if dataset_files is None:
raise ValueError("dataset_files is not provided.")
if not isinstance(dataset_files, (str, list)):
raise TypeError("dataset_files should be of type str or a list of strings.")
# check task
task_param = param_dict.get('task')
if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL")
# check usage
usage_param = param_dict.get('usage')
if usage_param not in ['train', 'test', 'eval']:
raise ValueError("usage should be train, test or eval")
check_param_type(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
return method(*args, **kwargs)
return new_method
def check_textfiledataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""

View File

@ -65,6 +65,7 @@ SET(DE_UT_SRCS
cifar_op_test.cc
celeba_op_test.cc
take_op_test.cc
clue_op_test.cc
text_file_op_test.cc
filter_op_test.cc
concat_op_test.cc

View File

@ -0,0 +1,117 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/util/status.h"
namespace common = mindspore::common;
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestCLUEOp : public UT::DatasetOpTesting {
};
TEST_F(MindDataTestCLUEOp, TestCLUEBasic) {
// Start with an empty execution tree
auto tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testCLUE/afqmc/train.json";
std::map<std::string, std::string> key_map;
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
std::shared_ptr<ClueOp> op;
ClueOp::Builder builder;
builder.SetClueFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetNumWorkers(16)
.SetOpConnectorSize(2)
.SetColsKeyMap(key_map);
Status rc = builder.Build(&op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssociateNode(op);
ASSERT_TRUE(rc.IsOk());
rc = tree->AssignRoot(op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 3);
}
TEST_F(MindDataTestCLUEOp, TestTotalRows) {
std::string tf_file1 = datasets_root_path_ + "/testCLUE/afqmc/train.json";
std::string tf_file2 = datasets_root_path_ + "/testCLUE/afqmc/dev.json";
std::vector<std::string> files;
files.push_back(tf_file1);
int64_t total_rows = 0;
ClueOp::CountAllFileRows(files, &total_rows);
ASSERT_EQ(total_rows, 3);
files.clear();
files.push_back(tf_file2);
ClueOp::CountAllFileRows(files, &total_rows);
ASSERT_EQ(total_rows, 3);
files.clear();
files.push_back(tf_file1);
files.push_back(tf_file2);
ClueOp::CountAllFileRows(files, &total_rows);
ASSERT_EQ(total_rows, 6);
files.clear();
}

View File

@ -0,0 +1,3 @@
{"sentence1": "你有花呗吗", "sentence2": "我的花呗没额度了", "label": "0"}
{"sentence1": "吃饭能用花呗吗", "sentence2": "花呗太方便了", "label": "0"}
{"sentence1": "蚂蚁花呗支付金额有什么限制", "sentence2": "我到实体店消费用花呗支付受金额限制", "label": "1"}

View File

@ -0,0 +1,3 @@
{"id": 0, "sentence1": "借呗取消的时间", "sentence2": "蚂蚁借呗恢复的月数"}
{"id": 1, "sentence1": "网商贷用什么方法转变成借呗", "sentence2": "什么手段能将网商贷切换为借呗"}
{"id": 2, "sentence1": "我的借呗为什么开通不了", "sentence2": "我为啥没法开通借呗"}

View File

@ -0,0 +1,3 @@
{"sentence1": "蚂蚁借呗等额还款能否换成先息后本", "sentence2": "借呗可以先息到期还本吗", "label": "0"}
{"sentence1": "蚂蚁花呗说我违约了", "sentence2": "蚂蚁花呗违约行为是啥", "label": "0"}
{"sentence1": "帮我看看本月花呗账单结清了没", "sentence2": "上月的花呗账单", "label": "0"}

View File

@ -0,0 +1,3 @@
{"sentence1": "每个人都有权利", "sentence2": "每个人都有福利", "label": "neutral"}
{"sentence1": "有时候我喜欢他,但我也喜欢看到有人打他", "sentence2": "说实话,我有点喜欢他,但还是喜欢看到有人打他。", "label": "entailment"}
{"sentence1": "我最喜欢的餐馆是离你最近的一家", "sentence2": "我最喜欢的餐馆离你家至少一百英里远。", "label": "contradiction"}

View File

@ -0,0 +1,3 @@
{"id": 0, "sentence1": "今天,全球都在看着最新航天飞机的处女航。", "sentence2": "全世界都在看最新的航天飞机发射。"}
{"id": 1, "sentence1": "而我们把竹篮放在一个地方,把玻璃瓶放在另一处,把书放在另一处,满了要把它放到车里", "sentence2": "我们没有分开任何东西,都把它全扔进一个箱子里。"}
{"id": 2, "sentence1": "她占用了我的很多时间,她给我读了很多关于灵异的故事,我觉得很无聊。", "sentence2": "我喜欢和她一起读鬼故事。"}

View File

@ -0,0 +1,3 @@
{"sentence1": "你应该给这件衣服定一个价格。", "sentence2": "不同的衣服有不同的价格。", "label": "neutral"}
{"sentence1": "我怎么知道他要说什么", "sentence2": "他说什么我并不知道。", "label": "entailment"}
{"sentence1": "向左。", "sentence2": "向右。", "label": "contradiction"}

View File

@ -0,0 +1,3 @@
{"id": 1, "abst": "这是第一段很长的文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "1"}
{"id": 2, "abst": "这是第二段很长的文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "1"}
{"id": 3, "abst": "这是第三段很长的文本", "keyword": ["1", "2", "3"], "label": "0"}

View File

@ -0,0 +1,3 @@
{"id": 2415, "abst": "长文本1", "keyword": ["关键词1", "关键词2"]}
{"id": 2565, "abst": "长文本2", "keyword": ["关键词1", "关键词2", "关键词3"]}
{"id": 2625, "abst": "长文本3", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"]}

View File

@ -0,0 +1,3 @@
{"id": 1, "abst": "这是一段长文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "0"}
{"id": 2, "abst": "这是一段长文本", "keyword": ["关键词5", "关键词6", "关键词7", "关键词8"], "label": "0"}
{"id": 3, "abst": "这是一段长文本", "keyword": ["关键词9", "关键词10", "关键词11", "关键词12"], "label": "0"}

View File

@ -0,0 +1,3 @@
{"label": "110", "label_des": "社区超市", "sentence": "这是第一段文本"}
{"label": "70", "label_des": "工具", "sentence": "这是第二段文本"}
{"label": "10", "label_des": "社区服务", "sentence": "这是第三段文本"}

View File

@ -0,0 +1,3 @@
{"id": 0, "sentence": "文本1"}
{"id": 1, "sentence": "文本2"}
{"id": 2, "sentence": "文本3"}

View File

@ -0,0 +1,3 @@
{"label": "11", "label_des": "薅羊毛", "sentence": "第一个文本"}
{"label": "95", "label_des": "借贷", "sentence": "第二个文本"}
{"label": "74", "label_des": "违章", "sentence": "第三个文本"}

View File

@ -0,0 +1,3 @@
{"label": "102", "label_desc": "news_entertainment", "sentence": "新闻1", "keywords": "关键词一,关键词二,关键词三,关键词四"}
{"label": "110", "label_desc": "news_military", "sentence": "新闻2", "keywords": "关键词一,关键词二,关键词三,关键词四,关键词五"}
{"label": "104", "label_desc": "news_finance", "sentence": "新闻3", "keywords": "关键词一,关键词二,关键词三,关键词四,关键词五"}

View File

@ -0,0 +1,3 @@
{"id": 0, "sentence": "新闻1", "keywords": "关键词1,关键词2,关键词3,关键词4,关键词5"}
{"id": 1, "sentence": "新闻2", "keywords": "关键词1,关键词2,关键词3,关键词4"}
{"id": 2, "sentence": "新闻3", "keywords": ""}

View File

@ -0,0 +1,3 @@
{"label": "108", "label_desc": "news_edu", "sentence": "新闻1", "keywords": ""}
{"label": "104", "label_desc": "news_finance", "sentence": "新闻2", "keywords": "关键词1,关键词2,关键词3,关键词4,关键词5,关键词6"}
{"label": "106", "label_desc": "news_house", "sentence": "新闻3", "keywords": ""}

View File

@ -0,0 +1,3 @@
{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?", "label": "true"}
{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场", "label": "false"}
{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业", "label": "true"}

View File

@ -0,0 +1,3 @@
{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?"}
{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场"}
{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业"}

View File

@ -0,0 +1,3 @@
{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?", "label": "true"}
{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场", "label": "false"}
{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业", "label": "true"}

View File

@ -0,0 +1,355 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
def test_clue():
"""
Test CLUE with repeat, skip and so on
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
data = data.repeat(2)
data = data.skip(3)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_num_shards():
"""
Test num_shards param of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 1
def test_clue_num_samples():
"""
Test num_samples param of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2)
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 2
def test_textline_dataset_get_datasetsize():
"""
Test get_dataset_size of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
data = ds.TextFileDataset(TRAIN_FILE)
size = data.get_dataset_size()
assert size == 3
def test_clue_afqmc():
"""
Test AFQMC for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json'
EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'id': d['id'],
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
# evaluation
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_cmnli():
"""
Test CMNLI for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json'
TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json'
EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'sentence1': d['sentence1'].item().decode("utf8"),
'sentence2': d['sentence2'].item().decode("utf8")
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'id': d['id'],
'sentence1': d['sentence1'],
'sentence2': d['sentence2']
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'],
'sentence1': d['sentence1'],
'sentence2': d['sentence2']
})
assert len(buffer) == 3
def test_clue_csl():
"""
Test CSL for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json'
TEST_FILE = '../data/dataset/testCLUE/csl/test.json'
EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'id': d['id'],
'abst': d['abst'].item().decode("utf8"),
'keyword': [i.item().decode("utf8") for i in d['keyword']],
'label': d['label'].item().decode("utf8")
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'id': d['id'],
'abst': d['abst'].item().decode("utf8"),
'keyword': [i.item().decode("utf8") for i in d['keyword']],
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'id': d['id'],
'abst': d['abst'].item().decode("utf8"),
'keyword': [i.item().decode("utf8") for i in d['keyword']],
'label': d['label'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_iflytek():
"""
Test IFLYTEK for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json'
TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json'
EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'label_des': d['label_des'].item().decode("utf8"),
'sentence': d['sentence'].item().decode("utf8"),
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'id': d['id'],
'sentence': d['sentence'].item().decode("utf8")
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'label_des': d['label_des'].item().decode("utf8"),
'sentence': d['sentence'].item().decode("utf8")
})
assert len(buffer) == 3
def test_clue_tnews():
"""
Test TNEWS for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json'
TEST_FILE = '../data/dataset/testCLUE/tnews/test.json'
EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'label_desc': d['label_desc'].item().decode("utf8"),
'sentence': d['sentence'].item().decode("utf8"),
'keywords':
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'id': d['id'],
'sentence': d['sentence'].item().decode("utf8"),
'keywords':
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False)
for d in data.create_dict_iterator():
buffer.append({
'label': d['label'].item().decode("utf8"),
'label_desc': d['label_desc'].item().decode("utf8"),
'sentence': d['sentence'].item().decode("utf8"),
'keywords':
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
})
assert len(buffer) == 3
def test_clue_wsc():
"""
Test WSC for train, test and evaluation
"""
TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json'
TEST_FILE = '../data/dataset/testCLUE/wsc/test.json'
EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json'
# train
buffer = []
data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train')
for d in data.create_dict_iterator():
buffer.append({
'span1_index': d['span1_index'],
'span2_index': d['span2_index'],
'span1_text': d['span1_text'].item().decode("utf8"),
'span2_text': d['span2_text'].item().decode("utf8"),
'idx': d['idx'],
'label': d['label'].item().decode("utf8"),
'text': d['text'].item().decode("utf8")
})
assert len(buffer) == 3
# test
buffer = []
data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test')
for d in data.create_dict_iterator():
buffer.append({
'span1_index': d['span1_index'],
'span2_index': d['span2_index'],
'span1_text': d['span1_text'].item().decode("utf8"),
'span2_text': d['span2_text'].item().decode("utf8"),
'idx': d['idx'],
'text': d['text'].item().decode("utf8")
})
assert len(buffer) == 3
# eval
buffer = []
data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval')
for d in data.create_dict_iterator():
buffer.append({
'span1_index': d['span1_index'],
'span2_index': d['span2_index'],
'span1_text': d['span1_text'].item().decode("utf8"),
'span2_text': d['span2_text'].item().decode("utf8"),
'idx': d['idx'],
'label': d['label'].item().decode("utf8"),
'text': d['text'].item().decode("utf8")
})
assert len(buffer) == 3
if __name__ == "__main__":
test_clue()
test_clue_afqmc()
test_clue_cmnli()
test_clue_csl()
test_clue_iflytek()
test_clue_tnews()
test_clue_wsc()