From 289a7997df343dc51a468b62cbc1a264a2ae1889 Mon Sep 17 00:00:00 2001 From: fattycat <908929260@qq.com> Date: Tue, 21 Dec 2021 15:31:04 +0800 Subject: [PATCH] [feat] [assistant] [#I40GYB] add new loader CoNLL2000ChunkingDataset --- .../ccsrc/minddata/dataset/api/datasets.cc | 9 + .../engine/ir/datasetops/source/bindings.cc | 13 + .../engine/datasetops/source/CMakeLists.txt | 1 + .../engine/datasetops/source/conll2000_op.cc | 181 +++++ .../engine/datasetops/source/conll2000_op.h | 96 +++ .../engine/ir/datasetops/dataset_node.h | 1 + .../ir/datasetops/source/CMakeLists.txt | 1 + .../ir/datasetops/source/conll2000_node.cc | 203 ++++++ .../ir/datasetops/source/conll2000_node.h | 130 ++++ .../dataset/include/dataset/datasets.h | 64 ++ .../mindspore/dataset/engine/datasets.py | 61 +- .../mindspore/dataset/engine/validators.py | 30 + tests/ut/cpp/dataset/CMakeLists.txt | 1 + .../dataset/c_api_dataset_conll2000_test.cc | 628 ++++++++++++++++++ .../dataset/testCoNLL2000Dataset/test.txt | 14 + .../dataset/testCoNLL2000Dataset/train.txt | 21 + .../python/dataset/test_datasets_conll2000.py | 345 ++++++++++ 17 files changed, 1798 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/datasetops/source/conll2000_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/datasetops/source/conll2000_op.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/conll2000_node.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/conll2000_node.h create mode 100644 tests/ut/cpp/dataset/c_api_dataset_conll2000_test.cc create mode 100755 tests/ut/data/dataset/testCoNLL2000Dataset/test.txt create mode 100755 tests/ut/data/dataset/testCoNLL2000Dataset/train.txt create mode 100644 tests/ut/python/dataset/test_datasets_conll2000.py diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 21357a7883d..350aa03820b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -95,6 +95,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h" #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/conll2000_node.h" #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" #include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h" #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h" @@ -1033,6 +1034,14 @@ CocoDataset::CocoDataset(const std::vector &dataset_dir, const std::vector ir_node_ = std::static_pointer_cast(ds); } +CoNLL2000Dataset::CoNLL2000Dataset(const std::vector &dataset_dir, const std::vector &usage, + int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, + const std::shared_ptr &cache) { + auto ds = std::make_shared(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, + num_shards, shard_id, cache); + ir_node_ = std::static_pointer_cast(ds); +} + CSVDataset::CSVDataset(const std::vector> &dataset_files, char field_delim, const std::vector> &column_defaults, const std::vector> &column_names, int64_t num_samples, ShuffleMode shuffle, diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc index 490b1197180..98a78833f9a 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc @@ -32,6 +32,7 @@ #include "minddata/dataset/engine/ir/datasetops/source/cityscapes_node.h" #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" +#include "minddata/dataset/engine/ir/datasetops/source/conll2000_node.h" #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" #include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h" #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h" @@ -159,6 +160,18 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) { })); })); +PYBIND_REGISTER(CoNLL2000Node, 2, ([](const py::module *m) { + (void)py::class_>( + *m, "CoNLL2000Node", "to create a CoNLL2000Node") + .def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle, + int32_t num_shards, int32_t shard_id) { + std::shared_ptr conll2000 = std::make_shared( + dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); + THROW_IF_ERROR(conll2000->ValidateParams()); + return conll2000; + })); + })); + PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "CSVNode", "to create a CSVNode") .def(py::init([](std::vector csv_files, char field_delim, py::list column_defaults, diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt index d8a655834f8..2254b793a93 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt @@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES cityscapes_op.cc clue_op.cc coco_op.cc + conll2000_op.cc csv_op.cc dbpedia_op.cc div2k_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/conll2000_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/conll2000_op.cc new file mode 100644 index 00000000000..91200e44ac4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/conll2000_op.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/conll2000_op.h" + +#include +#include +#include +#include +#include + +#include "debug/common.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/wait_post.h" +#include "utils/file_utils.h" + +namespace mindspore { +namespace dataset { +CoNLL2000Op::CoNLL2000Op(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size, + std::unique_ptr schema, const std::vector &conll2000_file_list, + int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id) + : TextFileOp(num_workers, total_rows, worker_connector_size, std::move(schema), conll2000_file_list, + op_connector_size, shuffle_files, num_devices, device_id) {} + +// A print method typically used for debugging. +void CoNLL2000Op::Print(std::ostream &out, bool show_all) const { + if (!show_all) { + // Call the super class for displaying any common 1-liner info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op. + out << "\n"; + } else { + // Call the super class for displaying any common detailed info. + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff. + out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCoNLL2000 file list:\n"; + for (size_t i = 0; i < text_files_list_.size(); ++i) { + out << " " << text_files_list_[i]; + } + out << "\nData Schema:\n"; + out << *data_schema_ << "\n\n"; + } +} + +Status CoNLL2000Op::LoadTensor(const std::vector &column, TensorRow *out_row, size_t index) { + RETURN_UNEXPECTED_IF_NULL(out_row); + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateFromVector(column, &tensor)); + (*out_row)[index] = std::move(tensor); + return Status::OK(); +} + +// Function to split string based on a character delimiter. +std::vector CoNLL2000Op::Split(const std::string &s, char delim) { + std::vector res; + std::stringstream ss(s); + std::string item; + + while (getline(ss, item, delim)) { + res.push_back(item); + } + return res; +} + +// Removes excess space before and after the string. +std::string CoNLL2000Op::Strip(const std::string &str) { + std::int64_t strlen = str.size(); + std::int64_t i, j; + i = 0; + while (i < strlen && str[i] == ' ') { + i++; + } + j = strlen - 1; + while (j >= i && str[j] == ' ') { + j--; + } + j++; + if (i == 0 && j == strlen) { + return str; + } else { + return str.substr(i, j - i); + } +} + +Status CoNLL2000Op::Load(const std::vector &word, const std::vector &pos_tag, + const std::vector &chunk_tag, const std::string &file, int32_t worker_id) { + size_t row_line = 3; + TensorRow tRow(row_line, nullptr); + // Add file path info. + std::vector file_path(row_line, file); + tRow.setPath(file_path); + size_t word_index = 0, pos_tag_index = 1, chunk_tag_index = 2; + RETURN_IF_NOT_OK(LoadTensor(word, &tRow, word_index)); + RETURN_IF_NOT_OK(LoadTensor(pos_tag, &tRow, pos_tag_index)); + RETURN_IF_NOT_OK(LoadTensor(chunk_tag, &tRow, chunk_tag_index)); + RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow))); + return Status::OK(); +} + +Status CoNLL2000Op::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) { + auto realpath = FileUtils::GetRealPath(file.data()); + if (!realpath.has_value()) { + MS_LOG(ERROR) << "Invalid file path, " << DatasetName() << " dataset dir: " << file << " does not exist."; + RETURN_STATUS_UNEXPECTED("Invalid file path, " + DatasetName() + " dataset dir: " + file + " does not exist."); + } + std::ifstream handle(realpath.value()); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + ": " + file); + } + int64_t rows_total = 0; + std::string line; + std::vector word_column; + std::vector pos_tag_column; + std::vector chunk_tag_column; + while (getline(handle, line)) { + if (line.empty() && rows_total < start_offset) { + continue; + } + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + if (word_column.size() != 0) { + Status s = Load(word_column, pos_tag_column, chunk_tag_column, file, worker_id); + if (s.IsError()) { + handle.close(); + return s; + } + } + std::vector().swap(word_column); + std::vector().swap(pos_tag_column); + std::vector().swap(chunk_tag_column); + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + line = Strip(line); + if (line.empty() && rows_total >= start_offset) { + if (word_column.size() != 0) { + Status s = Load(word_column, pos_tag_column, chunk_tag_column, file, worker_id); + if (s.IsError()) { + handle.close(); + return s; + } + } + std::vector().swap(word_column); + std::vector().swap(pos_tag_column); + std::vector().swap(chunk_tag_column); + continue; + } else if (!line.empty() && rows_total >= start_offset) { + std::vector column = Split(line, ' '); + size_t word_index = 0, pos_tag_index = 1, chunk_tag_index = 2; + word_column.push_back(column[word_index]); + pos_tag_column.push_back(column[pos_tag_index]); + chunk_tag_column.push_back(column[chunk_tag_index]); + } + rows_total++; + } + handle.close(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/conll2000_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/conll2000_op.h new file mode 100644 index 00000000000..698ed35f26e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/conll2000_op.h @@ -0,0 +1,96 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CONLL2000_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CONLL2000_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/util/queue.h" + +namespace mindspore { +namespace dataset { +class JaggedConnector; + +class CoNLL2000Op : public TextFileOp { + public: + /// \Constructor of CoNLL2000Op. + CoNLL2000Op(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr, + const std::vector &conll2000_file_list, int32_t op_connector_size, bool shuffle_files, + int32_t num_devices, int32_t device_id); + + /// \Default destructor. + ~CoNLL2000Op() = default; + + /// \brief A print method typically used for debugging. + /// \param[in] out The output stream to write output to. + /// \param[in] show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + /// \brief Op name getter. + /// \return Name of the current Op. + std::string Name() const override { return "CoNLL2000Op"; } + + /// \brief brief description DatasetName name getter + /// \param[in] upper Needs to be capitalized or not + /// \return DatasetName of the current Op + std::string DatasetName(bool upper = false) const { return upper ? "CoNLL2000" : "conll2000"; } + + private: + /// \brief Parses a single row and puts the data into multiple TensorRows. + /// \param[in] column The content of the column. + /// \param[in] out_row The tensor table to put the parsed data in. + /// \param[in] index Serial number of column. + /// \return Status The error code returned. + Status LoadTensor(const std::vector &column, TensorRow *out_row, size_t index); + + /// \brief Removes excess space before and after the string. + /// \param[in] str The input string. + /// \return A string. + std::string Strip(const std::string &str); + + /// \brief Split string based on a character delimiter. + /// \param[in] s The input string. + /// \param[in] delim Symbols for separating string. + /// \return A string vector. + std::vector Split(const std::string &s, char delim); + + /// \brief Specify that the corresponding data is translated into Tensor. + /// \param[in] word A list of words in a sentence. + /// \param[in] pos_tag Pos_tag part of speech. + /// \param[in] chunk_tag Chunk_tag part of speech. + /// \param[in] file The file to read. + /// \param[in] worker_id The id of the worker that is executing this function. + /// \return Status The error code returned. + Status Load(const std::vector &word, const std::vector &pos_tag, + const std::vector &chunk_tag, const std::string &file, int32_t worker_id); + + /// \brief Reads a text file and loads the data into multiple TensorRows. + /// \param[in] file The file to read. + /// \param[in] start_offset The start offset of file. + /// \param[in] end_offset The end offset of file. + /// \param[in] worker_id The id of the worker that is executing this function. + /// \return Status The error code returned. + Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_CONLL2000_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index daa056717fb..e07df8828b0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -84,6 +84,7 @@ constexpr char kCifar10Node[] = "Cifar10Dataset"; constexpr char kCityscapesNode[] = "CityscapesDataset"; constexpr char kCLUENode[] = "CLUEDataset"; constexpr char kCocoNode[] = "CocoDataset"; +constexpr char kCoNLL2000Node[] = "CoNLL2000Dataset"; constexpr char kCSVNode[] = "CSVDataset"; constexpr char kDBpediaNode[] = "DBpediaDataset"; constexpr char kDIV2KNode[] = "DIV2KDataset"; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt index 31aba0a7068..e07ed694728 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt @@ -11,6 +11,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES cityscapes_node.cc clue_node.cc coco_node.cc + conll2000_node.cc csv_node.cc dbpedia_node.cc div2k_node.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/conll2000_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/conll2000_node.cc new file mode 100644 index 00000000000..885914256a0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/conll2000_node.cc @@ -0,0 +1,203 @@ +/** + Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/source/conll2000_node.h" + +#include +#include + +#include "minddata/dataset/engine/datasetops/source/conll2000_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Constructor for CoNLL2000Node. +CoNLL2000Node::CoNLL2000Node(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, + std::shared_ptr cache) + : NonMappableSourceNode(std::move(cache)), + dataset_dir_(dataset_dir), + usage_(usage), + num_samples_(num_samples), + shuffle_(shuffle), + num_shards_(num_shards), + shard_id_(shard_id), + conll2000_file_list_(WalkAllFiles(usage, dataset_dir)) { + // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion + // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't + // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once + // PreBuildSampler is phased out, this can be cleaned up. + GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); +} + +std::shared_ptr CoNLL2000Node::Copy() { + auto node = + std::make_shared(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); + return node; +} + +void CoNLL2000Node::Print(std::ostream &out) const { + out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + + ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); +} + +Status CoNLL2000Node::ValidateParams() { + RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); + RETURN_IF_NOT_OK(ValidateDatasetDirParam("CoNLL2000Node", dataset_dir_)); + RETURN_IF_NOT_OK(ValidateStringValue("CoNLL2000Node", usage_, {"train", "test", "all"})); + + if (num_samples_ < 0) { + std::string err_msg = "CoNLL2000Node: Invalid number of samples: " + std::to_string(num_samples_); + LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + RETURN_IF_NOT_OK(ValidateDatasetShardParams("CoNLL2000Node", num_shards_, shard_id_)); + return Status::OK(); +} + +// Function to build CoNLL2000Node. +Status CoNLL2000Node::Build(std::vector> *const node_ops) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + + // Sort the dataset files in a lexicographical order. + std::vector sorted_dataset_files = conll2000_file_list_; + std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("word", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("pos_tag", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + RETURN_IF_NOT_OK( + schema->AddColumn(ColDescriptor("chunk_tag", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + // Create and initialize CoNLL2000Op. + std::shared_ptr conll2000_op = + std::make_shared(num_workers_, num_samples_, worker_connector_size_, std::move(schema), + sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); + RETURN_IF_NOT_OK(conll2000_op->Init()); + + // If a global shuffle is used for CoNLL2000, it will inject a shuffle op over the CoNLL2000. + // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built. + // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset CoNLL2000's shuffle + // option to false. + if (shuffle_ == ShuffleMode::kGlobal) { + // Inject ShuffleOp. + std::shared_ptr shuffle_op = nullptr; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset. + RETURN_IF_NOT_OK(CoNLL2000Op::CountAllFileRows(sorted_dataset_files, &num_rows)); + + // Add the shuffle op after this op. + RETURN_IF_NOT_OK( + AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); + shuffle_op->SetTotalRepeats(GetTotalRepeats()); + shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + node_ops->push_back(shuffle_op); + } + conll2000_op->SetTotalRepeats(GetTotalRepeats()); + conll2000_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); + // Add CoNLL2000Op. + node_ops->push_back(conll2000_op); + + return Status::OK(); +} + +// Get the shard id of node. +Status CoNLL2000Node::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + + return Status::OK(); +} + +// Get Dataset size. +Status CoNLL2000Node::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + int64_t num_rows, sample_size = num_samples_; + RETURN_IF_NOT_OK(CoNLL2000Op::CountAllFileRows(conll2000_file_list_, &num_rows)); + num_rows = static_cast(ceil(num_rows / (1.0 * num_shards_))); + *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; + dataset_size_ = *dataset_size; + return Status::OK(); +} + +Status CoNLL2000Node::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} + +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. +// CoNLL2000 by itself is a non-mappable dataset that does not support sampling. +// However, if a cache operator is injected at some other place higher in the tree, that cache can +// inherit this sampler from the leaf, providing sampling support from the caching layer. +// That is why we setup the sampler for a leaf node that does not use sampling. +Status CoNLL2000Node::SetupSamplerForCache(std::shared_ptr *sampler) { + bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); + *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); + return Status::OK(); +} + +// If a cache has been added into the ascendant tree over this CoNLL2000 node, then the cache will be executing +// a sampler for fetching the data. As such, any options in the CoNLL2000 node need to be reset to its defaults so +// that this CoNLL2000 node will produce the full set of data into the cache. +Status CoNLL2000Node::MakeSimpleProducer() { + shard_id_ = 0; + num_shards_ = 1; + shuffle_ = ShuffleMode::kFalse; + num_samples_ = 0; + return Status::OK(); +} + +std::vector CoNLL2000Node::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) { + std::vector conll2000_file_list; + Path train_prefix("train.txt"); + Path test_prefix("test.txt"); + Path dir(dataset_dir); + + if (usage == "train") { + Path temp_path = dir / train_prefix; + conll2000_file_list.push_back(temp_path.ToString()); + } else if (usage == "test") { + Path temp_path = dir / test_prefix; + conll2000_file_list.push_back(temp_path.ToString()); + } else { + Path temp_path = dir / train_prefix; + conll2000_file_list.push_back(temp_path.ToString()); + Path temp_path1 = dir / test_prefix; + conll2000_file_list.push_back(temp_path1.ToString()); + } + return conll2000_file_list; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/conll2000_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/conll2000_node.h new file mode 100644 index 00000000000..fe903ded442 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/conll2000_node.h @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONLL2000_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONLL2000_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/ir/datasetops/dataset_node.h" + +namespace mindspore { +namespace dataset { +/// \class CoNLL2000Node. +/// \brief A Dataset derived class to represent CoNLL2000 dataset. +class CoNLL2000Node : public NonMappableSourceNode { + public: + /// \brief Constructor. + CoNLL2000Node(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle, + int32_t num_shards, int32_t shard_id, std::shared_ptr cache); + + /// \brief Destructor. + ~CoNLL2000Node() = default; + + /// \brief Node name getter. + /// \return Name of the current node. + std::string Name() const override { return "CoNLL2000Node"; } + + /// \brief Print the description. + /// \param[out] out The output stream to write output to. + void Print(std::ostream &out) const override; + + /// \brief Copy the node to a new object. + /// \return A shared pointer to the new copy. + std::shared_ptr Copy() override; + + /// \brief A base class override function to create the required runtime dataset op objects for this class. + /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. + /// \return Status Status::OK() if build successfully. + Status Build(std::vector> *const node_ops) override; + + /// \brief Parameters validation. + /// \return Status Status::OK() if all the parameters are valid. + Status ValidateParams() override; + + /// \brief Get the shard id of node. + /// \param[in] shard_id The shard id. + /// \return Status Status::OK() if get shard id successfully. + Status GetShardId(int32_t *shard_id) override; + + /// \brief Base-class override for GetDatasetSize. + /// \param[in] size_getter Shared pointer to DatasetSizeGetter. + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset. + /// \return Status of the function. + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + + /// \brief Getter functions. + const std::string &DatasetDir() const { return dataset_dir_; } + + /// \brief Getter functions. + const std::string &Usage() const { return usage_; } + + /// \brief Getter functions. + int64_t NumSamples() const { return num_samples_; } + + /// \brief Getter functions. + int32_t NumShards() const { return num_shards_; } + + /// \brief Getter functions. + int32_t ShardId() const { return shard_id_; } + + /// \brief Getter functions. + ShuffleMode Shuffle() const { return shuffle_; } + + /// \brief Get the arguments of node. + /// \param[out] out_json JSON string of all attributes. + /// \return Status of the function. + Status to_json(nlohmann::json *out_json) override; + + /// \brief CoNLL2000 by itself is a non-mappable dataset that does not support sampling. + /// However, if a cache operator is injected at some other place higher in the tree, that cache can + /// inherit this sampler from the leaf, providing sampling support from the caching layer. + /// That is why we setup the sampler for a leaf node that does not use sampling. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \param[in] sampler The sampler to setup. + /// \return Status of the function. + Status SetupSamplerForCache(std::shared_ptr *sampler) override; + + /// \brief If a cache has been added into the ascendant tree over this CoNLL2000 node, then the cache will be + /// executing a sampler for fetching the data. As such, any options in the CoNLL2000 node need to be + /// reset to its defaults so that this CoNLL2000 node will produce the full set of data into the cache. + /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. + /// \return Status of the function. + Status MakeSimpleProducer() override; + + /// \Read all files in the directory. + /// \param[in] usage Part of dataset of CoNLL2000. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \return Status The status code returned. + std::vector WalkAllFiles(const std::string &usage, const std::string &dataset_dir); + + private: + std::string dataset_dir_; + std::string usage_; + int64_t num_samples_; + int32_t num_shards_; + int32_t shard_id_; + ShuffleMode shuffle_; + std::vector conll2000_file_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONLL2000_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h b/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h index d1a46283514..196a5418b77 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h @@ -1776,6 +1776,70 @@ inline std::shared_ptr MS_API Coco(const std::string &dataset_dir, decode, sampler, cache, extra_metadata); } +/// \class CoNLL2000Dataset +/// \brief A source dataset for reading and parsing CoNLL2000Dataset. +class MS_API CoNLL2000Dataset : public Dataset { + public: + /// \brief Constructor of CoNLL2000Dataset. + /// \param[in] dataset_dir Path to the root directory that contains the dataset. + /// \param[in] usage The type of data list txt file to be read, can be "train", "test" or "all". + /// \param[in] num_samples The number of samples to be included in the dataset. + /// \param[in] shuffle The mode for shuffling data every epoch. + /// Can be any of: + /// ShuffleMode.kFalse - No shuffling is performed. + /// ShuffleMode.kFiles - Shuffle files only. + /// ShuffleMode.kGlobal - Shuffle both the files and samples. + /// \param[in] num_shards Number of shards that the dataset should be divided into. + /// \param[in] shard_id The shard ID within num_shards. This argument should be + /// specified only when num_shards is also specified. + /// \param[in] cache Tensor cache to use. + CoNLL2000Dataset(const std::vector &dataset_dir, const std::vector &usage, int64_t num_samples, + ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, + const std::shared_ptr &cache); + + /// \brief Destructor of CoNLL2000Dataset. + ~CoNLL2000Dataset() = default; +}; + +/// \brief Function to create a CoNLL2000Dataset. +/// \note The generated dataset has three column ['word', 'pos_tag', 'chunk_tag']. +/// \param[in] dataset_dir Path to the root directory that contains the dataset. +/// \param[in] usage Part of dataset of CoNLL2000, can be "train", "test" or "all" (default="all"). +/// \param[in] num_samples The number of samples to be included in the dataset +/// (Default = 0, means all samples). +/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal). +/// Can be any of: +/// ShuffleMode::kFalse - No shuffling is performed. +/// ShuffleMode::kFiles - Shuffle files only. +/// ShuffleMode::kGlobal - Shuffle both the files and samples. +/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1). +/// \param[in] shard_id The shard ID within num_shards. This argument should be +/// specified only when num_shards is also specified (Default = 0). +/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). +/// \return Shared pointer to the CoNLL2000Dataset. +/// \par Example +/// \code +/// /* Define dataset path and MindData object */ +/// std::string folder_path = "/path/to/conll2000_dataset_directory"; +/// std::shared_ptr ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kGlobal); +/// +/// /* Create iterator to read dataset */ +/// std::shared_ptr iter = ds->CreateIterator(); +/// std::unordered_map row; +/// iter->GetNextRow(&row); +/// +/// /* Note: In CoNLL2000 dataset, each dictionary has keys "word", "pos_tag", "chunk_tag" */ +/// auto word = row["word"]; +/// \endcode +inline std::shared_ptr MS_API CoNLL2000(const std::string &dataset_dir, + const std::string &usage = "all", int64_t num_samples = 0, + ShuffleMode shuffle = ShuffleMode::kGlobal, + int32_t num_shards = 1, int32_t shard_id = 0, + const std::shared_ptr &cache = nullptr) { + return std::make_shared(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle, + num_shards, shard_id, cache); +} + /// \class CSVDataset /// \brief A source dataset that reads and parses comma-separated values (CSV) datasets. class MS_API CSVDataset : public Dataset { diff --git a/mindspore/python/mindspore/dataset/engine/datasets.py b/mindspore/python/mindspore/dataset/engine/datasets.py index 954f2c30f88..cf5952a2e87 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets.py +++ b/mindspore/python/mindspore/dataset/engine/datasets.py @@ -72,7 +72,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \ check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset, check_svhn_dataset, \ check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \ - check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset + check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset,\ + check_conll2000_dataset from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ get_prefetch_size from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist @@ -6642,6 +6643,64 @@ class CocoDataset(MappableDataset): return self._class_indexing +class CoNLL2000Dataset(SourceDataset): + """ + A source dataset that reads and parses CoNLL2000 dataset. + + The generated dataset has three columns: :py:obj:`[word, pos_tag, chunk_tag]`. + The tensor of column :py:obj:`word` is of the string type. + The tensor of column :py:obj:`pos_tag` is of the string type. + The tensor of column :py:obj:`chunk_tag` is of the string type. + + Args: + dataset_dir (str): Path to the root directory that contains the dataset. + usage (str, optional): Usage of this dataset, can be `train`, `test`, or `all`. `train` will read from + 8936 train samples, `test` will read from 2,012 test samples, + `all` will read from all 1,0948 samples (default=None, all samples). + num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset). + shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch + (default=Shuffle.GLOBAL). + If shuffle is False, no shuffling will be performed; + If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL + Otherwise, there are two levels of shuffling: + + - Shuffle.GLOBAL: Shuffle both the files and samples. + + - Shuffle.FILES: Shuffle files only. + + num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). + When this argument is specified, `num_samples` reflects the max sample number of per shard. + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument can only be specified when num_shards is also specified. + num_parallel_workers (int, optional): Number of workers to read the data + (default=None, number set in the config). + cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing + (default=None, which means no cache is used). + + Raises: + RuntimeError: If dataset_dir does not contain data files. + RuntimeError: If num_parallel_workers exceeds the max thread numbers. + RuntimeError: If num_shards is specified but shard_id is None. + RuntimeError: If shard_id is specified but num_shards is None. + + Examples: + >>> conll2000_dataset_dir = "/path/to/conll2000_dataset_dir" + >>> dataset = ds.CoNLL2000Dataset(dataset_files=conll2000_dataset_dir, usage='all') + """ + + @check_conll2000_dataset + def __init__(self, dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None, + shard_id=None, num_parallel_workers=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, + num_shards=num_shards, shard_id=shard_id, cache=cache) + self.dataset_dir = dataset_dir + self.usage = replace_none(usage, 'all') + + def parse(self, children=None): + return cde.CoNLL2000Node(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards, + self.shard_id) + + class CelebADataset(MappableDataset): """ A source dataset for reading and parsing CelebA dataset. diff --git a/mindspore/python/mindspore/dataset/engine/validators.py b/mindspore/python/mindspore/dataset/engine/validators.py index 707bdf82f27..9b29f2e6189 100644 --- a/mindspore/python/mindspore/dataset/engine/validators.py +++ b/mindspore/python/mindspore/dataset/engine/validators.py @@ -2223,3 +2223,33 @@ def check_yahoo_answers_dataset(method): return method(self, *args, **kwargs) return new_method + + +def check_conll2000_dataset(method): + """ A wrapper that wraps a parameter checker around the original Dataset(CoNLL2000Dataset).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + + # check dataset_dir + dataset_dir = param_dict.get('dataset_dir') + check_dir(dataset_dir) + + # check usage + usage = param_dict.get('usage') + if usage is not None: + check_valid_str(usage, ["train", "test", "all"], "usage") + + validate_dataset_param_value(nreq_param_int, param_dict, int) + check_sampler_shuffle_shard_options(param_dict) + + cache = param_dict.get('cache') + check_cache_option(cache) + + return method(self, *args, **kwargs) + + return new_method + \ No newline at end of file diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index e06e9cea874..9977dcf58de 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -20,6 +20,7 @@ SET(DE_UT_SRCS c_api_dataset_cityscapes_test.cc c_api_dataset_clue_test.cc c_api_dataset_coco_test.cc + c_api_dataset_conll2000_test.cc c_api_dataset_config_test.cc c_api_dataset_csv_test.cc c_api_dataset_dbpedia_test.cc diff --git a/tests/ut/cpp/dataset/c_api_dataset_conll2000_test.cc b/tests/ut/cpp/dataset/c_api_dataset_conll2000_test.cc new file mode 100644 index 00000000000..7fbe5027888 --- /dev/null +++ b/tests/ut/cpp/dataset/c_api_dataset_conll2000_test.cc @@ -0,0 +1,628 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/include/dataset/datasets.h" + +using namespace mindspore::dataset; + +using mindspore::dataset::ShuffleMode; + +class MindDataTestPipeline : public UT::DatasetOpTesting { +protected: +}; + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetBasic."; + // Test CoNLL2000 Dataset with single text file and many default inputs. + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(2); + + // Create a CoNLL2000Dataset, with single text file. + // Note: valid.txt has 3 rows. + // Use 2 samples. + // Use defaults for other input parameters. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::vector column_names = {"word", "pos_tag", "chunk_tag"}; + std::shared_ptr ds = CoNLL2000(dataset_dir, "train", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + EXPECT_NE(row.find("word"), row.end()); + + std::vector> expected_result = { + {"Challenge", "NNP", "O"}, {"Her", "PP$", "B-NP"}, {"To", "TO", "I-VP"}}; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto word = row[column_names[j]]; + std::shared_ptr de_word; + ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word)); + std::string_view sv; + ASSERT_OK(de_word->GetItemAt(&sv, {{}})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + EXPECT_EQ(i, 3); + // Expect 3 samples. + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetBasicWithPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetBasicWithPipeline."; + // Test CoNLL2000 Dataset with single text file and many default inputs. + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(2); + + // Create two CoNLL2000Dataset, with single text file. + // Note: test.txt has 3 rows. + // Use 2 samples. + // Use defaults for other input parameters. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds1 = CoNLL2000(dataset_dir, "test", 0, ShuffleMode::kFalse); + std::shared_ptr ds2 = CoNLL2000(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds1, nullptr); + EXPECT_NE(ds2, nullptr); + + // Create two Repeat operation on ds. + int32_t repeat_num = 2; + ds1 = ds1->Repeat(repeat_num); + EXPECT_NE(ds1, nullptr); + repeat_num = 3; + ds2 = ds2->Repeat(repeat_num); + EXPECT_NE(ds2, nullptr); + + // Create a Concat operation on the ds. + ds1 = ds1->Concat({ds2}); + EXPECT_NE(ds1, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds1->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + std::vector column_names = {"word", "pos_tag", "chunk_tag"}; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("word"), row.end()); + std::vector> expected_result = {{"He", "PBP", "B-NP"}, {"The", "DT", "B-NP"}}; + uint64_t i = 0; + while (row.size() != 0) { + auto word = row["word"]; + MS_LOG(INFO) << "Tensor word shape: " << word.Shape(); + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + + // Expect 10 samples. + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000Getters) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000Getters."; + // Test CoNLL2000 Dataset with single text file and many default inputs. + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(987); + GlobalContext::config_manager()->set_num_parallel_workers(2); + + // Create a CoNLL2000 Dataset, with single text file. + // Note: test.txt has 1 rows. + // Use 2 samples. + // Use defaults for other input parameters. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "test", 2, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + std::vector column_names = {"word", "pos_tag", "chunk_tag"}; + EXPECT_EQ(ds->GetDatasetSize(), 2); + EXPECT_EQ(ds->GetColumnNames(), column_names); + + std::shared_ptr ds1 = CoNLL2000(dataset_dir, "", 0, ShuffleMode::kFalse); + EXPECT_NE(ds1, nullptr); + + EXPECT_EQ(ds1->GetDatasetSize(), 30); + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail1."; + + // Create a CoNLL2000Dataset. + // with invalid samplers=-1. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "test", -1, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: CoNLL2000 number of samples cannot be negative. + EXPECT_EQ(iter, nullptr); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail2."; + + // Attempt to create a CoNLL2000 Dataset. + // with wrongful empty dataset_files input. + std::shared_ptr ds = CoNLL2000("NotExistFile", "test", 2, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: dataset_files is not specified. + EXPECT_EQ(iter, nullptr); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail3) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail3."; + + // Create a CoNLL2000 Dataset. + // with non-existent dataset_files input. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "dev", 2, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: specified dataset_files does not exist. + EXPECT_EQ(iter, nullptr); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail4) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail4."; + + // Create a CoNLL2000Dataset. + // with empty string dataset_files input. + std::shared_ptr ds = CoNLL2000("", "test", 2, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + std::cout << iter; + // Expect failure: specified dataset_files does not exist. + EXPECT_EQ(iter, nullptr); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail5) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail5."; + + // Create a CoNLL2000 Dataset. + // with invalid num_shards=0 value. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "test", 2, ShuffleMode::kFalse, 0); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: Number of shards cannot be <=0. + EXPECT_EQ(iter, nullptr); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail6) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail6."; + + // Create a CoNLL2000Dataset. + // with invalid shard_id=-1 value. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "test", 2, ShuffleMode::kFalse, -1); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: shard_id cannot be negative. + EXPECT_EQ(iter, nullptr); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetFail7) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetFail7."; + + // Create a CoNLL2000 Dataset. + // with invalid shard_id=2 and num_shards=2 combination. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "test", 2, ShuffleMode::kFalse, 2, 2); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure: Cannot have shard_id >= num_shards. + EXPECT_EQ(iter, nullptr); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleFalse) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleFalse."; + // Test CoNLL2000 Dataset with two text files and no shuffle, num_parallel_workers=4. + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(654); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + // Create a CoNLL2000 Dataset, with two text files, test.txt and train.txt, in lexicographical order. + // Note: test.txt has 2 rows. + // Note: train.txt has 3 rows. + // Use default of all samples. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + std::vector column_names = {"word", "pos_tag", "chunk_tag"}; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("word"), row.end()); + std::vector> expected_result = {{"He", "PBP", "B-NP"}, + {"Challenge", "NNP", "O"}, + {"The", "DT", "B-NP"}, + {"Her", "PP$", "B-NP"}, + {"To", "TO", "I-VP"}}; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto word = row[column_names[j]]; + std::shared_ptr de_word; + ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word)); + std::string_view sv; + ASSERT_OK(de_word->GetItemAt(&sv, {{}})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleFilesA) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleFilesA."; + // Test CoNLL2000 Dataset with files shuffle, num_parallel_workers=4. + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(135); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + // Create a CoNLL2000 Dataset, with two text files,test.txt and train.txt, in lexicographical order. + // Note: test.txt has 2 rows. + // Note: train.txt has 3 rows. + // Set shuffle to files shuffle. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + std::vector column_names = {"word", "pos_tag", "chunk_tag"}; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("word"), row.end()); + std::vector> expected_result = {{"He", "PBP", "B-NP"}, + {"Challenge", "NNP", "O"}, + {"The", "DT", "B-NP"}, + {"Her", "PP$", "B-NP"}, + {"To", "TO", "I-VP"}}; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto word = row[column_names[j]]; + std::shared_ptr de_word; + ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word)); + std::string_view sv; + ASSERT_OK(de_word->GetItemAt(&sv, {{}})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 3 + 1 + 2 = 6 samples. + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleFilesB) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleFilesB."; + // Test CoNLL2000 Dataset with files shuffle, num_parallel_workers=4. + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(135); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + // Create a CoNLL2000 Dataset, with two text files test.txt and train.txt, in lexicographical order. + // Note: test.txt has 2 rows. + // Note: train.txt has 3 rows. + // Set shuffle to files shuffle. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + std::vector column_names = {"word", "pos_tag", "chunk_tag"}; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("word"), row.end()); + std::vector> expected_result = {{"He", "PBP", "B-NP"}, + {"Challenge", "NNP", "O"}, + {"The", "DT", "B-NP"}, + {"Her", "PP$", "B-NP"}, + {"To", "TO", "I-VP"}}; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto word = row[column_names[j]]; + std::shared_ptr de_word; + ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word)); + std::string_view sv; + ASSERT_OK(de_word->GetItemAt(&sv, {{}})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 3 + 1 + 2 = 6 samples. + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleGlobal1A) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleGlobalA."; + // Test CoNLL2000 Dataset with 1 text file, global shuffle, num_parallel_workers=4. + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(246); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + // Create a CoNLL2000 Dataset, with one text files. + // Note: test.txt has 2 rows. + // Set shuffle to global shuffle. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "test", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + std::vector column_names = {"word", "pos_tag", "chunk_tag"}; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("word"), row.end()); + std::vector> expected_result = {{"He", "PBP", "B-NP"}, {"The", "DT", "B-NP"}}; + + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto word = row[column_names[j]]; + std::shared_ptr de_word; + ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word)); + std::string_view sv; + ASSERT_OK(de_word->GetItemAt(&sv, {{}})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 1 samples. + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} + +/// Feature: CoNLL2000ChunkingDataset. +/// Description: test CoNLL2000ChunkingDataset in pipeline mode. +/// Expectation: the data is processed successfully. +TEST_F(MindDataTestPipeline, TestCoNLL2000DatasetShuffleGlobalB) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCoNLL2000DatasetShuffleGlobalB."; + // Test CoNLL200 Dataset with 2 text files, global shuffle, num_parallel_workers=4. + + // Set configuration. + uint32_t original_seed = GlobalContext::config_manager()->seed(); + uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); + MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; + GlobalContext::config_manager()->set_seed(246); + GlobalContext::config_manager()->set_num_parallel_workers(4); + + // Create a CoNLL2000 Dataset, with two text files. + // Note: test.txt has 2 rows. + // Note: train.txt has 3 rows. + // Set shuffle to global shuffle. + std::string dataset_dir = datasets_root_path_ + "/testCoNLL2000Dataset"; + std::shared_ptr ds = CoNLL2000(dataset_dir, "all", 0, ShuffleMode::kFalse); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset. + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row. + std::unordered_map row; + std::vector column_names = {"word", "pos_tag", "chunk_tag"}; + ASSERT_OK(iter->GetNextRow(&row)); + + EXPECT_NE(row.find("word"), row.end()); + std::vector> expected_result = {{"He", "PBP", "B-NP"}, + {"Challenge", "NNP", "O"}, + {"The", "DT", "B-NP"}, + {"Her", "PP$", "B-NP"}, + {"To", "TO", "I-VP"}}; + uint64_t i = 0; + while (row.size() != 0) { + for (int j = 0; j < column_names.size(); j++) { + auto word = row[column_names[j]]; + std::shared_ptr de_word; + ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_word)); + std::string_view sv; + ASSERT_OK(de_word->GetItemAt(&sv, {{}})); + std::string ss(sv); + EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str()); + } + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + + // Expect 3 + 1 + 2 = 6 samples. + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline. + iter->Stop(); + + // Restore configuration. + GlobalContext::config_manager()->set_seed(original_seed); + GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); +} diff --git a/tests/ut/data/dataset/testCoNLL2000Dataset/test.txt b/tests/ut/data/dataset/testCoNLL2000Dataset/test.txt new file mode 100755 index 00000000000..de963ebb1de --- /dev/null +++ b/tests/ut/data/dataset/testCoNLL2000Dataset/test.txt @@ -0,0 +1,14 @@ +He PBP B-NP +reckons VBZ B-VP +the DT B-NP +current JJ I-NP +account NN I-NP +. . O + +The DT B-NP +1.8 CD I-NP +billion CD I-NP +in IN B-PP +September NNP B-NP +. . O + diff --git a/tests/ut/data/dataset/testCoNLL2000Dataset/train.txt b/tests/ut/data/dataset/testCoNLL2000Dataset/train.txt new file mode 100755 index 00000000000..d8490a6fe54 --- /dev/null +++ b/tests/ut/data/dataset/testCoNLL2000Dataset/train.txt @@ -0,0 +1,21 @@ +Challenge NNP O +of IN B-PP +the DT B-NP +August NNP B-NP +month NNP B-NP +. . O + +Her PP$ B-NP +'s POS B-NP +chancellor NNP O +at IN B-PP +Lawson NNP I-NP +. . O + +To TO I-VP +economists NNS B-NP +, , O +foreign JJ B-NP +exchange NN I-NP +. . O + diff --git a/tests/ut/python/dataset/test_datasets_conll2000.py b/tests/ut/python/dataset/test_datasets_conll2000.py new file mode 100644 index 00000000000..f0d92b68738 --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_conll2000.py @@ -0,0 +1,345 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest + +import mindspore.dataset as ds +from mindspore import log as logger +from util import config_get_set_num_parallel_workers, config_get_set_seed + +DATA_DIR = '../data/dataset/testCoNLL2000Dataset' + + +def test_conll2000_dataset_one_file(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False) + count = 0 + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("{}".format(i["word"])) + count += 1 + assert count == 2 + + +def test_conll2000_dataset_all_file(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False) + count = 0 + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("{}".format(i["word"])) + count += 1 + assert count == 5 + + +def test_conll2000_dataset_num_samples_none(): + """ + Feature: CoNLL2000ChunkingDataset + Description: test param check of CoNLL2000ChunkingDataset + Expectation: throw correct error and message + """ + # Do not provide a num_samples argument, so it would be None by default + data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False) + count = 0 + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + logger.info("{}".format(i["word"])) + count += 1 + assert count == 2 + + +def test_conll2000_dataset_shuffle_false_num_parallel_workers_4(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + original_seed = config_get_set_seed(987) + data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False) + count = 0 + numword = 5 + line = ["He", "reckons", "the", "current", "account", ".", + "Challenge", "of", "the", "August", "month", ".", + "The", "1.8", "billion", "in", "September", ".", + "Her", "'s", "chancellor", "at", "Lawson", ".", + "To", "economists", ",", "foreign", "exchange", "."] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + for j in range(numword): + strs = i["word"][j].item().decode("utf8") + assert strs == line[count*6+j] + count += 1 + assert count == 5 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_conll2000_dataset_shuffle_false_num_parallel_workers_1(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + original_seed = config_get_set_seed(987) + data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False) + count = 0 + numword = 6 + line = ["He", "reckons", "the", "current", "account", ".", + "The", "1.8", "billion", "in", "September", ".", + "Challenge", "of", "the", "August", "month", ".", + "Her", "'s", "chancellor", "at", "Lawson", ".", + "To", "economists", ",", "foreign", "exchange", "."] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + for j in range(numword): + strs = i["word"][j].item().decode("utf8") + assert strs == line[count*6+j] + count += 1 + assert count == 5 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_conll2000_dataset_shuffle_files_num_parallel_workers_4(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + original_seed = config_get_set_seed(135) + data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES) + count = 0 + numword = 6 + line = ["He", "reckons", "the", "current", "account", ".", + "Challenge", "of", "the", "August", "month", ".", + "The", "1.8", "billion", "in", "September", ".", + "Her", "'s", "chancellor", "at", "Lawson", ".", + "To", "economists", ",", "foreign", "exchange", "."] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + for j in range(numword): + strs = i["word"][j].item().decode("utf8") + assert strs == line[count*6+j] + count += 1 + assert count == 5 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_conll2000_dataset_shuffle_files_num_parallel_workers_1(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + original_seed = config_get_set_seed(135) + data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES) + count = 0 + numword = 6 + line = ["He", "reckons", "the", "current", "account", ".", + "The", "1.8", "billion", "in", "September", ".", + "Challenge", "of", "the", "August", "month", ".", + "Her", "'s", "chancellor", "at", "Lawson", ".", + "To", "economists", ",", "foreign", "exchange", "."] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + for j in range(numword): + strs = i["word"][j].item().decode("utf8") + assert strs == line[count*6+j] + count += 1 + assert count == 5 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_conll2000_dataset_shuffle_global_num_parallel_workers_4(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(4) + original_seed = config_get_set_seed(246) + data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL) + count = 0 + numword = 6 + line = ["Challenge", "of", "the", "August", "month", ".", + "To", "economists", ",", "foreign", "exchange", ".", + "Her", "'s", "chancellor", "at", "Lawson", ".", + "He", "reckons", "the", "current", "account", ".", + "The", "1.8", "billion", "in", "September", "."] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + for j in range(numword): + strs = i["word"][j].item().decode("utf8") + assert strs == line[count*6+j] + count += 1 + assert count == 5 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_conll2000_dataset_shuffle_global_num_parallel_workers_1(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + original_seed = config_get_set_seed(246) + data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL) + count = 0 + numword = 6 + line = ["Challenge", "of", "the", "August", "month", ".", + "The", "1.8", "billion", "in", "September", ".", + "To", "economists", ",", "foreign", "exchange", ".", + "Her", "'s", "chancellor", "at", "Lawson", ".", + "He", "reckons", "the", "current", "account", "."] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + for j in range(numword): + strs = i["word"][j].item().decode("utf8") + assert strs == line[count*6+j] + count += 1 + assert count == 5 + # Restore configuration + ds.config.set_num_parallel_workers(original_num_parallel_workers) + ds.config.set_seed(original_seed) + + +def test_conll2000_dataset_num_samples(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False, num_samples=2) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 2 + + +def test_conll2000_dataset_distribution(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False, num_shards=2, shard_id=1) + count = 0 + for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): + count += 1 + assert count == 1 + + +def test_conll2000_dataset_repeat(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False) + data = data.repeat(3) + count = 0 + numword = 6 + line = ["He", "reckons", "the", "current", "account", ".", + "The", "1.8", "billion", "in", "September", ".", + "He", "reckons", "the", "current", "account", ".", + "The", "1.8", "billion", "in", "September", ".", + "He", "reckons", "the", "current", "account", ".", + "The", "1.8", "billion", "in", "September", ".",] + for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): + for j in range(numword): + strs = i["word"][j].item().decode("utf8") + assert strs == line[count*6+j] + count += 1 + assert count == 6 + + +def test_conll2000_dataset_get_datasetsize(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False) + size = data.get_dataset_size() + assert size == 12 + + +def test_conll2000_dataset_to_device(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False) + data = data.to_device() + data.send() + + +def test_conll2000_dataset_exceptions(): + """ + Feature: CoNLL2000ChunkingDataset. + Description: test param check of CoNLL2000ChunkingDataset. + Expectation: throw correct error and message. + """ + with pytest.raises(ValueError) as error_info: + _ = ds.CoNLL2000Dataset(DATA_DIR, usage="test", num_samples=-1) + assert "num_samples exceeds the boundary" in str(error_info.value) + + with pytest.raises(ValueError) as error_info: + _ = ds.CoNLL2000Dataset("NotExistFile", usage="test") + assert "The folder NotExistFile does not exist or is not a directory or permission denied!" in str(error_info.value) + + with pytest.raises(ValueError) as error_info: + _ = ds.TextFileDataset("") + assert "The following patterns did not match any files" in str(error_info.value) + + + def exception_func(item): + raise Exception("Error occur!") + with pytest.raises(RuntimeError) as error_info: + data = data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False) + data = data.map(operations=exception_func, input_columns=["word"], num_parallel_workers=1) + for _ in data.__iter__(): + pass + assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value) + + +if __name__ == "__main__": + test_conll2000_dataset_one_file() + test_conll2000_dataset_all_file() + test_conll2000_dataset_num_samples_none() + test_conll2000_dataset_shuffle_false_num_parallel_workers_4() + test_conll2000_dataset_shuffle_false_num_parallel_workers_1() + test_conll2000_dataset_shuffle_files_num_parallel_workers_4() + test_conll2000_dataset_shuffle_files_num_parallel_workers_1() + test_conll2000_dataset_shuffle_global_num_parallel_workers_4() + test_conll2000_dataset_shuffle_global_num_parallel_workers_1() + test_conll2000_dataset_num_samples() + test_conll2000_dataset_distribution() + test_conll2000_dataset_repeat() + test_conll2000_dataset_get_datasetsize() + test_conll2000_dataset_to_device() + test_conll2000_dataset_exceptions()