add CLUE dataset
This commit is contained in:
parent
3536185f5b
commit
e0e167a000
|
@ -31,6 +31,7 @@
|
|||
#include "dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "dataset/engine/datasetops/filter_op.h"
|
||||
#include "mindrecord/include/shard_category.h"
|
||||
#include "mindrecord/include/shard_distributed_sample.h"
|
||||
|
@ -72,7 +73,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
|
|||
{kCelebA, &DEPipeline::ParseCelebAOp},
|
||||
{kRandomData, &DEPipeline::ParseRandomDataOp},
|
||||
{kTextFile, &DEPipeline::ParseTextFileOp},
|
||||
{kBuildVocab, &DEPipeline::ParseBuildVocabOp}};
|
||||
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
|
||||
{kClue, &DEPipeline::ParseClueOp}};
|
||||
|
||||
DEPipeline::DEPipeline() : iterator_(nullptr) {
|
||||
try {
|
||||
|
@ -1210,6 +1212,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
|
||||
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
|
||||
if (!p.second.is_none()) {
|
||||
|
@ -1236,6 +1239,7 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<BuildVocabOp::Builder> builder = std::make_shared<BuildVocabOp::Builder>();
|
||||
for (auto arg : args) {
|
||||
|
@ -1267,5 +1271,45 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
|
||||
if (!args["dataset_files"].is_none()) {
|
||||
(void)builder->SetClueFilesList(ToStringVector(args["dataset_files"]));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
|
||||
}
|
||||
// Optional arguments
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_shards") {
|
||||
(void)builder->SetNumDevices(ToInt(value));
|
||||
} else if (key == "shard_id") {
|
||||
(void)builder->SetDeviceId(ToInt(value));
|
||||
} else if (key == "cols_to_keyword") {
|
||||
std::map<std::string, std::string> map_dict;
|
||||
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
|
||||
if (!p.second.is_none()) {
|
||||
map_dict.insert({ToString(p.first), ToString(p.second)});
|
||||
} else {
|
||||
map_dict.insert({ToString(p.first), ToString(p.first)});
|
||||
}
|
||||
}
|
||||
(void)builder->SetColsKeyMap(map_dict);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<ClueOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -64,7 +64,8 @@ enum OpName {
|
|||
kCelebA,
|
||||
kRandomData,
|
||||
kTextFile,
|
||||
kBuildVocab
|
||||
kBuildVocab,
|
||||
kClue
|
||||
};
|
||||
|
||||
// The C++ binder class that we expose to the python script.
|
||||
|
@ -166,6 +167,8 @@ class DEPipeline {
|
|||
|
||||
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
private:
|
||||
// Execution tree that links the dataset operators.
|
||||
std::shared_ptr<ExecutionTree> tree_;
|
||||
|
|
|
@ -55,6 +55,7 @@
|
|||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "dataset/engine/gnn/graph.h"
|
||||
|
@ -201,6 +202,18 @@ void bindDatasetOps(py::module *m) {
|
|||
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp")
|
||||
.def_static("get_num_rows", [](const py::list &files) {
|
||||
int64_t count = 0;
|
||||
std::vector<std::string> filenames;
|
||||
for (auto file : files) {
|
||||
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
|
||||
}
|
||||
THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
|
||||
.def_static("get_num_rows",
|
||||
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
|
@ -629,7 +642,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
.value("RANDOMDATA", OpName::kRandomData)
|
||||
.value("BUILDVOCAB", OpName::kBuildVocab)
|
||||
.value("CELEBA", OpName::kCelebA)
|
||||
.value("TEXTFILE", OpName::kTextFile);
|
||||
.value("TEXTFILE", OpName::kTextFile)
|
||||
.value("CLUE", OpName::kClue);
|
||||
|
||||
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
|
||||
.value("DE_JIEBA_MIX", JiebaMode::kMix)
|
||||
|
|
|
@ -19,4 +19,5 @@ add_library(engine-datasetops-source OBJECT
|
|||
random_data_op.cc
|
||||
celeba_op.cc
|
||||
text_file_op.cc
|
||||
clue_op.cc
|
||||
)
|
|
@ -0,0 +1,551 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/util/task_manager.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
ClueOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
builder_rows_per_buffer_ = config_manager->rows_per_buffer();
|
||||
builder_worker_connector_size_ = config_manager->worker_connector_size();
|
||||
}
|
||||
|
||||
Status ClueOp::Builder::ValidateInputs() const {
|
||||
std::string err;
|
||||
err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : "";
|
||||
err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : "";
|
||||
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err);
|
||||
}
|
||||
|
||||
Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
|
||||
RETURN_IF_NOT_OK(ValidateInputs());
|
||||
|
||||
// Throttle the number of workers if we have more workers than files!
|
||||
if (static_cast<size_t>(builder_num_workers_) > builder_clue_files_list_.size()) {
|
||||
builder_num_workers_ = builder_clue_files_list_.size();
|
||||
MS_LOG(WARNING) << "ClueOp operator parallelism reduced to " << builder_num_workers_ << " workers.";
|
||||
}
|
||||
|
||||
ColKeyMap ck_map;
|
||||
for (auto &p : builder_cols_to_keyword_) {
|
||||
ck_map.insert({p.first, split(p.second, '/')});
|
||||
}
|
||||
|
||||
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
|
||||
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
|
||||
builder_device_id_);
|
||||
RETURN_IF_NOT_OK(clue_op->Init());
|
||||
*op = std::move(clue_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim) {
|
||||
std::vector<std::string> res;
|
||||
std::stringstream ss(s);
|
||||
std::string item;
|
||||
|
||||
while (getline(ss, item, delim)) {
|
||||
res.push_back(item);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_rows_per_shard_(0),
|
||||
all_num_rows_(0),
|
||||
num_samples_(num_samples),
|
||||
filename_index_(std::make_unique<StringIndex>()),
|
||||
clue_files_list_(std::move(clue_files_list)),
|
||||
load_jagged_connector_(true),
|
||||
cols_to_keyword_(cols_to_keyword),
|
||||
shuffle_files_(shuffle_files),
|
||||
finished_reading_dataset_(false),
|
||||
num_devices_(num_device),
|
||||
device_id_(device_id),
|
||||
load_io_block_queue_(true) {
|
||||
worker_connector_size_ = worker_connector_size;
|
||||
}
|
||||
|
||||
Status ClueOp::Init() {
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_));
|
||||
|
||||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(clue_files_list_.size() / num_workers_) + 1);
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
// Set the column name mapping (base class field)
|
||||
int count = 0;
|
||||
for (auto &p : cols_to_keyword_) {
|
||||
column_name_id_map_[p.first] = count;
|
||||
count++;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::Reset() {
|
||||
load_jagged_connector_ = true;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset());
|
||||
NotifyToFillIOBlockQueue();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) {
|
||||
TensorRow tRow(1, nullptr);
|
||||
(*tensor_table)->push_back(std::move(tRow));
|
||||
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar()));
|
||||
(**tensor_table)[row][0] = std::move(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) {
|
||||
nlohmann::json cursor = js;
|
||||
for (int i = 0; i < key_chain.size(); i++) {
|
||||
if (cursor.find(key_chain[i]) != cursor.end()) {
|
||||
cursor = cursor[key_chain[i]];
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to find key: " + key_chain[i]);
|
||||
}
|
||||
}
|
||||
std::string final_str = key_chain.back();
|
||||
switch (cursor.type()) {
|
||||
case nlohmann::detail::value_t::string:
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get<std::string>()}, TensorShape::CreateScalar()));
|
||||
break;
|
||||
|
||||
case nlohmann::detail::value_t::number_integer:
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)));
|
||||
(*t)->SetItemAt<int32_t>({0}, cursor.get<int32_t>());
|
||||
break;
|
||||
case nlohmann::detail::value_t::number_unsigned:
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32)));
|
||||
(*t)->SetItemAt<int32_t>({0}, cursor.get<uint32_t>());
|
||||
break;
|
||||
case nlohmann::detail::value_t::number_float:
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32)));
|
||||
(*t)->SetItemAt<int32_t>({0}, cursor.get<float>());
|
||||
break;
|
||||
case nlohmann::detail::value_t::array:
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get<std::vector<std::string>>()}, TensorShape::CreateScalar()));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file " + file);
|
||||
}
|
||||
|
||||
int64_t rows_each_buffer = 0;
|
||||
int64_t rows_total = 0;
|
||||
std::string line;
|
||||
std::unique_ptr<DataBuffer> cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
|
||||
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
|
||||
|
||||
while (getline(handle, line)) {
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
// If read to the end offset of this file, break.
|
||||
if (rows_total >= end_offset) {
|
||||
break;
|
||||
}
|
||||
// Skip line before start offset.
|
||||
if (rows_total < start_offset) {
|
||||
rows_total++;
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
nlohmann::json js = nlohmann::json::parse(line);
|
||||
int cols_count = cols_to_keyword_.size();
|
||||
TensorRow tRow(cols_count, nullptr);
|
||||
tensor_table->push_back(std::move(tRow));
|
||||
|
||||
int cout = 0;
|
||||
for (auto &p : cols_to_keyword_) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor));
|
||||
(*tensor_table)[rows_each_buffer][cout] = std::move(tensor);
|
||||
cout++;
|
||||
}
|
||||
} catch (const std::exception &err) {
|
||||
// Catch any exception and convert to Status return code
|
||||
RETURN_STATUS_UNEXPECTED("Failed to load json file");
|
||||
}
|
||||
|
||||
// RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer));
|
||||
rows_each_buffer++;
|
||||
rows_total++;
|
||||
if (rows_each_buffer == rows_per_buffer_) {
|
||||
cur_buffer->set_tensor_table(std::move(tensor_table));
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
|
||||
|
||||
cur_buffer = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
|
||||
tensor_table = std::make_unique<TensorQTable>();
|
||||
rows_each_buffer = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (rows_each_buffer > 0) {
|
||||
cur_buffer->set_tensor_table(std::move(tensor_table));
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::operator()() {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
|
||||
// launch one thread, responsible for filling IoBlockQueue
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this)));
|
||||
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1)));
|
||||
|
||||
// must be called after launching workers.
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
|
||||
NotifyToFillIOBlockQueue();
|
||||
|
||||
while (!finished_reading_dataset_) {
|
||||
int64_t buffer_id = 0;
|
||||
int32_t workers_done = 0;
|
||||
int64_t rows_read = 0;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
while (workers_done < num_workers_) {
|
||||
std::unique_ptr<DataBuffer> buffer;
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
|
||||
if (buffer->eoe()) {
|
||||
workers_done++;
|
||||
} else if (num_samples_ == 0 || rows_read < num_samples_) {
|
||||
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
|
||||
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
|
||||
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
|
||||
}
|
||||
rows_read += buffer->NumRows();
|
||||
buffer->set_id(buffer_id++);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
|
||||
} else {
|
||||
// end of epoch
|
||||
load_jagged_connector_ = false;
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
}
|
||||
}
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfData());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<FilenameBlock> io_block;
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void ClueOp::Print(std::ostream &out, bool show_all) const {
|
||||
// Always show the id and name as first line regardless if this summary or detailed print
|
||||
out << "(" << std::setw(2) << operator_id_ << ") <ClueOp>:";
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
|
||||
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n";
|
||||
for (int i = 0; i < clue_files_list_.size(); ++i) {
|
||||
out << " " << clue_files_list_[i];
|
||||
}
|
||||
out << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Pops an element from a queue in io_block_queues
|
||||
Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes an element to a queue in io_block_queues
|
||||
Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
|
||||
std::mt19937 rng(seed);
|
||||
std::shuffle(i_keys->begin(), i_keys->end(), rng);
|
||||
}
|
||||
|
||||
Status ClueOp::WaitToFillIOBlockQueue() {
|
||||
// must be called first if called by worker spanwed by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::vector<int64_t> i_keys;
|
||||
if (shuffle_files_) {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
i_keys.push_back(it.key());
|
||||
}
|
||||
}
|
||||
uint32_t seed = 0;
|
||||
while (true) {
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
|
||||
io_block_queue_wait_post_.Clear();
|
||||
|
||||
if (finished_reading_dataset_) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (shuffle_files_) {
|
||||
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
|
||||
}
|
||||
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
int64_t pre_count = 0;
|
||||
int64_t start_offset = 0;
|
||||
int64_t end_offset = 0;
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
std::vector<std::pair<std::string, int64_t>> file_index;
|
||||
if (!i_keys.empty()) {
|
||||
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto file_it = filename_index_->Search(*it);
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>(file_it.value(), *it));
|
||||
}
|
||||
} else {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
|
||||
}
|
||||
}
|
||||
for (auto file_info : file_index) {
|
||||
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
|
||||
auto ioBlock =
|
||||
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
||||
queue_index = (queue_index + 1) % num_workers_;
|
||||
}
|
||||
|
||||
pre_count += filename_numrows_[file_info.first];
|
||||
}
|
||||
|
||||
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
|
||||
finish = false;
|
||||
} else {
|
||||
finish = true;
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
|
||||
|
||||
bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count) {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool push = false;
|
||||
int64_t start_index = device_id_ * num_rows_per_shard_;
|
||||
if (device_id_ + 1 < 0) {
|
||||
MS_LOG(ERROR) << "Device id is invalid";
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
||||
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
|
||||
*start_offset = start_index - pre_count;
|
||||
push = true;
|
||||
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
if (pre_count >= start_index && pre_count < end_index) {
|
||||
*start_offset = 0;
|
||||
push = true;
|
||||
if (pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
return push;
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
Status ClueOp::PostEndOfEpoch(int32_t queue_index) {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::CalculateNumRowsPerShard() {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
int64_t count = CountTotalRows(it.value());
|
||||
filename_numrows_[it.value()] = count;
|
||||
all_num_rows_ += count;
|
||||
}
|
||||
if (all_num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API "
|
||||
"validation first.");
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
|
||||
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t ClueOp::CountTotalRows(const std::string &file) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
MS_LOG(ERROR) << "Failed to open file: " << file;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int64_t count = 0;
|
||||
while (getline(handle, line)) {
|
||||
if (!line.empty()) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
Status ClueOp::PostEndOfData() {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
|
||||
std::shared_ptr<ClueOp> op;
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(Builder().SetClueFilesList(files).Build(&op));
|
||||
for (auto file : files) {
|
||||
*count += op->CountTotalRows(file);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,270 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
|
||||
#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "dataset/util/auto_index.h"
|
||||
#include "dataset/engine/datasetops/parallel_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
using StringIndex = AutoIndexObj<std::string>;
|
||||
using ColKeyMap = std::map<std::string, std::vector<std::string>>;
|
||||
|
||||
class JaggedConnector;
|
||||
|
||||
class ClueOp : public ParallelOp {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
// Builder constructor. Creates the builder object.
|
||||
// @note No default args
|
||||
// @return This is a constructor.
|
||||
Builder();
|
||||
|
||||
// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
// Checks if the inputs of the builder is valid.
|
||||
// @return Status - the error code returned.
|
||||
Status ValidateInputs() const;
|
||||
|
||||
// Create the final object.
|
||||
// @param op - dataset op.
|
||||
// @return - the error code return.
|
||||
Status Build(std::shared_ptr<ClueOp> *op);
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
builder_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t op_connector_size) {
|
||||
builder_op_connector_size_ = op_connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetRowsPerBuffer(int64_t rows_per_buffer) {
|
||||
builder_rows_per_buffer_ = rows_per_buffer;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumDevices(int64_t num_dev) {
|
||||
builder_num_devices_ = num_dev;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetDeviceId(int64_t dev_id) {
|
||||
builder_device_id_ = dev_id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetClueFilesList(const std::vector<std::string> &files_list) {
|
||||
builder_clue_files_list_ = files_list;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetShuffleFiles(bool shuffle_files) {
|
||||
builder_shuffle_files_ = shuffle_files;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetNumSamples(int64_t num_samples) {
|
||||
builder_num_samples_ = num_samples;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder - setter method returns reference to the builder.
|
||||
Builder &SetColsKeyMap(const std::map<std::string, std::string> &cols_to_key) {
|
||||
builder_cols_to_keyword_ = cols_to_key;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Split string based on a character delimiter
|
||||
// @return - the a string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
int64_t builder_rows_per_buffer_;
|
||||
int64_t builder_num_samples_;
|
||||
int32_t builder_worker_connector_size_;
|
||||
std::vector<std::string> builder_clue_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
std::map<std::string, std::string> builder_cols_to_keyword_;
|
||||
};
|
||||
|
||||
// Constructor of ClueOp
|
||||
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Default destructor
|
||||
~ClueOp() = default;
|
||||
|
||||
// A print method typically used for debugging
|
||||
// @param out - The output stream to write output to
|
||||
// @param show_all - A bool to control if you want to show all info or just a summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Instantiates the internal queues and connectors
|
||||
// @return Status - the error code returned
|
||||
Status Init();
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - the error code returned.
|
||||
Status operator()() override;
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
// @return Status - the error code returned.
|
||||
Status Reset() override;
|
||||
|
||||
// Get total rows in files.
|
||||
// @param files - all clue files.
|
||||
// @param count - number of rows.
|
||||
// @return Status - the error coed returned.
|
||||
static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count);
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Parses a single row and puts the data into a tensor table.
|
||||
// @param line - the content of the row.
|
||||
// @param tensor_table - the tensor table to put the parsed data in.
|
||||
// @param row - the id of the row filled in the tensor table.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row);
|
||||
|
||||
// Reads a clue file and loads the data into multiple buffers.
|
||||
// @param file - the file to read.
|
||||
// @param start_offset - the start offset of file.
|
||||
// @param end_offset - the end offset of file.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id);
|
||||
|
||||
// Pops an element from a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to pop from.
|
||||
// @param out_block - the popped element.
|
||||
// @return Status - the error code returned.
|
||||
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
|
||||
|
||||
// Pushes an element to a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to push to.
|
||||
// @param io_block - the element to push onto the queue.
|
||||
// @return Status - the error code returned.
|
||||
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
|
||||
|
||||
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
|
||||
// @return Status - the error code returned.
|
||||
Status WaitToFillIOBlockQueue();
|
||||
|
||||
// Fill the IOBlockQueue.
|
||||
// @para i_keys - keys of file to fill to the IOBlockQueue
|
||||
// @return Status - the error code returned.
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
|
||||
|
||||
// Notifies the thread which called FillIoBlockQueue to resume execution
|
||||
void NotifyToFillIOBlockQueue();
|
||||
|
||||
// Select file and push it to the block queue.
|
||||
// @param file_name - File name.
|
||||
// @param start_file - If file contains the first sample of data.
|
||||
// @param end_file - If file contains the end sample of data.
|
||||
// @param pre_count - Total rows of previous files.
|
||||
// @return Status - the error code returned.
|
||||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfEpoch(int32_t queue_index);
|
||||
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
Status CalculateNumRowsPerShard();
|
||||
|
||||
// Count number of rows in each file.
|
||||
// @param filename - clue file name.
|
||||
// @return int64_t - the total number of rows in file.
|
||||
int64_t CountTotalRows(const std::string &file);
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfData();
|
||||
|
||||
// @return Status - the error code returned.
|
||||
Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t);
|
||||
|
||||
int32_t device_id_;
|
||||
bool shuffle_files_;
|
||||
bool finished_reading_dataset_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
bool load_io_block_queue_;
|
||||
int64_t num_rows_per_shard_;
|
||||
int64_t all_num_rows_;
|
||||
int64_t num_samples_;
|
||||
std::map<std::string, int64_t> filename_numrows_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
std::vector<std::string> clue_files_list_;
|
||||
WaitPost io_block_queue_wait_post_;
|
||||
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
bool load_jagged_connector_;
|
||||
ColKeyMap cols_to_keyword_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
|
|
@ -43,7 +43,7 @@ TextFileOp::Builder::Builder()
|
|||
|
||||
Status TextFileOp::Builder::ValidateInputs() const {
|
||||
std::string err_msg;
|
||||
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : "";
|
||||
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : "";
|
||||
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
|
||||
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
|
|||
from .core.configuration import config
|
||||
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
|
||||
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
|
||||
TextFileDataset, Schema, Shuffle, zip, RandomDataset
|
||||
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset
|
||||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
WeightedRandomSampler, Sampler
|
||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
|
@ -29,6 +29,6 @@ from .engine.graphdata import GraphData
|
|||
|
||||
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset",
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset",
|
||||
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
|
||||
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "NumpySlicesDataset", "VOCDataset",
|
||||
"CocoDataset", "TextFileDataset", "CLUEDataset", "Schema", "DistributedSampler", "PKSampler",
|
||||
"RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]
|
||||
|
|
|
@ -30,7 +30,7 @@ from ..core.configuration import config, ConfigurationManager
|
|||
|
||||
__all__ = ["config", "ConfigurationManager", "zip",
|
||||
"ImageFolderDatasetV2", "MnistDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset",
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
|
||||
"VOCDataset", "CocoDataset", "TextFileDataset", "BuildVocabDataset", "Schema", "Schema",
|
||||
"DistributedSampler", "PKSampler",
|
||||
|
|
|
@ -33,7 +33,7 @@ import copy
|
|||
import numpy as np
|
||||
|
||||
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
|
||||
MindRecordOp, TextFileOp, VOCOp, CocoOp, CBatchInfo
|
||||
MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo
|
||||
from mindspore._c_expression import typing
|
||||
|
||||
from mindspore import log as logger
|
||||
|
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||
check_split
|
||||
check_split, check_cluedataset
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
||||
try:
|
||||
|
@ -4317,6 +4317,222 @@ class CelebADataset(MappableDataset):
|
|||
return self.sampler.is_sharded()
|
||||
|
||||
|
||||
class CLUEDataset(SourceDataset):
|
||||
"""
|
||||
A source dataset that reads and parses CLUE datasets.
|
||||
CLUE, the Chinese Language Understanding Evaluation Benchmark, a collection of datasets, baselines, pre-trained
|
||||
models, corpus and leaderboard. Here we bring in classification task of CLUE, which are AFQMC, TNEWS, IFLYTEK,
|
||||
CMNLI, WSC and CSL.
|
||||
|
||||
Args:
|
||||
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
|
||||
files. The list will be sorted in a lexicographical order.
|
||||
task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
|
||||
(default=AFQMC).
|
||||
usage (str, optional): Need train, test or eval data (default="train").
|
||||
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
|
||||
num_parallel_workers (int, optional): number of workers to read the data
|
||||
(default=None, number set in the config).
|
||||
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
|
||||
If shuffle is False, no shuffling will be performed;
|
||||
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
||||
Otherwise, there are two levels of shuffling:
|
||||
|
||||
- Shuffle.GLOBAL: Shuffle both the files and samples.
|
||||
|
||||
- Shuffle.FILES: Shuffle files only.
|
||||
|
||||
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
|
||||
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||
argument should be specified only when num_shards is also specified.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
|
||||
>>> dataset = ds.CLUEDataset(dataset_files=dataset_files, task='AFQMC', usage='train')
|
||||
|
||||
"""
|
||||
|
||||
@check_cluedataset
|
||||
def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None,
|
||||
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.dataset_files = self._find_files(dataset_files)
|
||||
self.dataset_files.sort()
|
||||
self.num_samples = num_samples
|
||||
self.task_dict = {
|
||||
'AFQMC': {
|
||||
'train': {
|
||||
'sentence1': 'sentence1',
|
||||
'sentence2': 'sentence2',
|
||||
'label': 'label'
|
||||
},
|
||||
'test': {
|
||||
'id': 'id',
|
||||
'sentence1': 'sentence1',
|
||||
'sentence2': 'sentence2'
|
||||
},
|
||||
'eval': {
|
||||
'sentence1': 'sentence1',
|
||||
'sentence2': 'sentence2',
|
||||
'label': 'label'
|
||||
}
|
||||
},
|
||||
'CMNLI': {
|
||||
'train': {
|
||||
'sentence1': 'sentence1',
|
||||
'sentence2': 'sentence2',
|
||||
'label': 'label'
|
||||
},
|
||||
'test': {
|
||||
'id': 'id',
|
||||
'sentence1': 'sentence1',
|
||||
'sentence2': 'sentence2'
|
||||
},
|
||||
'eval': {
|
||||
'sentence1': 'sentence1',
|
||||
'sentence2': 'sentence2',
|
||||
'label': 'label'
|
||||
}
|
||||
},
|
||||
'CSL': {
|
||||
'train': {
|
||||
'id': 'id',
|
||||
'abst': 'abst',
|
||||
'keyword': 'keyword',
|
||||
'label': 'label'
|
||||
},
|
||||
'test': {
|
||||
'id': 'id',
|
||||
'abst': 'abst',
|
||||
'keyword': 'keyword'
|
||||
},
|
||||
'eval': {
|
||||
'id': 'id',
|
||||
'abst': 'abst',
|
||||
'keyword': 'keyword',
|
||||
'label': 'label'
|
||||
}
|
||||
},
|
||||
'IFLYTEK': {
|
||||
'train': {
|
||||
'label': 'label',
|
||||
'label_des': 'label_des',
|
||||
'sentence': 'sentence'
|
||||
},
|
||||
'test': {
|
||||
'id': 'id',
|
||||
'sentence': 'sentence',
|
||||
},
|
||||
'eval': {
|
||||
'label': 'label',
|
||||
'label_des': 'label_des',
|
||||
'sentence': 'sentence'
|
||||
}
|
||||
},
|
||||
'TNEWS': {
|
||||
'train': {
|
||||
'label': 'label',
|
||||
'label_desc': 'label_desc',
|
||||
'sentence': 'sentence',
|
||||
'keywords': 'keywords'
|
||||
},
|
||||
'test': {
|
||||
'id': 'id',
|
||||
'sentence': 'sentence',
|
||||
'keywords': 'keywords'
|
||||
},
|
||||
'eval': {
|
||||
'label': 'label',
|
||||
'label_desc': 'label_desc',
|
||||
'sentence': 'sentence',
|
||||
'keywords': 'keywords'
|
||||
}
|
||||
},
|
||||
'WSC': {
|
||||
'train': {
|
||||
'span1_index': 'target/span1_index',
|
||||
'span2_index': 'target/span2_index',
|
||||
'span1_text': 'target/span1_text',
|
||||
'span2_text': 'target/span2_text',
|
||||
'idx': 'idx',
|
||||
'label': 'label',
|
||||
'text': 'text'
|
||||
},
|
||||
'test': {
|
||||
'span1_index': 'target/span1_index',
|
||||
'span2_index': 'target/span2_index',
|
||||
'span1_text': 'target/span1_text',
|
||||
'span2_text': 'target/span2_text',
|
||||
'idx': 'idx',
|
||||
'text': 'text'
|
||||
},
|
||||
'eval': {
|
||||
'span1_index': 'target/span1_index',
|
||||
'span2_index': 'target/span2_index',
|
||||
'span1_text': 'target/span1_text',
|
||||
'span2_text': 'target/span2_text',
|
||||
'idx': 'idx',
|
||||
'label': 'label',
|
||||
'text': 'text'
|
||||
}
|
||||
}
|
||||
}
|
||||
self.cols_to_keyword = self.task_dict[task][usage]
|
||||
|
||||
if not isinstance(shuffle, (bool, Shuffle)):
|
||||
raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
|
||||
if not isinstance(shuffle, Shuffle):
|
||||
if shuffle:
|
||||
self.shuffle_level = Shuffle.GLOBAL
|
||||
self.shuffle_files = True
|
||||
else:
|
||||
self.shuffle_level = None
|
||||
self.shuffle_files = False
|
||||
else:
|
||||
self.shuffle_level = shuffle
|
||||
self.shuffle_files = True
|
||||
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["dataset_files"] = self.dataset_files
|
||||
args["num_samples"] = self.num_samples
|
||||
if self.shuffle_files is not None:
|
||||
args["shuffle_files"] = self.shuffle_files
|
||||
args["shuffle"] = self.shuffle_level
|
||||
args["num_shards"] = self.num_shards
|
||||
args["shard_id"] = self.shard_id
|
||||
args["cols_to_keyword"] = self.cols_to_keyword
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
"""
|
||||
Get the number of batches in an epoch.
|
||||
|
||||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
num_rows = ClueOp.get_num_rows(self.dataset_files)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is None:
|
||||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
return self._dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
||||
def is_sharded(self):
|
||||
if self.num_shards is not None:
|
||||
return self.num_shards > 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TextFileDataset(SourceDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in text format.
|
||||
|
|
|
@ -50,7 +50,8 @@ def alter_tree(node):
|
|||
|
||||
def _alter_node(node):
|
||||
"""Performing some alteration to a dataset node. A common alteration is to insert a node."""
|
||||
if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL:
|
||||
if isinstance(node, (de.TFRecordDataset, de.TextFileDataset, de.CLUEDataset)) \
|
||||
and node.shuffle_level == de.Shuffle.GLOBAL:
|
||||
# Remove the connection between the parent's node to the current node because we are inserting a node.
|
||||
if node.output:
|
||||
node.output.pop()
|
||||
|
@ -179,6 +180,8 @@ class Iterator:
|
|||
op_type = OpName.TEXTFILE
|
||||
elif isinstance(dataset, de.BuildVocabDataset):
|
||||
op_type = OpName.BUILDVOCAB
|
||||
elif isinstance(dataset, de.CLUEDataset):
|
||||
op_type = OpName.CLUE
|
||||
else:
|
||||
raise ValueError("Unsupported DatasetOp")
|
||||
|
||||
|
|
|
@ -1075,6 +1075,41 @@ def check_add_column(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_cluedataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
|
||||
# check dataset_files; required argument
|
||||
dataset_files = param_dict.get('dataset_files')
|
||||
if dataset_files is None:
|
||||
raise ValueError("dataset_files is not provided.")
|
||||
if not isinstance(dataset_files, (str, list)):
|
||||
raise TypeError("dataset_files should be of type str or a list of strings.")
|
||||
|
||||
# check task
|
||||
task_param = param_dict.get('task')
|
||||
if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
|
||||
raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL")
|
||||
|
||||
# check usage
|
||||
usage_param = param_dict.get('usage')
|
||||
if usage_param not in ['train', 'test', 'eval']:
|
||||
raise ValueError("usage should be train, test or eval")
|
||||
|
||||
check_param_type(nreq_param_int, param_dict, int)
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_textfiledataset(method):
|
||||
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ SET(DE_UT_SRCS
|
|||
cifar_op_test.cc
|
||||
celeba_op_test.cc
|
||||
take_op_test.cc
|
||||
clue_op_test.cc
|
||||
text_file_op_test.cc
|
||||
filter_op_test.cc
|
||||
concat_op_test.cc
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "common/utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace common = mindspore::common;
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestCLUEOp : public UT::DatasetOpTesting {
|
||||
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestCLUEOp, TestCLUEBasic) {
|
||||
// Start with an empty execution tree
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testCLUE/afqmc/train.json";
|
||||
std::map<std::string, std::string> key_map;
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
|
||||
std::shared_ptr<ClueOp> op;
|
||||
ClueOp::Builder builder;
|
||||
builder.SetClueFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(16)
|
||||
.SetNumWorkers(16)
|
||||
.SetOpConnectorSize(2)
|
||||
.SetColsKeyMap(key_map);
|
||||
|
||||
Status rc = builder.Build(&op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->AssociateNode(op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->AssignRoot(op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration.";
|
||||
rc = tree->Prepare();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = tree->Launch();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(tree);
|
||||
TensorRow tensor_list;
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
int row_count = 0;
|
||||
while (!tensor_list.empty()) {
|
||||
// Display the tensor by calling the printer on it
|
||||
for (int i = 0; i < tensor_list.size(); i++) {
|
||||
std::ostringstream ss;
|
||||
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
|
||||
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
|
||||
}
|
||||
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
row_count++;
|
||||
}
|
||||
|
||||
ASSERT_EQ(row_count, 3);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCLUEOp, TestTotalRows) {
|
||||
std::string tf_file1 = datasets_root_path_ + "/testCLUE/afqmc/train.json";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testCLUE/afqmc/dev.json";
|
||||
std::vector<std::string> files;
|
||||
files.push_back(tf_file1);
|
||||
int64_t total_rows = 0;
|
||||
ClueOp::CountAllFileRows(files, &total_rows);
|
||||
ASSERT_EQ(total_rows, 3);
|
||||
files.clear();
|
||||
|
||||
files.push_back(tf_file2);
|
||||
ClueOp::CountAllFileRows(files, &total_rows);
|
||||
ASSERT_EQ(total_rows, 3);
|
||||
files.clear();
|
||||
|
||||
files.push_back(tf_file1);
|
||||
files.push_back(tf_file2);
|
||||
ClueOp::CountAllFileRows(files, &total_rows);
|
||||
ASSERT_EQ(total_rows, 6);
|
||||
files.clear();
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
{"sentence1": "你有花呗吗", "sentence2": "我的花呗没额度了", "label": "0"}
|
||||
{"sentence1": "吃饭能用花呗吗", "sentence2": "花呗太方便了", "label": "0"}
|
||||
{"sentence1": "蚂蚁花呗支付金额有什么限制", "sentence2": "我到实体店消费用花呗支付受金额限制", "label": "1"}
|
|
@ -0,0 +1,3 @@
|
|||
{"id": 0, "sentence1": "借呗取消的时间", "sentence2": "蚂蚁借呗恢复的月数"}
|
||||
{"id": 1, "sentence1": "网商贷用什么方法转变成借呗", "sentence2": "什么手段能将网商贷切换为借呗"}
|
||||
{"id": 2, "sentence1": "我的借呗为什么开通不了", "sentence2": "我为啥没法开通借呗"}
|
|
@ -0,0 +1,3 @@
|
|||
{"sentence1": "蚂蚁借呗等额还款能否换成先息后本", "sentence2": "借呗可以先息到期还本吗", "label": "0"}
|
||||
{"sentence1": "蚂蚁花呗说我违约了", "sentence2": "蚂蚁花呗违约行为是啥", "label": "0"}
|
||||
{"sentence1": "帮我看看本月花呗账单结清了没", "sentence2": "上月的花呗账单", "label": "0"}
|
|
@ -0,0 +1,3 @@
|
|||
{"sentence1": "每个人都有权利", "sentence2": "每个人都有福利", "label": "neutral"}
|
||||
{"sentence1": "有时候我喜欢他,但我也喜欢看到有人打他", "sentence2": "说实话,我有点喜欢他,但还是喜欢看到有人打他。", "label": "entailment"}
|
||||
{"sentence1": "我最喜欢的餐馆是离你最近的一家", "sentence2": "我最喜欢的餐馆离你家至少一百英里远。", "label": "contradiction"}
|
|
@ -0,0 +1,3 @@
|
|||
{"id": 0, "sentence1": "今天,全球都在看着最新航天飞机的处女航。", "sentence2": "全世界都在看最新的航天飞机发射。"}
|
||||
{"id": 1, "sentence1": "而我们把竹篮放在一个地方,把玻璃瓶放在另一处,把书放在另一处,满了要把它放到车里", "sentence2": "我们没有分开任何东西,都把它全扔进一个箱子里。"}
|
||||
{"id": 2, "sentence1": "她占用了我的很多时间,她给我读了很多关于灵异的故事,我觉得很无聊。", "sentence2": "我喜欢和她一起读鬼故事。"}
|
|
@ -0,0 +1,3 @@
|
|||
{"sentence1": "你应该给这件衣服定一个价格。", "sentence2": "不同的衣服有不同的价格。", "label": "neutral"}
|
||||
{"sentence1": "我怎么知道他要说什么", "sentence2": "他说什么我并不知道。", "label": "entailment"}
|
||||
{"sentence1": "向左。", "sentence2": "向右。", "label": "contradiction"}
|
|
@ -0,0 +1,3 @@
|
|||
{"id": 1, "abst": "这是第一段很长的文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "1"}
|
||||
{"id": 2, "abst": "这是第二段很长的文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "1"}
|
||||
{"id": 3, "abst": "这是第三段很长的文本", "keyword": ["1", "2", "3"], "label": "0"}
|
|
@ -0,0 +1,3 @@
|
|||
{"id": 2415, "abst": "长文本1", "keyword": ["关键词1", "关键词2"]}
|
||||
{"id": 2565, "abst": "长文本2", "keyword": ["关键词1", "关键词2", "关键词3"]}
|
||||
{"id": 2625, "abst": "长文本3", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"]}
|
|
@ -0,0 +1,3 @@
|
|||
{"id": 1, "abst": "这是一段长文本", "keyword": ["关键词1", "关键词2", "关键词3", "关键词4"], "label": "0"}
|
||||
{"id": 2, "abst": "这是一段长文本", "keyword": ["关键词5", "关键词6", "关键词7", "关键词8"], "label": "0"}
|
||||
{"id": 3, "abst": "这是一段长文本", "keyword": ["关键词9", "关键词10", "关键词11", "关键词12"], "label": "0"}
|
|
@ -0,0 +1,3 @@
|
|||
{"label": "110", "label_des": "社区超市", "sentence": "这是第一段文本"}
|
||||
{"label": "70", "label_des": "工具", "sentence": "这是第二段文本"}
|
||||
{"label": "10", "label_des": "社区服务", "sentence": "这是第三段文本"}
|
|
@ -0,0 +1,3 @@
|
|||
{"id": 0, "sentence": "文本1"}
|
||||
{"id": 1, "sentence": "文本2"}
|
||||
{"id": 2, "sentence": "文本3"}
|
|
@ -0,0 +1,3 @@
|
|||
{"label": "11", "label_des": "薅羊毛", "sentence": "第一个文本"}
|
||||
{"label": "95", "label_des": "借贷", "sentence": "第二个文本"}
|
||||
{"label": "74", "label_des": "违章", "sentence": "第三个文本"}
|
|
@ -0,0 +1,3 @@
|
|||
{"label": "102", "label_desc": "news_entertainment", "sentence": "新闻1", "keywords": "关键词一,关键词二,关键词三,关键词四"}
|
||||
{"label": "110", "label_desc": "news_military", "sentence": "新闻2", "keywords": "关键词一,关键词二,关键词三,关键词四,关键词五"}
|
||||
{"label": "104", "label_desc": "news_finance", "sentence": "新闻3", "keywords": "关键词一,关键词二,关键词三,关键词四,关键词五"}
|
|
@ -0,0 +1,3 @@
|
|||
{"id": 0, "sentence": "新闻1", "keywords": "关键词1,关键词2,关键词3,关键词4,关键词5"}
|
||||
{"id": 1, "sentence": "新闻2", "keywords": "关键词1,关键词2,关键词3,关键词4"}
|
||||
{"id": 2, "sentence": "新闻3", "keywords": ""}
|
|
@ -0,0 +1,3 @@
|
|||
{"label": "108", "label_desc": "news_edu", "sentence": "新闻1", "keywords": ""}
|
||||
{"label": "104", "label_desc": "news_finance", "sentence": "新闻2", "keywords": "关键词1,关键词2,关键词3,关键词4,关键词5,关键词6"}
|
||||
{"label": "106", "label_desc": "news_house", "sentence": "新闻3", "keywords": ""}
|
|
@ -0,0 +1,3 @@
|
|||
{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?", "label": "true"}
|
||||
{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场", "label": "false"}
|
||||
{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业", "label": "true"}
|
|
@ -0,0 +1,3 @@
|
|||
{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?"}
|
||||
{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场"}
|
||||
{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业"}
|
|
@ -0,0 +1,3 @@
|
|||
{"target": {"span1_index": 0, "span1_text": "小明", "span2_index": 4, "span2_text": "他"}, "idx": 0, "text": "小明呢,他在哪?", "label": "true"}
|
||||
{"target": {"span1_index": 0, "span1_text": "小红", "span2_index": 9, "span2_text": "他"}, "idx": 1, "text": "小红刚刚看到小明,他在操场", "label": "false"}
|
||||
{"target": {"span1_index": 6, "span1_text": "小张", "span2_index": 8, "span2_text": "你"}, "idx": 2, "text": "等小明回来,小张你叫他交作业", "label": "true"}
|
|
@ -0,0 +1,355 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def test_clue():
|
||||
"""
|
||||
Test CLUE with repeat, skip and so on
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
|
||||
data = data.repeat(2)
|
||||
data = data.skip(3)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'sentence1': d['sentence1'].item().decode("utf8"),
|
||||
'sentence2': d['sentence2'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_clue_num_shards():
|
||||
"""
|
||||
Test num_shards param of CLUE dataset
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'sentence1': d['sentence1'].item().decode("utf8"),
|
||||
'sentence2': d['sentence2'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 1
|
||||
|
||||
|
||||
def test_clue_num_samples():
|
||||
"""
|
||||
Test num_samples param of CLUE dataset
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2)
|
||||
count = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
count += 1
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_textline_dataset_get_datasetsize():
|
||||
"""
|
||||
Test get_dataset_size of CLUE dataset
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
|
||||
data = ds.TextFileDataset(TRAIN_FILE)
|
||||
size = data.get_dataset_size()
|
||||
assert size == 3
|
||||
|
||||
|
||||
def test_clue_afqmc():
|
||||
"""
|
||||
Test AFQMC for train, test and evaluation
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
|
||||
TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json'
|
||||
EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json'
|
||||
|
||||
# train
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'sentence1': d['sentence1'].item().decode("utf8"),
|
||||
'sentence2': d['sentence2'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# test
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'id': d['id'],
|
||||
'sentence1': d['sentence1'].item().decode("utf8"),
|
||||
'sentence2': d['sentence2'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# evaluation
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'sentence1': d['sentence1'].item().decode("utf8"),
|
||||
'sentence2': d['sentence2'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_clue_cmnli():
|
||||
"""
|
||||
Test CMNLI for train, test and evaluation
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json'
|
||||
TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json'
|
||||
EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json'
|
||||
|
||||
# train
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'sentence1': d['sentence1'].item().decode("utf8"),
|
||||
'sentence2': d['sentence2'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# test
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'id': d['id'],
|
||||
'sentence1': d['sentence1'],
|
||||
'sentence2': d['sentence2']
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# eval
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'],
|
||||
'sentence1': d['sentence1'],
|
||||
'sentence2': d['sentence2']
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_clue_csl():
|
||||
"""
|
||||
Test CSL for train, test and evaluation
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json'
|
||||
TEST_FILE = '../data/dataset/testCLUE/csl/test.json'
|
||||
EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json'
|
||||
|
||||
# train
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'id': d['id'],
|
||||
'abst': d['abst'].item().decode("utf8"),
|
||||
'keyword': [i.item().decode("utf8") for i in d['keyword']],
|
||||
'label': d['label'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# test
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'id': d['id'],
|
||||
'abst': d['abst'].item().decode("utf8"),
|
||||
'keyword': [i.item().decode("utf8") for i in d['keyword']],
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# eval
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'id': d['id'],
|
||||
'abst': d['abst'].item().decode("utf8"),
|
||||
'keyword': [i.item().decode("utf8") for i in d['keyword']],
|
||||
'label': d['label'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_clue_iflytek():
|
||||
"""
|
||||
Test IFLYTEK for train, test and evaluation
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json'
|
||||
TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json'
|
||||
EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json'
|
||||
|
||||
# train
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'label_des': d['label_des'].item().decode("utf8"),
|
||||
'sentence': d['sentence'].item().decode("utf8"),
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# test
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'id': d['id'],
|
||||
'sentence': d['sentence'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# eval
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'label_des': d['label_des'].item().decode("utf8"),
|
||||
'sentence': d['sentence'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_clue_tnews():
|
||||
"""
|
||||
Test TNEWS for train, test and evaluation
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json'
|
||||
TEST_FILE = '../data/dataset/testCLUE/tnews/test.json'
|
||||
EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json'
|
||||
|
||||
# train
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'label_desc': d['label_desc'].item().decode("utf8"),
|
||||
'sentence': d['sentence'].item().decode("utf8"),
|
||||
'keywords':
|
||||
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# test
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'id': d['id'],
|
||||
'sentence': d['sentence'].item().decode("utf8"),
|
||||
'keywords':
|
||||
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# eval
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False)
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'label_desc': d['label_desc'].item().decode("utf8"),
|
||||
'sentence': d['sentence'].item().decode("utf8"),
|
||||
'keywords':
|
||||
d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_clue_wsc():
|
||||
"""
|
||||
Test WSC for train, test and evaluation
|
||||
"""
|
||||
TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json'
|
||||
TEST_FILE = '../data/dataset/testCLUE/wsc/test.json'
|
||||
EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json'
|
||||
|
||||
# train
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train')
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'span1_index': d['span1_index'],
|
||||
'span2_index': d['span2_index'],
|
||||
'span1_text': d['span1_text'].item().decode("utf8"),
|
||||
'span2_text': d['span2_text'].item().decode("utf8"),
|
||||
'idx': d['idx'],
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'text': d['text'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# test
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test')
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'span1_index': d['span1_index'],
|
||||
'span2_index': d['span2_index'],
|
||||
'span1_text': d['span1_text'].item().decode("utf8"),
|
||||
'span2_text': d['span2_text'].item().decode("utf8"),
|
||||
'idx': d['idx'],
|
||||
'text': d['text'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
# eval
|
||||
buffer = []
|
||||
data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval')
|
||||
for d in data.create_dict_iterator():
|
||||
buffer.append({
|
||||
'span1_index': d['span1_index'],
|
||||
'span2_index': d['span2_index'],
|
||||
'span1_text': d['span1_text'].item().decode("utf8"),
|
||||
'span2_text': d['span2_text'].item().decode("utf8"),
|
||||
'idx': d['idx'],
|
||||
'label': d['label'].item().decode("utf8"),
|
||||
'text': d['text'].item().decode("utf8")
|
||||
})
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_clue()
|
||||
test_clue_afqmc()
|
||||
test_clue_cmnli()
|
||||
test_clue_csl()
|
||||
test_clue_iflytek()
|
||||
test_clue_tnews()
|
||||
test_clue_wsc()
|
Loading…
Reference in New Issue