!22514 [assistant][ops] Add new data operator Multi30kDataset

Merge pull request !22514 from 杨旭华/Multi30kDataset
This commit is contained in:
i-robot 2022-01-20 02:13:17 +00:00 committed by Gitee
commit 501614a61f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
22 changed files with 1743 additions and 2 deletions

View File

@ -104,6 +104,7 @@
#endif
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
@ -1532,6 +1533,16 @@ MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vect
}
#ifndef ENABLE_ANDROID
Multi30kDataset::Multi30kDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &language_pair, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds =
std::make_shared<Multi30kNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(language_pair),
num_samples, shuffle, num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<Multi30kNode>(ds);
}
PennTreebankDataset::PennTreebankDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {

View File

@ -68,6 +68,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
@ -464,6 +465,20 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(Multi30kNode, 2, ([](const py::module *m) {
(void)py::class_<Multi30kNode, DatasetNode, std::shared_ptr<Multi30kNode>>(*m, "Multi30kNode",
"to create a Multi30kNode")
.def(py::init([](const std::string &dataset_dir, const std::string &usage,
const std::vector<std::string> &language_pair, int64_t num_samples,
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
std::shared_ptr<Multi30kNode> multi30k =
std::make_shared<Multi30kNode>(dataset_dir, usage, language_pair, num_samples,
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(multi30k->ValidateParams());
return multi30k;
}));
}));
PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
(void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>(
*m, "PennTreebankNode", "to create a PennTreebankNode")

View File

@ -30,6 +30,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
lj_speech_op.cc
mappable_leaf_op.cc
mnist_op.cc
multi30k_op.cc
nonmappable_leaf_op.cc
penn_treebank_op.cc
photo_tour_op.cc

View File

@ -0,0 +1,152 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/multi30k_op.h"
#include <fstream>
#include <iomanip>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "debug/common.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace dataset {
// constructor of Multi30k.
Multi30kOp::Multi30kOp(int32_t num_workers, int64_t num_samples, const std::vector<std::string> &language_pair,
int32_t worker_connector_size, std::unique_ptr<DataSchema> schema,
const std::vector<std::string> &text_files_list, int32_t op_connector_size, bool shuffle_files,
int32_t num_devices, int32_t device_id)
: TextFileOp(num_workers, num_samples, worker_connector_size, std::move(schema), std::move(text_files_list),
op_connector_size, shuffle_files, num_devices, device_id),
language_pair_(language_pair) {}
// Print info of operator.
void Multi30kOp::Print(std::ostream &out, bool show_all) {
// Print parameter to debug function.
std::vector<std::string> multi30k_files_list = TextFileOp::FileNames();
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nMulti30k files list:\n";
for (int i = 0; i < multi30k_files_list.size(); ++i) {
out << " " << multi30k_files_list[i];
}
out << "\n\n";
}
}
Status Multi30kOp::LoadTensor(const std::string &line, TensorRow *out_row, size_t index) {
RETURN_UNEXPECTED_IF_NULL(out_row);
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
(*out_row)[index] = std::move(tensor);
return Status::OK();
}
Status Multi30kOp::LoadFile(const std::string &file_en, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
auto realpath_en = FileUtils::GetRealPath(file_en.data());
if (!realpath_en.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " << DatasetName() + " Dataset file: " << file_en << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, " + DatasetName() + " Dataset file: " + file_en + " does not exist.");
}
// We use English files to find Germany files, to make sure that data are ordered.
Path path_en(file_en);
Path parent_path(path_en.ParentPath());
std::string basename = path_en.Basename();
int suffix_len = 3;
std::string suffix_de = ".de";
basename = basename.replace(basename.find("."), suffix_len, suffix_de);
Path BaseName(basename);
Path path_de = parent_path / BaseName;
std::string file_de = path_de.ToString();
auto realpath_de = FileUtils::GetRealPath(file_de.data());
if (!realpath_de.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " << DatasetName() + " Dataset file: " << file_de << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, " + DatasetName() + " Dataset file: " + file_de + " does not exist.");
}
std::ifstream handle_en(realpath_en.value());
CHECK_FAIL_RETURN_UNEXPECTED(handle_en.is_open(), "Invalid file, failed to open en file: " + file_en);
std::ifstream handle_de(realpath_de.value());
CHECK_FAIL_RETURN_UNEXPECTED(handle_de.is_open(), "Invalid file, failed to open de file: " + file_de);
// Set path for path in class TensorRow.
std::string line_en;
std::string line_de;
std::vector<std::string> path = {file_en, file_de};
int row_total = 0;
while (getline(handle_en, line_en) && getline(handle_de, line_de)) {
if (line_en.empty() && line_de.empty()) {
continue;
}
// If read to the end offset of this file, break.
if (row_total >= end_offset) {
break;
}
// Skip line before start offset.
if (row_total < start_offset) {
++row_total;
continue;
}
int tensor_size = 2;
TensorRow tRow(tensor_size, nullptr);
Status rc_en;
Status rc_de;
if (language_pair_[0] == "en") {
rc_en = LoadTensor(line_en, &tRow, 0);
rc_de = LoadTensor(line_de, &tRow, 1);
} else if (language_pair_[0] == "de") {
rc_en = LoadTensor(line_en, &tRow, 1);
rc_de = LoadTensor(line_de, &tRow, 0);
}
if (rc_en.IsError() || rc_de.IsError()) {
handle_en.close();
handle_de.close();
RETURN_IF_NOT_OK(rc_en);
RETURN_IF_NOT_OK(rc_de);
}
(&tRow)->setPath(path);
Status rc = jagged_rows_connector_->Add(worker_id, std::move(tRow));
if (rc.IsError()) {
handle_en.close();
handle_de.close();
return rc;
}
++row_total;
}
handle_en.close();
handle_de.close();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MULTI30K_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MULTI30K_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
namespace mindspore {
namespace dataset {
class JaggedConnector;
using StringIndex = AutoIndexObj<std::string>;
class Multi30kOp : public TextFileOp {
public:
/// \brief Constructor of Multi30kOp
/// \note The builder class should be used to call this constructor.
/// \param[in] num_workers Number of worker threads reading data from multi30k_file files.
/// \param[in] num_samples Number of rows to read.
/// \param[in] language_pair List containing text and translation language.
/// \param[in] worker_connector_size List of filepaths for the dataset files.
/// \param[in] schema The data schema object.
/// \param[in] text_files_list File path of multi30k files.
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
/// \param[in] num_devices Shards of data.
/// \param[in] device_id The device ID within num_devices.
Multi30kOp(int32_t num_workers, int64_t num_samples, const std::vector<std::string> &language_pair,
int32_t worker_connector_size, std::unique_ptr<DataSchema> schema,
const std::vector<std::string> &text_files_list, int32_t op_connector_size, bool shuffle_files,
int32_t num_devices, int32_t device_id);
/// \Default destructor.
~Multi30kOp() = default;
/// \brief A print method typically used for debugging.
/// \param[in] out The output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all);
/// \brief Return the name of Operator.
/// \return Status - return the name of Operator.
std::string Name() const override { return "Multi30kOp"; }
/// \brief DatasetName name getter.
/// \param[in] upper If true, the return value is uppercase, otherwise, it is lowercase.
/// \return std::string DatasetName of the current Op.
std::string DatasetName(bool upper = false) const { return upper ? "Multi30k" : "multi30k"; }
private:
/// \brief Load data into Tensor.
/// \param[in] line Data read from files.
/// \param[in] out_row Output tensor.
/// \param[in] index The index of Tensor.
Status LoadTensor(const std::string &line, TensorRow *out_row, size_t index);
/// \brief Read data from files.
/// \param[in] file_en The paths of multi30k dataset files.
/// \param[in] start_offset The location of reading start.
/// \param[in] end_offset The location of reading finished.
/// \param[in] worker_id The id of the worker that is executing this function.
Status LoadFile(const std::string &file_en, int64_t start_offset, int64_t end_offset, int32_t worker_id);
std::vector<std::string> language_pair_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MULTI30K_OP_H_

View File

@ -106,6 +106,7 @@ constexpr char kLJSpeechNode[] = "LJSpeechDataset";
constexpr char kManifestNode[] = "ManifestDataset";
constexpr char kMindDataNode[] = "MindDataDataset";
constexpr char kMnistNode[] = "MnistDataset";
constexpr char kMulti30kNode[] = "Multi30kDataset";
constexpr char kPennTreebankNode[] = "PennTreebankDataset";
constexpr char kPhotoTourNode[] = "PhotoTourDataset";
constexpr char kPlaces365Node[] = "Places365Dataset";

View File

@ -32,6 +32,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
manifest_node.cc
minddata_node.cc
mnist_node.cc
multi30k_node.cc
penn_treebank_node.cc
photo_tour_node.cc
places365_node.cc

View File

@ -0,0 +1,198 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/multi30k_op.h"
namespace mindspore {
namespace dataset {
Multi30kNode::Multi30kNode(const std::string &dataset_dir, const std::string &usage,
const std::vector<std::string> &language_pair, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: NonMappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
language_pair_(language_pair),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
multi30k_files_list_(WalkAllFiles(usage, dataset_dir)) {
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
void Multi30kNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
}
std::shared_ptr<DatasetNode> Multi30kNode::Copy() {
auto node = std::make_shared<Multi30kNode>(dataset_dir_, usage_, language_pair_, num_samples_, shuffle_, num_shards_,
shard_id_, cache_);
return node;
}
// Function to build Multi30kNode
Status Multi30kNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
std::vector<std::string> sorted_dataset_files = multi30k_files_list_;
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("translation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1)));
std::shared_ptr<Multi30kOp> multi30k_op =
std::make_shared<Multi30kOp>(num_workers_, num_samples_, language_pair_, worker_connector_size_, std::move(schema),
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(multi30k_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_IF_NOT_OK(Multi30kOp::CountAllFileRows(sorted_dataset_files, &num_rows));
// Add the shuffle op after this op
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
multi30k_op->SetTotalRepeats(GetTotalRepeats());
multi30k_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add Multi30kOp
node_ops->push_back(multi30k_op);
return Status::OK();
}
Status Multi30kNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Multi30kDataset", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("Multi30kDataset", multi30k_files_list_));
RETURN_IF_NOT_OK(ValidateStringValue("Multi30kDataset", usage_, {"train", "valid", "test", "all"}));
const int kLanguagePairSize = 2;
if (language_pair_.size() != kLanguagePairSize) {
std::string err_msg =
"Multi30kDataset: language_pair expecting size 2, but got: " + std::to_string(language_pair_.size());
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
const std::vector<std::vector<std::string>> support_language_pair = {{"en", "de"}, {"de", "en"}};
if (language_pair_ != support_language_pair[0] && language_pair_ != support_language_pair[1]) {
std::string err_msg = R"(Multi30kDataset: language_pair must be {"en", "de"} or {"de", "en"}.)";
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateScalar("Multi30kDataset", "num_samples", num_samples_, {0}, false));
RETURN_IF_NOT_OK(ValidateDatasetShardParams("Multi30kDataset", num_shards_, shard_id_));
return Status::OK();
}
Status Multi30kNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
Status Multi30kNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size = num_samples_;
RETURN_IF_NOT_OK(Multi30kOp::CountAllFileRows(multi30k_files_list_, &num_rows));
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status Multi30kNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
args["language_pair"] = language_pair_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
Status Multi30kNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
Status Multi30kNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
std::vector<std::string> Multi30kNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
std::vector<std::string> multi30k_files_list;
Path train_en("training/train.en");
Path test_en("mmt16_task1_test/test.en");
Path valid_en("validation/val.en");
Path dir(dataset_dir);
if (usage == "train") {
Path temp_path = dir / train_en;
multi30k_files_list.push_back(temp_path.ToString());
} else if (usage == "test") {
Path temp_path = dir / test_en;
multi30k_files_list.push_back(temp_path.ToString());
} else if (usage == "valid") {
Path temp_path = dir / valid_en;
multi30k_files_list.push_back(temp_path.ToString());
} else {
Path temp_path = dir / train_en;
multi30k_files_list.push_back(temp_path.ToString());
Path temp_path1 = dir / test_en;
multi30k_files_list.push_back(temp_path1.ToString());
Path temp_path2 = dir / valid_en;
multi30k_files_list.push_back(temp_path2.ToString());
}
return multi30k_files_list;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,134 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class Multi30kNode : public NonMappableSourceNode {
public:
/// \brief Constructor of Multi30kNode.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of MULTI30K, can be "train", "test", "valid" or "all".
/// \param[in] language_pair List containing text and translation language.
/// \param[in] num_samples The number of samples to be included in the dataset
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shared_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
Multi30kNode(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &language_pair,
int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shared_id,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor of Multi30kNode.
~Multi30kNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kMulti30kNode; }
/// \brief Print the description.
/// \param[in] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops);
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting.
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Multi30k by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this Multi30k node, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the Multi30k node need to be reset to its defaults
/// so that this Multi30k node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function
Status MakeSimpleProducer() override;
/// \brief Getter functions
int32_t NumSamples() const { return num_samples_; }
int32_t NumShards() const { return num_shards_; }
int32_t ShardId() const { return shard_id_; }
ShuffleMode Shuffle() const { return shuffle_; }
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
const std::vector<std::string> &LanguagePair() const { return language_pair_; }
/// \brief Generate a list of read file names according to usage.
/// \param[in] usage Part of dataset of Multi30k.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \return std::vector<std::string> A list of read file names.
std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir);
private:
std::string dataset_dir_;
std::string usage_;
std::vector<std::string> language_pair_;
int32_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
std::vector<std::string> multi30k_files_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_

View File

@ -3723,6 +3723,75 @@ inline std::shared_ptr<MnistDataset> MS_API Mnist(const std::string &dataset_dir
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
/// \class Multi30kDataset
/// \brief A source dataset that reads and parses Multi30k dataset.
class MS_API Multi30kDataset : public Dataset {
public:
/// \brief Constructor of Multi30kDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of MULTI30K, can be "train", "test", "valid" or "all".
/// \param[in] language_pair List containing text and translation language.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
Multi30kDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::vector<std::vector<char>> &language_pair, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of Multi30kDataset.
~Multi30kDataset() = default;
};
/// \brief Function to create a Multi30kDataset.
/// \note The generated dataset has two columns ["text", "translation"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of MULTI30K, can be "train", "test", "valid" or "all" (default = "all").
/// \param[in] language_pair List containing text and translation language (default = {"en", "de"}).
/// \param[in] num_samples The number of samples to be included in the dataset
/// (Default = 0, means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the Multi30kDataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string dataset_dir = "/path/to/multi30k_dataset_directory";
/// std::shared_ptr<Dataset> ds = Multi30k(dataset_dir, "all");
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In Multi30kdataset, each dictionary has keys "text" and "translation" */
/// auto text = row["text"];
/// \endcode
inline std::shared_ptr<Multi30kDataset> MS_API Multi30k(const std::string &dataset_dir,
const std::string &usage = "all",
const std::vector<std::string> &language_pair = {"en", "de"},
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<Multi30kDataset>(StringToChar(dataset_dir), StringToChar(usage),
VectorStringToChar(language_pair), num_samples, shuffle, num_shards,
shard_id, cache);
}
/// \class PennTreebankDataset
/// \brief A source dataset for reading and parsing PennTreebank dataset.
class MS_API PennTreebankDataset : public Dataset {

View File

@ -74,6 +74,7 @@ __all__ = ["Caltech101Dataset", # Vision
"IMDBDataset", # Text
"IWSLT2016Dataset", # Text
"IWSLT2017Dataset", # Text
"Multi30kDataset", # Text
"PennTreebankDataset", # Text
"SogouNewsDataset", # Text
"TextFileDataset", # Text

View File

@ -30,7 +30,7 @@ from .validators import check_imdb_dataset, check_iwslt2016_dataset, check_iwslt
check_penn_treebank_dataset, check_ag_news_dataset, check_amazon_review_dataset, check_udpos_dataset, \
check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, \
check_sogou_news_dataset, check_textfiledataset, check_dbpedia_dataset, check_yelp_review_dataset, \
check_en_wik9_dataset, check_yahoo_answers_dataset
check_en_wik9_dataset, check_yahoo_answers_dataset, check_multi30k_dataset
from ..core.validator_helpers import replace_none
@ -961,6 +961,106 @@ class IWSLT2017Dataset(SourceDataset, TextBaseDataset):
self.shuffle_flag, self.num_shards, self.shard_id)
class Multi30kDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses Multi30k dataset.
The generated dataset has two columns :py:obj:`[text, translation]`.
The tensor of column :py:obj:'text' is of the string type.
The tensor of column :py:obj:'translation' is of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Acceptable usages include `train`, `test, `valid` or `all` (default=`all`).
language_pair (str, optional): Acceptable language_pair include ['en', 'de'], ['de', 'en']
(default=['en', 'de']).
num_samples (int, optional): The number of images to be included in the dataset
(default=None, all samples).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None). When this argument is specified, `num_samples` reflects
the max sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If usage is not "train", "test", "valid" or "all".
RuntimeError: If the length of language_pair is not equal to 2.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
RuntimeError: If num_samples is less than 0.
Examples:
>>> multi30k_dataset_dir = "/path/to/multi30k_dataset_directory"
>>> data = ds.Multi30kDataset(dataset_dir=multi30k_dataset_dir, usage='all', language_pair=['de', 'en'])
About Multi30k dataset:
Multi30K is a dataset to stimulate multilingual multimodal research for English-German.
It is based on the Flickr30k dataset, which contains images sourced from online
photo-sharing websites. Each image is paired with five English descriptions, which were
collected from Amazon Mechanical Turk. The Multi30K dataset extends the Flickr30K
dataset with translated and independent German sentences.
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
.. code-block::
multi30k_dataset_directory
training
train.de
train.en
validation
val.de
val.en
mmt16_task1_test
val.de
val.en
Citation:
.. code-block::
@article{elliott-EtAl:2016:VL16,
author = {{Elliott}, D. and {Frank}, S. and {Sima'an}, K. and {Specia}, L.},
title = {Multi30K: Multilingual English-German Image Descriptions},
booktitle = {Proceedings of the 5th Workshop on Vision and Language},
year = {2016},
pages = {70--74},
year = 2016
}
"""
@check_multi30k_dataset
def __init__(self, dataset_dir, usage=None, language_pair=None, num_samples=None,
num_parallel_workers=None, shuffle=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, 'all')
self.language_pair = replace_none(language_pair, ["en", "de"])
self.shuffle = replace_none(shuffle, Shuffle.GLOBAL)
def parse(self, children=None):
return cde.Multi30kNode(self.dataset_dir, self.usage, self.language_pair, self.num_samples,
self.shuffle_flag, self.num_shards, self.shard_id)
class PennTreebankDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses PennTreebank datasets.

View File

@ -2558,3 +2558,40 @@ def check_en_wik9_dataset(method):
return new_method
def check_multi30k_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset (Multi30kDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
language_pair = param_dict.get('language_pair')
support_language_pair = [['en', 'de'], ['de', 'en'], ('en', 'de'), ('de', 'en')]
if language_pair is not None:
type_check(language_pair, (list, tuple), "language_pair")
if len(language_pair) != 2:
raise ValueError(
"language_pair should be a list or tuple of length 2, but got {0}".format(len(language_pair)))
if language_pair not in support_language_pair:
raise ValueError(
"language_pair can only be ['en', 'de'] or ['en', 'de'], but got {0}".format(language_pair))
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
return method(self, *args, **kwargs)
return new_method

View File

@ -40,6 +40,7 @@ SET(DE_UT_SRCS
c_api_dataset_lj_speech_test.cc
c_api_dataset_manifest_test.cc
c_api_dataset_minddata_test.cc
c_api_dataset_multi30k_test.cc
c_api_dataset_ops_test.cc
c_api_dataset_penn_treebank_test.cc
c_api_dataset_photo_tour_test.cc

View File

@ -0,0 +1,652 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::Status;
using mindspore::dataset::ShuffleMode;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: Test Multi30k Dataset(English).
/// Description: read Multi30kDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestMulti30kSuccessEn) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kSuccessEn.";
// Test Multi30k English files with default parameters
// Create a Multi30k dataset
std::string en_file = datasets_root_path_ + "/testMulti30kDataset";
// test train
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected = {"This is the first English sentence in train.",
"This is the second English sentence in train.",
"This is the third English sentence in train."};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// test valid
usage = "valid";
expected = {"This is the first English sentence in valid.",
"This is the second English sentence in valid."};
ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
iter = ds->CreateIterator();
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
i = 0;
while (row.size() != 0) {
auto text = row["text"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 2);
iter->Stop();
// test test
usage = "test";
expected = {"This is the first English sentence in test.",
"This is the second English sentence in test.",
"This is the third English sentence in test."};
ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
iter = ds->CreateIterator();
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
i = 0;
while (row.size() != 0) {
auto text = row["text"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 3);
iter->Stop();
}
/// Feature: Test Multi30k Dataset(Germany).
/// Description: read Multi30kDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestMulti30kSuccessDe) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kSuccessDe.";
// Test Multi30k Germany files with default parameters
// Create a Multi30k dataset
std::string en_file = datasets_root_path_ + "/testMulti30kDataset";
// test train
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("translation"), row.end());
std::vector<std::string> expected = {"This is the first Germany sentence in train.",
"This is the second Germany sentence in train.",
"This is the third Germany sentence in train."};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["translation"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// test valid
usage = "valid";
expected = {"This is the first Germany sentence in valid.",
"This is the second Germany sentence in valid."};
ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
iter = ds->CreateIterator();
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("translation"), row.end());
i = 0;
while (row.size() != 0) {
auto text = row["translation"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 2);
iter->Stop();
// test test
usage = "test";
expected = {"This is the first Germany sentence in test.",
"This is the second Germany sentence in test.",
"This is the third Germany sentence in test."};
ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
iter = ds->CreateIterator();
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("translation"), row.end());
i = 0;
while (row.size() != 0) {
auto text = row["translation"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 3);
iter->Stop();
}
/// Feature: Test Multi30k Dataset(Germany).
/// Description: read Multi30kDataset data and get data.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetBasicWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetBasicWithPipeline.";
// Create two Multi30kFile Dataset, with single Multi30k file
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds1 = Multi30k(train_en_file, usage, {"en", "de"}, 2, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds2 = Multi30k(train_en_file, usage, {"en", "de"}, 2, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 2;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 3;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"text"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 10 samples
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: Test Getters.
/// Description: includes tests for shape, type, size.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestMulti30kGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kGetters.";
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 2, ShuffleMode::kFalse);
std::vector<std::string> column_names = {"text","translation"};
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: Test Multi30kDataset in distribution.
/// Description: test interface in a distributed state.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetDistribution) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetDistribution.";
// Create a Multi30kFile Dataset, with single Multi30k file
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kGlobal, 3, 2);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 1 samples
EXPECT_EQ(i, 1);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: Error Test.
/// Description: test the wrong input.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidFilePath) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvalidFilePath.";
// Create a Multi30k Dataset
// with invalid file path
std::string train_en_file = datasets_root_path_ + "/invalid/file.path";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"});
EXPECT_NE(ds, nullptr);
}
/// Feature: Error Test.
/// Description: test the wrong input.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidUsage) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvaildUsage.";
// Create a Multi30k Dataset
// with invalid usage
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "invalid_usage";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"});
EXPECT_NE(ds, nullptr);
}
/// Feature: Error Test.
/// Description: test the wrong input.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidLanguagePair) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailLanguagePair.";
// Create a Multi30k Dataset
// with invalid usage
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::vector<std::string> language_pair0 = {"ch", "ja"};
std::shared_ptr<Dataset> ds0 = Multi30k(train_en_file, usage, language_pair0);
EXPECT_NE(ds0, nullptr);
std::vector<std::string> language_pair1 = {"en", "de", "aa"};
std::shared_ptr<Dataset> ds1 = Multi30k(train_en_file, usage, language_pair1);
EXPECT_NE(ds1, nullptr);
}
/// Feature: Error Test.
/// Description: test the wrong input.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidNumSamples) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvalidNumSamples.";
// Create a Multi30k Dataset
// with invalid samplers=-1
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, -1);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: TextFile number of samples cannot be negative
EXPECT_EQ(iter, nullptr);
}
/// Feature: Error Test.
/// Description: test the wrong input.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidShards) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvalidShards.";
// Create a Multi30k Dataset
// with invalid shards.
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse, 2, 3);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: TextFile number of samples cannot be negative
EXPECT_EQ(iter, nullptr);
}
/// Feature: Error Test.
/// Description: test the wrong input.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidShardID) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvalidShardID.";
// Create a Multi30k Dataset
// with invalid shard ID.
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse, 0, -1);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: TextFile number of samples cannot be negative
EXPECT_EQ(iter, nullptr);
}
/// Feature: Error Test.
/// Description: test the wrong input.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetLanguagePair) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetLanguagePair.";
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"de", "en"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("translation"), row.end());
std::vector<std::string> expected = {"This is the first English sentence in train.",
"This is the second English sentence in train.",
"This is the third English sentence in train."};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["translation"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 3);
iter->Stop();
}
/// Feature: Test Multi30k Dataset(shufflemode=kFalse).
/// Description: test Multi30k Dataset interface with different ShuffleMode.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetShuffleFilesFalse) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetShuffleFilesFalse.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(135);
GlobalContext::config_manager()->set_num_parallel_workers(1);
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected = {"This is the first English sentence in train.",
"This is the second English sentence in train.",
"This is the third English sentence in train."};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test Multi30k Dataset(shufflemode=kFiles).
/// Description: test Multi30k Dataset interface with different ShuffleMode.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetShuffleFilesFiles) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetShuffleFilesFiles.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(135);
GlobalContext::config_manager()->set_num_parallel_workers(1);
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kFiles);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
std::vector<std::string> expected = {"This is the first English sentence in train.",
"This is the second English sentence in train.",
"This is the third English sentence in train."};
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
// Compare against expected result
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test Multi30k Dataset(shufflemode=kGlobal).
/// Description: test Multi30k Dataset interface with different ShuffleMode.
/// Expectation: the data is processed successfully.
TEST_F(MindDataTestPipeline, TestMulti30kDatasetShuffleFilesGlobal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetShuffleFilesGlobal.";
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
std::string usage = "train";
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("text"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["text"];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}

View File

@ -0,0 +1,3 @@
This is the first Germany sentence in test.
This is the second Germany sentence in test.
This is the third Germany sentence in test.

View File

@ -0,0 +1,3 @@
This is the first English sentence in test.
This is the second English sentence in test.
This is the third English sentence in test.

View File

@ -0,0 +1,3 @@
This is the first Germany sentence in train.
This is the second Germany sentence in train.
This is the third Germany sentence in train.

View File

@ -0,0 +1,3 @@
This is the first English sentence in train.
This is the second English sentence in train.
This is the third English sentence in train.

View File

@ -0,0 +1,2 @@
This is the first Germany sentence in valid.
This is the second Germany sentence in valid.

View File

@ -0,0 +1,2 @@
This is the first English sentence in valid.
This is the second English sentence in valid.

View File

@ -0,0 +1,267 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
import mindspore.dataset.text.transforms as a_c_trans
from mindspore import log as logger
from util import config_get_set_num_parallel_workers, config_get_set_seed
INVALID_FILE = '../data/dataset/testMulti30kDataset/invalid_dir'
DATA_ALL_FILE = '../data/dataset/testMulti30kDataset'
def test_data_file_multi30k_text():
"""
Feature: Test Multi30k Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(987)
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage="train", shuffle=False)
count = 0
line = ["This is the first English sentence in train.",
"This is the second English sentence in train.",
"This is the third English sentence in train."
]
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 3
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_data_file_multi30k_translation():
"""
Feature: Test Multi30k Dataset.
Description: read data from a single file.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(987)
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage="train", shuffle=False)
count = 0
line = ["This is the first Germany sentence in train.",
"This is the second Germany sentence in train.",
"This is the third Germany sentence in train."
]
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["translation"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 3
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_all_file_multi30k():
"""
Feature: Test Multi30k Dataset.
Description: read data from all file.
Expectation: the data is processed successfully.
"""
dataset = ds.Multi30kDataset(DATA_ALL_FILE)
count = 0
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 8
def test_dataset_num_samples_none():
"""
Feature: Test Multi30k Dataset(num_samples = default).
Description: test get num_samples.
Expectation: the data is processed successfully.
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(987)
dataset = ds.Multi30kDataset(DATA_ALL_FILE, shuffle=False)
count = 0
line = ["This is the first English sentence in test.",
"This is the second English sentence in test.",
"This is the third English sentence in test.",
"This is the first English sentence in train.",
"This is the second English sentence in train.",
"This is the third English sentence in train.",
"This is the first English sentence in valid.",
"This is the second English sentence in valid."
]
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == line[count]
count += 1
assert count == 8
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_num_shards_multi30k():
"""
Feature: Test Multi30k Dataset(num_shards = 3).
Description: test get num_samples.
Expectation: the data is processed successfully.
"""
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage='train', num_shards=3, shard_id=1)
count = 0
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
count += 1
assert count == 1
def test_multi30k_dataset_num_samples():
"""
Feature: Test Multi30k Dataset(num_samples = 2).
Description: test get num_samples.
Expectation: the data is processed successfully.
"""
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage="test", num_samples=2)
count = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 2
def test_multi30k_dataset_shuffle_files():
"""
Feature: Test Multi30k Dataset.
Description: test get all files.
Expectation: the data is processed successfully.
"""
dataset = ds.Multi30kDataset(DATA_ALL_FILE, shuffle=True)
count = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 8
def test_multi30k_dataset_shuffle_false():
"""
Feature: Test Multi30k Dataset (shuffle = false).
Description: test get all files.
Expectation: the data is processed successfully.
"""
dataset = ds.Multi30kDataset(DATA_ALL_FILE, shuffle=False)
count = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 8
def test_multi30k_dataset_repeat():
"""
Feature: Test Multi30k in distribution (repeat 3 times).
Description: test in a distributed state.
Expectation: the data is processed successfully.
"""
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage='train')
dataset = dataset.repeat(3)
count = 0
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 9
def test_multi30k_dataset_get_datasetsize():
"""
Feature: Test Getters.
Description: test get_dataset_size of Multi30k dataset.
Expectation: the data is processed successfully.
"""
dataset = ds.Multi30kDataset(DATA_ALL_FILE)
size = dataset.get_dataset_size()
assert size == 8
def test_multi30k_dataset_exceptions():
"""
Feature: Test Multi30k Dataset.
Description: Test exceptions.
Expectation: Exception thrown to be caught
"""
with pytest.raises(ValueError) as error_info:
_ = ds.Multi30kDataset(INVALID_FILE)
assert "The folder ../data/dataset/testMulti30kDataset/invalid_dir does not exist or is not a directory or" \
" permission denied" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.Multi30kDataset(DATA_ALL_FILE, usage="INVALID")
assert "Input usage is not within the valid set of ['train', 'test', 'valid', 'all']." in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.Multi30kDataset(DATA_ALL_FILE, usage="test", language_pair=["ch", "ja"])
assert "language_pair can only be ['en', 'de'] or ['en', 'de'], but got ['ch', 'ja']" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.Multi30kDataset(DATA_ALL_FILE, usage="test", language_pair=["en", "en", "de"])
assert "language_pair should be a list or tuple of length 2, but got 3" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.Multi30kDataset(DATA_ALL_FILE, usage='test', num_samples=-1)
assert "num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)!" in str(error_info.value)
def test_multi30k_dataset_en_pipeline():
"""
Feature: Multi30kDataset
Description: test Multi30kDataset in pipeline mode
Expectation: the data is processed successfully
"""
expected = ["this is the first english sentence in train.",
"this is the second english sentence in train.",
"this is the third english sentence in train."]
dataset = ds.Multi30kDataset(DATA_ALL_FILE, 'train', shuffle=False)
filter_wikipedia_xml_op = a_c_trans.CaseFold()
dataset = dataset.map(input_columns=["text"], operations=filter_wikipedia_xml_op, num_parallel_workers=1)
count = 0
for i in dataset.create_dict_iterator(output_numpy=True):
strs = i["text"].item().decode("utf8")
assert strs == expected[count]
count += 1
def test_multi30k_dataset_de_pipeline():
"""
Feature: Multi30kDataset
Description: test Multi30kDataset in pipeline mode
Expectation: the data is processed successfully
"""
expected = ["this is the first germany sentence in train.",
"this is the second germany sentence in train.",
"this is the third germany sentence in train."]
dataset = ds.Multi30kDataset(DATA_ALL_FILE, 'train', shuffle=False)
filter_wikipedia_xml_op = a_c_trans.CaseFold()
dataset = dataset.map(input_columns=["translation"], operations=filter_wikipedia_xml_op, num_parallel_workers=1)
count = 0
for i in dataset.create_dict_iterator(output_numpy=True):
strs = i["translation"].item().decode("utf8")
assert strs == expected[count]
count += 1
if __name__ == "__main__":
test_data_file_multi30k_text()
test_data_file_multi30k_translation()
test_all_file_multi30k()
test_dataset_num_samples_none()
test_num_shards_multi30k()
test_multi30k_dataset_num_samples()
test_multi30k_dataset_shuffle_files()
test_multi30k_dataset_shuffle_false()
test_multi30k_dataset_repeat()
test_multi30k_dataset_get_datasetsize()
test_multi30k_dataset_exceptions()
test_multi30k_dataset_en_pipeline()
test_multi30k_dataset_de_pipeline()