forked from mindspore-Ecosystem/mindspore
!23670 [assistant][ops] add new dataset loading operator CoNLL2000ChunkingDataset
Merge pull request !23670 from 杨旭华/CoNLL2000ChunkingDataset
This commit is contained in:
commit
14df5d3984
|
@ -95,6 +95,7 @@
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/conll2000_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||||
|
@ -1033,6 +1034,14 @@ CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector
|
||||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CoNLL2000Dataset::CoNLL2000Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
|
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||||
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
|
auto ds = std::make_shared<CoNLL2000Node>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
|
||||||
|
num_shards, shard_id, cache);
|
||||||
|
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||||
|
}
|
||||||
|
|
||||||
CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char field_delim,
|
CSVDataset::CSVDataset(const std::vector<std::vector<char>> &dataset_files, char field_delim,
|
||||||
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
|
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
|
||||||
const std::vector<std::vector<char>> &column_names, int64_t num_samples, ShuffleMode shuffle,
|
const std::vector<std::vector<char>> &column_names, int64_t num_samples, ShuffleMode shuffle,
|
||||||
|
|
|
@ -32,6 +32,7 @@
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/conll2000_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||||
|
@ -159,6 +160,18 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(CoNLL2000Node, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<CoNLL2000Node, DatasetNode, std::shared_ptr<CoNLL2000Node>>(
|
||||||
|
*m, "CoNLL2000Node", "to create a CoNLL2000Node")
|
||||||
|
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
|
||||||
|
int32_t num_shards, int32_t shard_id) {
|
||||||
|
std::shared_ptr<CoNLL2000Node> conll2000 = std::make_shared<CoNLL2000Node>(
|
||||||
|
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
|
THROW_IF_ERROR(conll2000->ValidateParams());
|
||||||
|
return conll2000;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
|
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
|
||||||
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
|
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
|
||||||
|
|
|
@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||||
cityscapes_op.cc
|
cityscapes_op.cc
|
||||||
clue_op.cc
|
clue_op.cc
|
||||||
coco_op.cc
|
coco_op.cc
|
||||||
|
conll2000_op.cc
|
||||||
csv_op.cc
|
csv_op.cc
|
||||||
dbpedia_op.cc
|
dbpedia_op.cc
|
||||||
div2k_op.cc
|
div2k_op.cc
|
||||||
|
|
|
@ -0,0 +1,181 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/conll2000_op.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "debug/common.h"
|
||||||
|
#include "minddata/dataset/core/config_manager.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||||
|
#include "minddata/dataset/engine/execution_tree.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
#include "minddata/dataset/util/wait_post.h"
|
||||||
|
#include "utils/file_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
CoNLL2000Op::CoNLL2000Op(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size,
|
||||||
|
std::unique_ptr<DataSchema> schema, const std::vector<std::string> &conll2000_file_list,
|
||||||
|
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
|
||||||
|
: TextFileOp(num_workers, total_rows, worker_connector_size, std::move(schema), conll2000_file_list,
|
||||||
|
op_connector_size, shuffle_files, num_devices, device_id) {}
|
||||||
|
|
||||||
|
// A print method typically used for debugging.
|
||||||
|
void CoNLL2000Op::Print(std::ostream &out, bool show_all) const {
|
||||||
|
if (!show_all) {
|
||||||
|
// Call the super class for displaying any common 1-liner info.
|
||||||
|
ParallelOp::Print(out, show_all);
|
||||||
|
// Then show any custom derived-internal 1-liner info for this op.
|
||||||
|
out << "\n";
|
||||||
|
} else {
|
||||||
|
// Call the super class for displaying any common detailed info.
|
||||||
|
ParallelOp::Print(out, show_all);
|
||||||
|
// Then show any custom derived-internal stuff.
|
||||||
|
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||||
|
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCoNLL2000 file list:\n";
|
||||||
|
for (size_t i = 0; i < text_files_list_.size(); ++i) {
|
||||||
|
out << " " << text_files_list_[i];
|
||||||
|
}
|
||||||
|
out << "\nData Schema:\n";
|
||||||
|
out << *data_schema_ << "\n\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CoNLL2000Op::LoadTensor(const std::vector<std::string> &column, TensorRow *out_row, size_t index) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out_row);
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(column, &tensor));
|
||||||
|
(*out_row)[index] = std::move(tensor);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to split string based on a character delimiter.
|
||||||
|
std::vector<std::string> CoNLL2000Op::Split(const std::string &s, char delim) {
|
||||||
|
std::vector<std::string> res;
|
||||||
|
std::stringstream ss(s);
|
||||||
|
std::string item;
|
||||||
|
|
||||||
|
while (getline(ss, item, delim)) {
|
||||||
|
res.push_back(item);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removes excess space before and after the string.
|
||||||
|
std::string CoNLL2000Op::Strip(const std::string &str) {
|
||||||
|
std::int64_t strlen = str.size();
|
||||||
|
std::int64_t i, j;
|
||||||
|
i = 0;
|
||||||
|
while (i < strlen && str[i] == ' ') {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
j = strlen - 1;
|
||||||
|
while (j >= i && str[j] == ' ') {
|
||||||
|
j--;
|
||||||
|
}
|
||||||
|
j++;
|
||||||
|
if (i == 0 && j == strlen) {
|
||||||
|
return str;
|
||||||
|
} else {
|
||||||
|
return str.substr(i, j - i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CoNLL2000Op::Load(const std::vector<std::string> &word, const std::vector<std::string> &pos_tag,
|
||||||
|
const std::vector<std::string> &chunk_tag, const std::string &file, int32_t worker_id) {
|
||||||
|
size_t row_line = 3;
|
||||||
|
TensorRow tRow(row_line, nullptr);
|
||||||
|
// Add file path info.
|
||||||
|
std::vector<std::string> file_path(row_line, file);
|
||||||
|
tRow.setPath(file_path);
|
||||||
|
size_t word_index = 0, pos_tag_index = 1, chunk_tag_index = 2;
|
||||||
|
RETURN_IF_NOT_OK(LoadTensor(word, &tRow, word_index));
|
||||||
|
RETURN_IF_NOT_OK(LoadTensor(pos_tag, &tRow, pos_tag_index));
|
||||||
|
RETURN_IF_NOT_OK(LoadTensor(chunk_tag, &tRow, chunk_tag_index));
|
||||||
|
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow)));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CoNLL2000Op::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||||
|
auto realpath = FileUtils::GetRealPath(file.data());
|
||||||
|
if (!realpath.has_value()) {
|
||||||
|
MS_LOG(ERROR) << "Invalid file path, " << DatasetName() << " dataset dir: " << file << " does not exist.";
|
||||||
|
RETURN_STATUS_UNEXPECTED("Invalid file path, " + DatasetName() + " dataset dir: " + file + " does not exist.");
|
||||||
|
}
|
||||||
|
std::ifstream handle(realpath.value());
|
||||||
|
if (!handle.is_open()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + ": " + file);
|
||||||
|
}
|
||||||
|
int64_t rows_total = 0;
|
||||||
|
std::string line;
|
||||||
|
std::vector<std::string> word_column;
|
||||||
|
std::vector<std::string> pos_tag_column;
|
||||||
|
std::vector<std::string> chunk_tag_column;
|
||||||
|
while (getline(handle, line)) {
|
||||||
|
if (line.empty() && rows_total < start_offset) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// If read to the end offset of this file, break.
|
||||||
|
if (rows_total >= end_offset) {
|
||||||
|
if (word_column.size() != 0) {
|
||||||
|
Status s = Load(word_column, pos_tag_column, chunk_tag_column, file, worker_id);
|
||||||
|
if (s.IsError()) {
|
||||||
|
handle.close();
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<std::string>().swap(word_column);
|
||||||
|
std::vector<std::string>().swap(pos_tag_column);
|
||||||
|
std::vector<std::string>().swap(chunk_tag_column);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Skip line before start offset.
|
||||||
|
if (rows_total < start_offset) {
|
||||||
|
rows_total++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
line = Strip(line);
|
||||||
|
if (line.empty() && rows_total >= start_offset) {
|
||||||
|
if (word_column.size() != 0) {
|
||||||
|
Status s = Load(word_column, pos_tag_column, chunk_tag_column, file, worker_id);
|
||||||
|
if (s.IsError()) {
|
||||||
|
handle.close();
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<std::string>().swap(word_column);
|
||||||
|
std::vector<std::string>().swap(pos_tag_column);
|
||||||
|
std::vector<std::string>().swap(chunk_tag_column);
|
||||||
|
continue;
|
||||||
|
} else if (!line.empty() && rows_total >= start_offset) {
|
||||||
|
std::vector<std::string> column = Split(line, ' ');
|
||||||
|
size_t word_index = 0, pos_tag_index = 1, chunk_tag_index = 2;
|
||||||
|
word_column.push_back(column[word_index]);
|
||||||
|
pos_tag_column.push_back(column[pos_tag_index]);
|
||||||
|
chunk_tag_column.push_back(column[chunk_tag_index]);
|
||||||
|
}
|
||||||
|
rows_total++;
|
||||||
|
}
|
||||||
|
handle.close();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,96 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CONLL2000_OP_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CONLL2000_OP_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||||
|
#include "minddata/dataset/util/queue.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class JaggedConnector;
|
||||||
|
|
||||||
|
class CoNLL2000Op : public TextFileOp {
|
||||||
|
public:
|
||||||
|
/// \Constructor of CoNLL2000Op.
|
||||||
|
CoNLL2000Op(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr<DataSchema>,
|
||||||
|
const std::vector<std::string> &conll2000_file_list, int32_t op_connector_size, bool shuffle_files,
|
||||||
|
int32_t num_devices, int32_t device_id);
|
||||||
|
|
||||||
|
/// \Default destructor.
|
||||||
|
~CoNLL2000Op() = default;
|
||||||
|
|
||||||
|
/// \brief A print method typically used for debugging.
|
||||||
|
/// \param[in] out The output stream to write output to.
|
||||||
|
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
|
||||||
|
void Print(std::ostream &out, bool show_all) const override;
|
||||||
|
|
||||||
|
/// \brief Op name getter.
|
||||||
|
/// \return Name of the current Op.
|
||||||
|
std::string Name() const override { return "CoNLL2000Op"; }
|
||||||
|
|
||||||
|
/// \brief brief description DatasetName name getter
|
||||||
|
/// \param[in] upper Needs to be capitalized or not
|
||||||
|
/// \return DatasetName of the current Op
|
||||||
|
std::string DatasetName(bool upper = false) const { return upper ? "CoNLL2000" : "conll2000"; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// \brief Parses a single row and puts the data into multiple TensorRows.
|
||||||
|
/// \param[in] column The content of the column.
|
||||||
|
/// \param[in] out_row The tensor table to put the parsed data in.
|
||||||
|
/// \param[in] index Serial number of column.
|
||||||
|
/// \return Status The error code returned.
|
||||||
|
Status LoadTensor(const std::vector<std::string> &column, TensorRow *out_row, size_t index);
|
||||||
|
|
||||||
|
/// \brief Removes excess space before and after the string.
|
||||||
|
/// \param[in] str The input string.
|
||||||
|
/// \return A string.
|
||||||
|
std::string Strip(const std::string &str);
|
||||||
|
|
||||||
|
/// \brief Split string based on a character delimiter.
|
||||||
|
/// \param[in] s The input string.
|
||||||
|
/// \param[in] delim Symbols for separating string.
|
||||||
|
/// \return A string vector.
|
||||||
|
std::vector<std::string> Split(const std::string &s, char delim);
|
||||||
|
|
||||||
|
/// \brief Specify that the corresponding data is translated into Tensor.
|
||||||
|
/// \param[in] word A list of words in a sentence.
|
||||||
|
/// \param[in] pos_tag Pos_tag part of speech.
|
||||||
|
/// \param[in] chunk_tag Chunk_tag part of speech.
|
||||||
|
/// \param[in] file The file to read.
|
||||||
|
/// \param[in] worker_id The id of the worker that is executing this function.
|
||||||
|
/// \return Status The error code returned.
|
||||||
|
Status Load(const std::vector<std::string> &word, const std::vector<std::string> &pos_tag,
|
||||||
|
const std::vector<std::string> &chunk_tag, const std::string &file, int32_t worker_id);
|
||||||
|
|
||||||
|
/// \brief Reads a text file and loads the data into multiple TensorRows.
|
||||||
|
/// \param[in] file The file to read.
|
||||||
|
/// \param[in] start_offset The start offset of file.
|
||||||
|
/// \param[in] end_offset The end offset of file.
|
||||||
|
/// \param[in] worker_id The id of the worker that is executing this function.
|
||||||
|
/// \return Status The error code returned.
|
||||||
|
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CONLL2000_OP_H_
|
|
@ -84,6 +84,7 @@ constexpr char kCifar10Node[] = "Cifar10Dataset";
|
||||||
constexpr char kCityscapesNode[] = "CityscapesDataset";
|
constexpr char kCityscapesNode[] = "CityscapesDataset";
|
||||||
constexpr char kCLUENode[] = "CLUEDataset";
|
constexpr char kCLUENode[] = "CLUEDataset";
|
||||||
constexpr char kCocoNode[] = "CocoDataset";
|
constexpr char kCocoNode[] = "CocoDataset";
|
||||||
|
constexpr char kCoNLL2000Node[] = "CoNLL2000Dataset";
|
||||||
constexpr char kCSVNode[] = "CSVDataset";
|
constexpr char kCSVNode[] = "CSVDataset";
|
||||||
constexpr char kDBpediaNode[] = "DBpediaDataset";
|
constexpr char kDBpediaNode[] = "DBpediaDataset";
|
||||||
constexpr char kDIV2KNode[] = "DIV2KDataset";
|
constexpr char kDIV2KNode[] = "DIV2KDataset";
|
||||||
|
|
|
@ -11,6 +11,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
||||||
cityscapes_node.cc
|
cityscapes_node.cc
|
||||||
clue_node.cc
|
clue_node.cc
|
||||||
coco_node.cc
|
coco_node.cc
|
||||||
|
conll2000_node.cc
|
||||||
csv_node.cc
|
csv_node.cc
|
||||||
dbpedia_node.cc
|
dbpedia_node.cc
|
||||||
div2k_node.cc
|
div2k_node.cc
|
||||||
|
|
|
@ -0,0 +1,203 @@
|
||||||
|
/**
|
||||||
|
Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/conll2000_node.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/conll2000_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
// Constructor for CoNLL2000Node.
|
||||||
|
CoNLL2000Node::CoNLL2000Node(const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
|
||||||
|
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||||
|
std::shared_ptr<DatasetCache> cache)
|
||||||
|
: NonMappableSourceNode(std::move(cache)),
|
||||||
|
dataset_dir_(dataset_dir),
|
||||||
|
usage_(usage),
|
||||||
|
num_samples_(num_samples),
|
||||||
|
shuffle_(shuffle),
|
||||||
|
num_shards_(num_shards),
|
||||||
|
shard_id_(shard_id),
|
||||||
|
conll2000_file_list_(WalkAllFiles(usage, dataset_dir)) {
|
||||||
|
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
|
||||||
|
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
|
||||||
|
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
|
||||||
|
// PreBuildSampler is phased out, this can be cleaned up.
|
||||||
|
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<DatasetNode> CoNLL2000Node::Copy() {
|
||||||
|
auto node =
|
||||||
|
std::make_shared<CoNLL2000Node>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CoNLL2000Node::Print(std::ostream &out) const {
|
||||||
|
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
|
||||||
|
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CoNLL2000Node::ValidateParams() {
|
||||||
|
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||||
|
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CoNLL2000Node", dataset_dir_));
|
||||||
|
RETURN_IF_NOT_OK(ValidateStringValue("CoNLL2000Node", usage_, {"train", "test", "all"}));
|
||||||
|
|
||||||
|
if (num_samples_ < 0) {
|
||||||
|
std::string err_msg = "CoNLL2000Node: Invalid number of samples: " + std::to_string(num_samples_);
|
||||||
|
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(ValidateDatasetShardParams("CoNLL2000Node", num_shards_, shard_id_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to build CoNLL2000Node.
|
||||||
|
Status CoNLL2000Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
|
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||||
|
|
||||||
|
// Sort the dataset files in a lexicographical order.
|
||||||
|
std::vector<std::string> sorted_dataset_files = conll2000_file_list_;
|
||||||
|
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||||
|
|
||||||
|
// Do internal Schema generation.
|
||||||
|
auto schema = std::make_unique<DataSchema>();
|
||||||
|
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("word", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||||
|
TensorShape scalar = TensorShape::CreateScalar();
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
schema->AddColumn(ColDescriptor("pos_tag", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
schema->AddColumn(ColDescriptor("chunk_tag", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||||
|
|
||||||
|
// Create and initialize CoNLL2000Op.
|
||||||
|
std::shared_ptr<CoNLL2000Op> conll2000_op =
|
||||||
|
std::make_shared<CoNLL2000Op>(num_workers_, num_samples_, worker_connector_size_, std::move(schema),
|
||||||
|
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||||
|
RETURN_IF_NOT_OK(conll2000_op->Init());
|
||||||
|
|
||||||
|
// If a global shuffle is used for CoNLL2000, it will inject a shuffle op over the CoNLL2000.
|
||||||
|
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
|
||||||
|
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset CoNLL2000's shuffle
|
||||||
|
// option to false.
|
||||||
|
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||||
|
// Inject ShuffleOp.
|
||||||
|
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||||
|
int64_t num_rows = 0;
|
||||||
|
|
||||||
|
// First, get the number of rows in the dataset.
|
||||||
|
RETURN_IF_NOT_OK(CoNLL2000Op::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||||
|
|
||||||
|
// Add the shuffle op after this op.
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||||
|
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
|
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
|
node_ops->push_back(shuffle_op);
|
||||||
|
}
|
||||||
|
conll2000_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
|
conll2000_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
|
// Add CoNLL2000Op.
|
||||||
|
node_ops->push_back(conll2000_op);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the shard id of node.
|
||||||
|
Status CoNLL2000Node::GetShardId(int32_t *shard_id) {
|
||||||
|
*shard_id = shard_id_;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Dataset size.
|
||||||
|
Status CoNLL2000Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
|
int64_t *dataset_size) {
|
||||||
|
if (dataset_size_ > 0) {
|
||||||
|
*dataset_size = dataset_size_;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
int64_t num_rows, sample_size = num_samples_;
|
||||||
|
RETURN_IF_NOT_OK(CoNLL2000Op::CountAllFileRows(conll2000_file_list_, &num_rows));
|
||||||
|
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
|
||||||
|
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||||
|
dataset_size_ = *dataset_size;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CoNLL2000Node::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["num_parallel_workers"] = num_workers_;
|
||||||
|
args["dataset_dir"] = dataset_dir_;
|
||||||
|
args["usage"] = usage_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
args["shuffle"] = shuffle_;
|
||||||
|
args["num_shards"] = num_shards_;
|
||||||
|
args["shard_id"] = shard_id_;
|
||||||
|
if (cache_ != nullptr) {
|
||||||
|
nlohmann::json cache_args;
|
||||||
|
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||||
|
args["cache"] = cache_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||||
|
// CoNLL2000 by itself is a non-mappable dataset that does not support sampling.
|
||||||
|
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||||
|
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||||
|
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||||
|
Status CoNLL2000Node::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||||
|
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||||
|
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a cache has been added into the ascendant tree over this CoNLL2000 node, then the cache will be executing
|
||||||
|
// a sampler for fetching the data. As such, any options in the CoNLL2000 node need to be reset to its defaults so
|
||||||
|
// that this CoNLL2000 node will produce the full set of data into the cache.
|
||||||
|
Status CoNLL2000Node::MakeSimpleProducer() {
|
||||||
|
shard_id_ = 0;
|
||||||
|
num_shards_ = 1;
|
||||||
|
shuffle_ = ShuffleMode::kFalse;
|
||||||
|
num_samples_ = 0;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> CoNLL2000Node::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
|
||||||
|
std::vector<std::string> conll2000_file_list;
|
||||||
|
Path train_prefix("train.txt");
|
||||||
|
Path test_prefix("test.txt");
|
||||||
|
Path dir(dataset_dir);
|
||||||
|
|
||||||
|
if (usage == "train") {
|
||||||
|
Path temp_path = dir / train_prefix;
|
||||||
|
conll2000_file_list.push_back(temp_path.ToString());
|
||||||
|
} else if (usage == "test") {
|
||||||
|
Path temp_path = dir / test_prefix;
|
||||||
|
conll2000_file_list.push_back(temp_path.ToString());
|
||||||
|
} else {
|
||||||
|
Path temp_path = dir / train_prefix;
|
||||||
|
conll2000_file_list.push_back(temp_path.ToString());
|
||||||
|
Path temp_path1 = dir / test_prefix;
|
||||||
|
conll2000_file_list.push_back(temp_path1.ToString());
|
||||||
|
}
|
||||||
|
return conll2000_file_list;
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,130 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONLL2000_NODE_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONLL2000_NODE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
/// \class CoNLL2000Node.
|
||||||
|
/// \brief A Dataset derived class to represent CoNLL2000 dataset.
|
||||||
|
class CoNLL2000Node : public NonMappableSourceNode {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
CoNLL2000Node(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
|
||||||
|
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~CoNLL2000Node() = default;
|
||||||
|
|
||||||
|
/// \brief Node name getter.
|
||||||
|
/// \return Name of the current node.
|
||||||
|
std::string Name() const override { return "CoNLL2000Node"; }
|
||||||
|
|
||||||
|
/// \brief Print the description.
|
||||||
|
/// \param[out] out The output stream to write output to.
|
||||||
|
void Print(std::ostream &out) const override;
|
||||||
|
|
||||||
|
/// \brief Copy the node to a new object.
|
||||||
|
/// \return A shared pointer to the new copy.
|
||||||
|
std::shared_ptr<DatasetNode> Copy() override;
|
||||||
|
|
||||||
|
/// \brief A base class override function to create the required runtime dataset op objects for this class.
|
||||||
|
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||||
|
/// \return Status Status::OK() if build successfully.
|
||||||
|
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||||
|
|
||||||
|
/// \brief Parameters validation.
|
||||||
|
/// \return Status Status::OK() if all the parameters are valid.
|
||||||
|
Status ValidateParams() override;
|
||||||
|
|
||||||
|
/// \brief Get the shard id of node.
|
||||||
|
/// \param[in] shard_id The shard id.
|
||||||
|
/// \return Status Status::OK() if get shard id successfully.
|
||||||
|
Status GetShardId(int32_t *shard_id) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for GetDatasetSize.
|
||||||
|
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||||
|
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||||
|
/// dataset size at the expense of accuracy.
|
||||||
|
/// \param[out] dataset_size the size of the dataset.
|
||||||
|
/// \return Status of the function.
|
||||||
|
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
|
int64_t *dataset_size) override;
|
||||||
|
|
||||||
|
/// \brief Getter functions.
|
||||||
|
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||||
|
|
||||||
|
/// \brief Getter functions.
|
||||||
|
const std::string &Usage() const { return usage_; }
|
||||||
|
|
||||||
|
/// \brief Getter functions.
|
||||||
|
int64_t NumSamples() const { return num_samples_; }
|
||||||
|
|
||||||
|
/// \brief Getter functions.
|
||||||
|
int32_t NumShards() const { return num_shards_; }
|
||||||
|
|
||||||
|
/// \brief Getter functions.
|
||||||
|
int32_t ShardId() const { return shard_id_; }
|
||||||
|
|
||||||
|
/// \brief Getter functions.
|
||||||
|
ShuffleMode Shuffle() const { return shuffle_; }
|
||||||
|
|
||||||
|
/// \brief Get the arguments of node.
|
||||||
|
/// \param[out] out_json JSON string of all attributes.
|
||||||
|
/// \return Status of the function.
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
|
/// \brief CoNLL2000 by itself is a non-mappable dataset that does not support sampling.
|
||||||
|
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||||
|
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||||
|
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||||
|
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||||
|
/// \param[in] sampler The sampler to setup.
|
||||||
|
/// \return Status of the function.
|
||||||
|
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||||
|
|
||||||
|
/// \brief If a cache has been added into the ascendant tree over this CoNLL2000 node, then the cache will be
|
||||||
|
/// executing a sampler for fetching the data. As such, any options in the CoNLL2000 node need to be
|
||||||
|
/// reset to its defaults so that this CoNLL2000 node will produce the full set of data into the cache.
|
||||||
|
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||||
|
/// \return Status of the function.
|
||||||
|
Status MakeSimpleProducer() override;
|
||||||
|
|
||||||
|
/// \Read all files in the directory.
|
||||||
|
/// \param[in] usage Part of dataset of CoNLL2000.
|
||||||
|
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||||
|
/// \return Status The status code returned.
|
||||||
|
std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string dataset_dir_;
|
||||||
|
std::string usage_;
|
||||||
|
int64_t num_samples_;
|
||||||
|
int32_t num_shards_;
|
||||||
|
int32_t shard_id_;
|
||||||
|
ShuffleMode shuffle_;
|
||||||
|
std::vector<std::string> conll2000_file_list_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONLL2000_NODE_H_
|
|
@ -1776,6 +1776,70 @@ inline std::shared_ptr<CocoDataset> MS_API Coco(const std::string &dataset_dir,
|
||||||
decode, sampler, cache, extra_metadata);
|
decode, sampler, cache, extra_metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// \class CoNLL2000Dataset
|
||||||
|
/// \brief A source dataset for reading and parsing CoNLL2000Dataset.
|
||||||
|
class MS_API CoNLL2000Dataset : public Dataset {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor of CoNLL2000Dataset.
|
||||||
|
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||||
|
/// \param[in] usage The type of data list txt file to be read, can be "train", "test" or "all".
|
||||||
|
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||||
|
/// \param[in] shuffle The mode for shuffling data every epoch.
|
||||||
|
/// Can be any of:
|
||||||
|
/// ShuffleMode.kFalse - No shuffling is performed.
|
||||||
|
/// ShuffleMode.kFiles - Shuffle files only.
|
||||||
|
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
|
||||||
|
/// \param[in] num_shards Number of shards that the dataset should be divided into.
|
||||||
|
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||||
|
/// specified only when num_shards is also specified.
|
||||||
|
/// \param[in] cache Tensor cache to use.
|
||||||
|
CoNLL2000Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
|
||||||
|
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||||
|
const std::shared_ptr<DatasetCache> &cache);
|
||||||
|
|
||||||
|
/// \brief Destructor of CoNLL2000Dataset.
|
||||||
|
~CoNLL2000Dataset() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// \brief Function to create a CoNLL2000Dataset.
|
||||||
|
/// \note The generated dataset has three column ['word', 'pos_tag', 'chunk_tag'].
|
||||||
|
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||||
|
/// \param[in] usage Part of dataset of CoNLL2000, can be "train", "test" or "all" (default="all").
|
||||||
|
/// \param[in] num_samples The number of samples to be included in the dataset
|
||||||
|
/// (Default = 0, means all samples).
|
||||||
|
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
|
||||||
|
/// Can be any of:
|
||||||
|
/// ShuffleMode::kFalse - No shuffling is performed.
|
||||||
|
/// ShuffleMode::kFiles - Shuffle files only.
|
||||||
|
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
|
||||||
|
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
|
||||||
|
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||||
|
/// specified only when num_shards is also specified (Default = 0).
|
||||||
|
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||||
|
/// \return Shared pointer to the CoNLL2000Dataset.
|
||||||
|
/// \par Example
|
||||||
|
/// \code
|
||||||
|
/// /* Define dataset path and MindData object */
|
||||||
|
/// std::string folder_path = "/path/to/conll2000_dataset_directory";
|
||||||
|
/// std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kGlobal);
|
||||||
|
///
|
||||||
|
/// /* Create iterator to read dataset */
|
||||||
|
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
/// iter->GetNextRow(&row);
|
||||||
|
///
|
||||||
|
/// /* Note: In CoNLL2000 dataset, each dictionary has keys "word", "pos_tag", "chunk_tag" */
|
||||||
|
/// auto word = row["word"];
|
||||||
|
/// \endcode
|
||||||
|
inline std::shared_ptr<CoNLL2000Dataset> MS_API CoNLL2000(const std::string &dataset_dir,
|
||||||
|
const std::string &usage = "all", int64_t num_samples = 0,
|
||||||
|
ShuffleMode shuffle = ShuffleMode::kGlobal,
|
||||||
|
int32_t num_shards = 1, int32_t shard_id = 0,
|
||||||
|
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||||
|
return std::make_shared<CoNLL2000Dataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
|
||||||
|
num_shards, shard_id, cache);
|
||||||
|
}
|
||||||
|
|
||||||
/// \class CSVDataset
|
/// \class CSVDataset
|
||||||
/// \brief A source dataset that reads and parses comma-separated values (CSV) datasets.
|
/// \brief A source dataset that reads and parses comma-separated values (CSV) datasets.
|
||||||
class MS_API CSVDataset : public Dataset {
|
class MS_API CSVDataset : public Dataset {
|
||||||
|
|
|
@ -74,7 +74,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
||||||
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \
|
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \
|
||||||
check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset, check_svhn_dataset, \
|
check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset, check_svhn_dataset, \
|
||||||
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \
|
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \
|
||||||
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset
|
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset,\
|
||||||
|
check_conll2000_dataset
|
||||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||||
get_prefetch_size
|
get_prefetch_size
|
||||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||||
|
@ -6662,6 +6663,64 @@ class CocoDataset(MappableDataset):
|
||||||
return self._class_indexing
|
return self._class_indexing
|
||||||
|
|
||||||
|
|
||||||
|
class CoNLL2000Dataset(SourceDataset):
|
||||||
|
"""
|
||||||
|
A source dataset that reads and parses CoNLL2000 dataset.
|
||||||
|
|
||||||
|
The generated dataset has three columns: :py:obj:`[word, pos_tag, chunk_tag]`.
|
||||||
|
The tensor of column :py:obj:`word` is of the string type.
|
||||||
|
The tensor of column :py:obj:`pos_tag` is of the string type.
|
||||||
|
The tensor of column :py:obj:`chunk_tag` is of the string type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||||
|
usage (str, optional): Usage of this dataset, can be `train`, `test`, or `all`. `train` will read from
|
||||||
|
8936 train samples, `test` will read from 2,012 test samples,
|
||||||
|
`all` will read from all 1,0948 samples (default=None, all samples).
|
||||||
|
num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
|
||||||
|
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
|
||||||
|
(default=Shuffle.GLOBAL).
|
||||||
|
If shuffle is False, no shuffling will be performed;
|
||||||
|
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
||||||
|
Otherwise, there are two levels of shuffling:
|
||||||
|
|
||||||
|
- Shuffle.GLOBAL: Shuffle both the files and samples.
|
||||||
|
|
||||||
|
- Shuffle.FILES: Shuffle files only.
|
||||||
|
|
||||||
|
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||||
|
When this argument is specified, `num_samples` reflects the max sample number of per shard.
|
||||||
|
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||||
|
argument can only be specified when num_shards is also specified.
|
||||||
|
num_parallel_workers (int, optional): Number of workers to read the data
|
||||||
|
(default=None, number set in the config).
|
||||||
|
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
|
||||||
|
(default=None, which means no cache is used).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If dataset_dir does not contain data files.
|
||||||
|
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||||
|
RuntimeError: If num_shards is specified but shard_id is None.
|
||||||
|
RuntimeError: If shard_id is specified but num_shards is None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> conll2000_dataset_dir = "/path/to/conll2000_dataset_dir"
|
||||||
|
>>> dataset = ds.CoNLL2000Dataset(dataset_files=conll2000_dataset_dir, usage='all')
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_conll2000_dataset
|
||||||
|
def __init__(self, dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None,
|
||||||
|
shard_id=None, num_parallel_workers=None, cache=None):
|
||||||
|
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
|
||||||
|
num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||||
|
self.dataset_dir = dataset_dir
|
||||||
|
self.usage = replace_none(usage, 'all')
|
||||||
|
|
||||||
|
def parse(self, children=None):
|
||||||
|
return cde.CoNLL2000Node(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
|
||||||
|
self.shard_id)
|
||||||
|
|
||||||
|
|
||||||
class CelebADataset(MappableDataset):
|
class CelebADataset(MappableDataset):
|
||||||
"""
|
"""
|
||||||
A source dataset for reading and parsing CelebA dataset.
|
A source dataset for reading and parsing CelebA dataset.
|
||||||
|
|
|
@ -2223,3 +2223,33 @@ def check_yahoo_answers_dataset(method):
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_conll2000_dataset(method):
|
||||||
|
""" A wrapper that wraps a parameter checker around the original Dataset(CoNLL2000Dataset)."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||||
|
|
||||||
|
# check dataset_dir
|
||||||
|
dataset_dir = param_dict.get('dataset_dir')
|
||||||
|
check_dir(dataset_dir)
|
||||||
|
|
||||||
|
# check usage
|
||||||
|
usage = param_dict.get('usage')
|
||||||
|
if usage is not None:
|
||||||
|
check_valid_str(usage, ["train", "test", "all"], "usage")
|
||||||
|
|
||||||
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||||
|
check_sampler_shuffle_shard_options(param_dict)
|
||||||
|
|
||||||
|
cache = param_dict.get('cache')
|
||||||
|
check_cache_option(cache)
|
||||||
|
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -20,6 +20,7 @@ SET(DE_UT_SRCS
|
||||||
c_api_dataset_cityscapes_test.cc
|
c_api_dataset_cityscapes_test.cc
|
||||||
c_api_dataset_clue_test.cc
|
c_api_dataset_clue_test.cc
|
||||||
c_api_dataset_coco_test.cc
|
c_api_dataset_coco_test.cc
|
||||||
|
c_api_dataset_conll2000_test.cc
|
||||||
c_api_dataset_config_test.cc
|
c_api_dataset_config_test.cc
|
||||||
c_api_dataset_csv_test.cc
|
c_api_dataset_csv_test.cc
|
||||||
c_api_dataset_dbpedia_test.cc
|
c_api_dataset_dbpedia_test.cc
|
||||||
|
|
|
@ -0,0 +1,628 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "minddata/dataset/core/global_context.h"
|
||||||
|
#include "minddata/dataset/include/dataset/datasets.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
|
||||||
|
using mindspore::dataset::ShuffleMode;
|
||||||
|
|
||||||
|
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||||
|
protected:
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetBasic) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetBasic.";
|
||||||
|
// Test CoNLL2000 Dataset with single text file and many default inputs.
|
||||||
|
|
||||||
|
// Set configuration.
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(987);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(2);
|
||||||
|
|
||||||
|
// Create a CoNLL2000Dataset, with single text file.
|
||||||
|
// Note: valid.txt has 3 rows.
|
||||||
|
// Use 2 samples.
|
||||||
|
// Use defaults for other input parameters.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::vector<std::string> column_names = {"word", "pos_tag", "chunk_tag"};
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "train", 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row.
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
EXPECT_NE(row.find("word"), row.end());
|
||||||
|
|
||||||
|
std::vector<std::vector<std::string>> expected_result = {
|
||||||
|
{"Challenge", "NNP", "O"}, {"Her", "PP$", "B-NP"}, {"To", "TO", "I-VP"}};
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
for (int j = 0; j < column_names.size(); j++) {
|
||||||
|
auto word = row[column_names[j]];
|
||||||
|
std::shared_ptr<Tensor> de_word;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_word->GetItemAt(&sv, {{}}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||||
|
}
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
// Expect 3 samples.
|
||||||
|
// Manually terminate the pipeline.
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration.
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetBasicWithPipeline) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetBasicWithPipeline.";
|
||||||
|
// Test CoNLL2000 Dataset with single text file and many default inputs.
|
||||||
|
|
||||||
|
// Set configuration.
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(987);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(2);
|
||||||
|
|
||||||
|
// Create two CoNLL2000Dataset, with single text file.
|
||||||
|
// Note: test.txt has 3 rows.
|
||||||
|
// Use 2 samples.
|
||||||
|
// Use defaults for other input parameters.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds1 = CoNLL2000(dataset_dir, "test", 0, ShuffleMode::kFalse);
|
||||||
|
std::shared_ptr<Dataset> ds2 = CoNLL2000(dataset_dir, "test", 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
EXPECT_NE(ds2, nullptr);
|
||||||
|
|
||||||
|
// Create two Repeat operation on ds.
|
||||||
|
int32_t repeat_num = 2;
|
||||||
|
ds1 = ds1->Repeat(repeat_num);
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
repeat_num = 3;
|
||||||
|
ds2 = ds2->Repeat(repeat_num);
|
||||||
|
EXPECT_NE(ds2, nullptr);
|
||||||
|
|
||||||
|
// Create a Concat operation on the ds.
|
||||||
|
ds1 = ds1->Concat({ds2});
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row.
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
std::vector<std::string> column_names = {"word", "pos_tag", "chunk_tag"};
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("word"), row.end());
|
||||||
|
std::vector<std::vector<std::string>> expected_result = {{"He", "PBP", "B-NP"}, {"The", "DT", "B-NP"}};
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto word = row["word"];
|
||||||
|
MS_LOG(INFO) << "Tensor word shape: " << word.Shape();
|
||||||
|
i++;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect 10 samples.
|
||||||
|
EXPECT_EQ(i, 10);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline.
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration.
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000Getters) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000Getters.";
|
||||||
|
// Test CoNLL2000 Dataset with single text file and many default inputs.
|
||||||
|
|
||||||
|
// Set configuration.
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(987);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(2);
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset, with single text file.
|
||||||
|
// Note: test.txt has 1 rows.
|
||||||
|
// Use 2 samples.
|
||||||
|
// Use defaults for other input parameters.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "test", 2, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
std::vector<std::string> column_names = {"word", "pos_tag", "chunk_tag"};
|
||||||
|
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||||
|
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||||
|
|
||||||
|
std::shared_ptr<Dataset> ds1 = CoNLL2000(dataset_dir, "", 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
|
||||||
|
EXPECT_EQ(ds1->GetDatasetSize(), 30);
|
||||||
|
// Restore configuration.
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail1) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail1.";
|
||||||
|
|
||||||
|
// Create a CoNLL2000Dataset.
|
||||||
|
// with invalid samplers=-1.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "test", -1, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: CoNLL2000 number of samples cannot be negative.
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail2) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail2.";
|
||||||
|
|
||||||
|
// Attempt to create a CoNLL2000 Dataset.
|
||||||
|
// with wrongful empty dataset_files input.
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000("NotExistFile", "test", 2, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: dataset_files is not specified.
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail3) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail3.";
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset.
|
||||||
|
// with non-existent dataset_files input.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "dev", 2, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: specified dataset_files does not exist.
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail4) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail4.";
|
||||||
|
|
||||||
|
// Create a CoNLL2000Dataset.
|
||||||
|
// with empty string dataset_files input.
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000("", "test", 2, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
std::cout << iter;
|
||||||
|
// Expect failure: specified dataset_files does not exist.
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail5) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail5.";
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset.
|
||||||
|
// with invalid num_shards=0 value.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "test", 2, ShuffleMode::kFalse, 0);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: Number of shards cannot be <=0.
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail6) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail6.";
|
||||||
|
|
||||||
|
// Create a CoNLL2000Dataset.
|
||||||
|
// with invalid shard_id=-1 value.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "test", 2, ShuffleMode::kFalse, -1);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: shard_id cannot be negative.
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail7) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail7.";
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset.
|
||||||
|
// with invalid shard_id=2 and num_shards=2 combination.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "test", 2, ShuffleMode::kFalse, 2, 2);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: Cannot have shard_id >= num_shards.
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleFalse) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleFalse.";
|
||||||
|
// Test CoNLL2000 Dataset with two text files and no shuffle, num_parallel_workers=4.
|
||||||
|
|
||||||
|
// Set configuration.
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(654);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset, with two text files, test.txt and train.txt, in lexicographical order.
|
||||||
|
// Note: test.txt has 2 rows.
|
||||||
|
// Note: train.txt has 3 rows.
|
||||||
|
// Use default of all samples.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row.
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
std::vector<std::string> column_names = {"word", "pos_tag", "chunk_tag"};
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("word"), row.end());
|
||||||
|
std::vector<std::vector<std::string>> expected_result = {{"He", "PBP", "B-NP"},
|
||||||
|
{"Challenge", "NNP", "O"},
|
||||||
|
{"The", "DT", "B-NP"},
|
||||||
|
{"Her", "PP$", "B-NP"},
|
||||||
|
{"To", "TO", "I-VP"}};
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
for (int j = 0; j < column_names.size(); j++) {
|
||||||
|
auto word = row[column_names[j]];
|
||||||
|
std::shared_ptr<Tensor> de_word;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_word->GetItemAt(&sv, {{}}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||||
|
}
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 5);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline.
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration.
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleFilesA) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleFilesA.";
|
||||||
|
// Test CoNLL2000 Dataset with files shuffle, num_parallel_workers=4.
|
||||||
|
|
||||||
|
// Set configuration.
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(135);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset, with two text files,test.txt and train.txt, in lexicographical order.
|
||||||
|
// Note: test.txt has 2 rows.
|
||||||
|
// Note: train.txt has 3 rows.
|
||||||
|
// Set shuffle to files shuffle.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row.
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
std::vector<std::string> column_names = {"word", "pos_tag", "chunk_tag"};
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("word"), row.end());
|
||||||
|
std::vector<std::vector<std::string>> expected_result = {{"He", "PBP", "B-NP"},
|
||||||
|
{"Challenge", "NNP", "O"},
|
||||||
|
{"The", "DT", "B-NP"},
|
||||||
|
{"Her", "PP$", "B-NP"},
|
||||||
|
{"To", "TO", "I-VP"}};
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
for (int j = 0; j < column_names.size(); j++) {
|
||||||
|
auto word = row[column_names[j]];
|
||||||
|
std::shared_ptr<Tensor> de_word;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_word->GetItemAt(&sv, {{}}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||||
|
}
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect 3 + 1 + 2 = 6 samples.
|
||||||
|
EXPECT_EQ(i, 5);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline.
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration.
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleFilesB) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleFilesB.";
|
||||||
|
// Test CoNLL2000 Dataset with files shuffle, num_parallel_workers=4.
|
||||||
|
|
||||||
|
// Set configuration.
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(135);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset, with two text files test.txt and train.txt, in lexicographical order.
|
||||||
|
// Note: test.txt has 2 rows.
|
||||||
|
// Note: train.txt has 3 rows.
|
||||||
|
// Set shuffle to files shuffle.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row.
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
std::vector<std::string> column_names = {"word", "pos_tag", "chunk_tag"};
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("word"), row.end());
|
||||||
|
std::vector<std::vector<std::string>> expected_result = {{"He", "PBP", "B-NP"},
|
||||||
|
{"Challenge", "NNP", "O"},
|
||||||
|
{"The", "DT", "B-NP"},
|
||||||
|
{"Her", "PP$", "B-NP"},
|
||||||
|
{"To", "TO", "I-VP"}};
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
for (int j = 0; j < column_names.size(); j++) {
|
||||||
|
auto word = row[column_names[j]];
|
||||||
|
std::shared_ptr<Tensor> de_word;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_word->GetItemAt(&sv, {{}}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||||
|
}
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect 3 + 1 + 2 = 6 samples.
|
||||||
|
EXPECT_EQ(i, 5);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline.
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration.
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleGlobal1A) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleGlobalA.";
|
||||||
|
// Test CoNLL2000 Dataset with 1 text file, global shuffle, num_parallel_workers=4.
|
||||||
|
|
||||||
|
// Set configuration.
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(246);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset, with one text files.
|
||||||
|
// Note: test.txt has 2 rows.
|
||||||
|
// Set shuffle to global shuffle.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "test", 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row.
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
std::vector<std::string> column_names = {"word", "pos_tag", "chunk_tag"};
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("word"), row.end());
|
||||||
|
std::vector<std::vector<std::string>> expected_result = {{"He", "PBP", "B-NP"}, {"The", "DT", "B-NP"}};
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
for (int j = 0; j < column_names.size(); j++) {
|
||||||
|
auto word = row[column_names[j]];
|
||||||
|
std::shared_ptr<Tensor> de_word;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_word->GetItemAt(&sv, {{}}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||||
|
}
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect 1 samples.
|
||||||
|
EXPECT_EQ(i, 2);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline.
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration.
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: CoNLL2000ChunkingDataset.
|
||||||
|
/// Description: test CoNLL2000ChunkingDataset in pipeline mode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleGlobalB) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleGlobalB.";
|
||||||
|
// Test CoNLL200 Dataset with 2 text files, global shuffle, num_parallel_workers=4.
|
||||||
|
|
||||||
|
// Set configuration.
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(246);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||||
|
|
||||||
|
// Create a CoNLL2000 Dataset, with two text files.
|
||||||
|
// Note: test.txt has 2 rows.
|
||||||
|
// Note: train.txt has 3 rows.
|
||||||
|
// Set shuffle to global shuffle.
|
||||||
|
std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset";
|
||||||
|
std::shared_ptr<Dataset> ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row.
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
std::vector<std::string> column_names = {"word", "pos_tag", "chunk_tag"};
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("word"), row.end());
|
||||||
|
std::vector<std::vector<std::string>> expected_result = {{"He", "PBP", "B-NP"},
|
||||||
|
{"Challenge", "NNP", "O"},
|
||||||
|
{"The", "DT", "B-NP"},
|
||||||
|
{"Her", "PP$", "B-NP"},
|
||||||
|
{"To", "TO", "I-VP"}};
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
for (int j = 0; j < column_names.size(); j++) {
|
||||||
|
auto word = row[column_names[j]];
|
||||||
|
std::shared_ptr<Tensor> de_word;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_word->GetItemAt(&sv, {{}}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||||
|
}
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect 3 + 1 + 2 = 6 samples.
|
||||||
|
EXPECT_EQ(i, 5);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline.
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration.
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
He PBP B-NP
|
||||||
|
reckons VBZ B-VP
|
||||||
|
the DT B-NP
|
||||||
|
current JJ I-NP
|
||||||
|
account NN I-NP
|
||||||
|
. . O
|
||||||
|
|
||||||
|
The DT B-NP
|
||||||
|
1.8 CD I-NP
|
||||||
|
billion CD I-NP
|
||||||
|
in IN B-PP
|
||||||
|
September NNP B-NP
|
||||||
|
. . O
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
Challenge NNP O
|
||||||
|
of IN B-PP
|
||||||
|
the DT B-NP
|
||||||
|
August NNP B-NP
|
||||||
|
month NNP B-NP
|
||||||
|
. . O
|
||||||
|
|
||||||
|
Her PP$ B-NP
|
||||||
|
's POS B-NP
|
||||||
|
chancellor NNP O
|
||||||
|
at IN B-PP
|
||||||
|
Lawson NNP I-NP
|
||||||
|
. . O
|
||||||
|
|
||||||
|
To TO I-VP
|
||||||
|
economists NNS B-NP
|
||||||
|
, , O
|
||||||
|
foreign JJ B-NP
|
||||||
|
exchange NN I-NP
|
||||||
|
. . O
|
||||||
|
|
|
@ -0,0 +1,345 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
from mindspore import log as logger
|
||||||
|
from util import config_get_set_num_parallel_workers, config_get_set_seed
|
||||||
|
|
||||||
|
DATA_DIR = '../data/dataset/testCoNLL2000Dataset'
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_one_file():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||||
|
count = 0
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
logger.info("{}".format(i["word"]))
|
||||||
|
count += 1
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_all_file():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False)
|
||||||
|
count = 0
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
logger.info("{}".format(i["word"]))
|
||||||
|
count += 1
|
||||||
|
assert count == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_num_samples_none():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset
|
||||||
|
Expectation: throw correct error and message
|
||||||
|
"""
|
||||||
|
# Do not provide a num_samples argument, so it would be None by default
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||||
|
count = 0
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
logger.info("{}".format(i["word"]))
|
||||||
|
count += 1
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_shuffle_false_num_parallel_workers_4():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||||
|
original_seed = config_get_set_seed(987)
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False)
|
||||||
|
count = 0
|
||||||
|
numword = 5
|
||||||
|
line = ["He", "reckons", "the", "current", "account", ".",
|
||||||
|
"Challenge", "of", "the", "August", "month", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", ".",
|
||||||
|
"Her", "'s", "chancellor", "at", "Lawson", ".",
|
||||||
|
"To", "economists", ",", "foreign", "exchange", "."]
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
for j in range(numword):
|
||||||
|
strs = i["word"][j].item().decode("utf8")
|
||||||
|
assert strs == line[count*6+j]
|
||||||
|
count += 1
|
||||||
|
assert count == 5
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_shuffle_false_num_parallel_workers_1():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
original_seed = config_get_set_seed(987)
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False)
|
||||||
|
count = 0
|
||||||
|
numword = 6
|
||||||
|
line = ["He", "reckons", "the", "current", "account", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", ".",
|
||||||
|
"Challenge", "of", "the", "August", "month", ".",
|
||||||
|
"Her", "'s", "chancellor", "at", "Lawson", ".",
|
||||||
|
"To", "economists", ",", "foreign", "exchange", "."]
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
for j in range(numword):
|
||||||
|
strs = i["word"][j].item().decode("utf8")
|
||||||
|
assert strs == line[count*6+j]
|
||||||
|
count += 1
|
||||||
|
assert count == 5
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_shuffle_files_num_parallel_workers_4():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||||
|
original_seed = config_get_set_seed(135)
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES)
|
||||||
|
count = 0
|
||||||
|
numword = 6
|
||||||
|
line = ["He", "reckons", "the", "current", "account", ".",
|
||||||
|
"Challenge", "of", "the", "August", "month", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", ".",
|
||||||
|
"Her", "'s", "chancellor", "at", "Lawson", ".",
|
||||||
|
"To", "economists", ",", "foreign", "exchange", "."]
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
for j in range(numword):
|
||||||
|
strs = i["word"][j].item().decode("utf8")
|
||||||
|
assert strs == line[count*6+j]
|
||||||
|
count += 1
|
||||||
|
assert count == 5
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_shuffle_files_num_parallel_workers_1():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
original_seed = config_get_set_seed(135)
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES)
|
||||||
|
count = 0
|
||||||
|
numword = 6
|
||||||
|
line = ["He", "reckons", "the", "current", "account", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", ".",
|
||||||
|
"Challenge", "of", "the", "August", "month", ".",
|
||||||
|
"Her", "'s", "chancellor", "at", "Lawson", ".",
|
||||||
|
"To", "economists", ",", "foreign", "exchange", "."]
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
for j in range(numword):
|
||||||
|
strs = i["word"][j].item().decode("utf8")
|
||||||
|
assert strs == line[count*6+j]
|
||||||
|
count += 1
|
||||||
|
assert count == 5
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_shuffle_global_num_parallel_workers_4():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||||
|
original_seed = config_get_set_seed(246)
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL)
|
||||||
|
count = 0
|
||||||
|
numword = 6
|
||||||
|
line = ["Challenge", "of", "the", "August", "month", ".",
|
||||||
|
"To", "economists", ",", "foreign", "exchange", ".",
|
||||||
|
"Her", "'s", "chancellor", "at", "Lawson", ".",
|
||||||
|
"He", "reckons", "the", "current", "account", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", "."]
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
for j in range(numword):
|
||||||
|
strs = i["word"][j].item().decode("utf8")
|
||||||
|
assert strs == line[count*6+j]
|
||||||
|
count += 1
|
||||||
|
assert count == 5
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_shuffle_global_num_parallel_workers_1():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
original_seed = config_get_set_seed(246)
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL)
|
||||||
|
count = 0
|
||||||
|
numword = 6
|
||||||
|
line = ["Challenge", "of", "the", "August", "month", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", ".",
|
||||||
|
"To", "economists", ",", "foreign", "exchange", ".",
|
||||||
|
"Her", "'s", "chancellor", "at", "Lawson", ".",
|
||||||
|
"He", "reckons", "the", "current", "account", "."]
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
for j in range(numword):
|
||||||
|
strs = i["word"][j].item().decode("utf8")
|
||||||
|
assert strs == line[count*6+j]
|
||||||
|
count += 1
|
||||||
|
assert count == 5
|
||||||
|
# Restore configuration
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_num_samples():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False, num_samples=2)
|
||||||
|
count = 0
|
||||||
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
count += 1
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_distribution():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False, num_shards=2, shard_id=1)
|
||||||
|
count = 0
|
||||||
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
count += 1
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_repeat():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||||
|
data = data.repeat(3)
|
||||||
|
count = 0
|
||||||
|
numword = 6
|
||||||
|
line = ["He", "reckons", "the", "current", "account", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", ".",
|
||||||
|
"He", "reckons", "the", "current", "account", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", ".",
|
||||||
|
"He", "reckons", "the", "current", "account", ".",
|
||||||
|
"The", "1.8", "billion", "in", "September", ".",]
|
||||||
|
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
for j in range(numword):
|
||||||
|
strs = i["word"][j].item().decode("utf8")
|
||||||
|
assert strs == line[count*6+j]
|
||||||
|
count += 1
|
||||||
|
assert count == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_get_datasetsize():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||||
|
size = data.get_dataset_size()
|
||||||
|
assert size == 12
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_to_device():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||||
|
data = data.to_device()
|
||||||
|
data.send()
|
||||||
|
|
||||||
|
|
||||||
|
def test_conll2000_dataset_exceptions():
|
||||||
|
"""
|
||||||
|
Feature: CoNLL2000ChunkingDataset.
|
||||||
|
Description: test param check of CoNLL2000ChunkingDataset.
|
||||||
|
Expectation: throw correct error and message.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = ds.CoNLL2000Dataset(DATA_DIR, usage="test", num_samples=-1)
|
||||||
|
assert "num_samples exceeds the boundary" in str(error_info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = ds.CoNLL2000Dataset("NotExistFile", usage="test")
|
||||||
|
assert "The folder NotExistFile does not exist or is not a directory or permission denied!" in str(error_info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = ds.TextFileDataset("")
|
||||||
|
assert "The following patterns did not match any files" in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def exception_func(item):
|
||||||
|
raise Exception("Error occur!")
|
||||||
|
with pytest.raises(RuntimeError) as error_info:
|
||||||
|
data = data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
|
||||||
|
data = data.map(operations=exception_func, input_columns=["word"], num_parallel_workers=1)
|
||||||
|
for _ in data.__iter__():
|
||||||
|
pass
|
||||||
|
assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_conll2000_dataset_one_file()
|
||||||
|
test_conll2000_dataset_all_file()
|
||||||
|
test_conll2000_dataset_num_samples_none()
|
||||||
|
test_conll2000_dataset_shuffle_false_num_parallel_workers_4()
|
||||||
|
test_conll2000_dataset_shuffle_false_num_parallel_workers_1()
|
||||||
|
test_conll2000_dataset_shuffle_files_num_parallel_workers_4()
|
||||||
|
test_conll2000_dataset_shuffle_files_num_parallel_workers_1()
|
||||||
|
test_conll2000_dataset_shuffle_global_num_parallel_workers_4()
|
||||||
|
test_conll2000_dataset_shuffle_global_num_parallel_workers_1()
|
||||||
|
test_conll2000_dataset_num_samples()
|
||||||
|
test_conll2000_dataset_distribution()
|
||||||
|
test_conll2000_dataset_repeat()
|
||||||
|
test_conll2000_dataset_get_datasetsize()
|
||||||
|
test_conll2000_dataset_to_device()
|
||||||
|
test_conll2000_dataset_exceptions()
|
Loading…
Reference in New Issue