forked from mindspore-Ecosystem/mindspore
!22514 [assistant][ops] Add new data operator Multi30kDataset
Merge pull request !22514 from 杨旭华/Multi30kDataset
This commit is contained in:
commit
501614a61f
|
@ -104,6 +104,7 @@
|
||||||
#endif
|
#endif
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
||||||
|
@ -1532,6 +1533,16 @@ MnistDataset::MnistDataset(const std::vector<char> &dataset_dir, const std::vect
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
#ifndef ENABLE_ANDROID
|
||||||
|
Multi30kDataset::Multi30kDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
|
const std::vector<std::vector<char>> &language_pair, int64_t num_samples,
|
||||||
|
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||||
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
|
auto ds =
|
||||||
|
std::make_shared<Multi30kNode>(CharToString(dataset_dir), CharToString(usage), VectorCharToString(language_pair),
|
||||||
|
num_samples, shuffle, num_shards, shard_id, cache);
|
||||||
|
ir_node_ = std::static_pointer_cast<Multi30kNode>(ds);
|
||||||
|
}
|
||||||
|
|
||||||
PennTreebankDataset::PennTreebankDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
PennTreebankDataset::PennTreebankDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
|
||||||
const std::shared_ptr<DatasetCache> &cache) {
|
const std::shared_ptr<DatasetCache> &cache) {
|
||||||
|
|
|
@ -68,6 +68,7 @@
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/places365_node.h"
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h"
|
||||||
|
@ -464,6 +465,20 @@ PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(Multi30kNode, 2, ([](const py::module *m) {
|
||||||
|
(void)py::class_<Multi30kNode, DatasetNode, std::shared_ptr<Multi30kNode>>(*m, "Multi30kNode",
|
||||||
|
"to create a Multi30kNode")
|
||||||
|
.def(py::init([](const std::string &dataset_dir, const std::string &usage,
|
||||||
|
const std::vector<std::string> &language_pair, int64_t num_samples,
|
||||||
|
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
|
||||||
|
std::shared_ptr<Multi30kNode> multi30k =
|
||||||
|
std::make_shared<Multi30kNode>(dataset_dir, usage, language_pair, num_samples,
|
||||||
|
toShuffleMode(shuffle), num_shards, shard_id, nullptr);
|
||||||
|
THROW_IF_ERROR(multi30k->ValidateParams());
|
||||||
|
return multi30k;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
|
||||||
PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>(
|
(void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>(
|
||||||
*m, "PennTreebankNode", "to create a PennTreebankNode")
|
*m, "PennTreebankNode", "to create a PennTreebankNode")
|
||||||
|
|
|
@ -30,6 +30,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||||
lj_speech_op.cc
|
lj_speech_op.cc
|
||||||
mappable_leaf_op.cc
|
mappable_leaf_op.cc
|
||||||
mnist_op.cc
|
mnist_op.cc
|
||||||
|
multi30k_op.cc
|
||||||
nonmappable_leaf_op.cc
|
nonmappable_leaf_op.cc
|
||||||
penn_treebank_op.cc
|
penn_treebank_op.cc
|
||||||
photo_tour_op.cc
|
photo_tour_op.cc
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/multi30k_op.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "debug/common.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||||
|
#include "utils/file_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
// constructor of Multi30k.
|
||||||
|
Multi30kOp::Multi30kOp(int32_t num_workers, int64_t num_samples, const std::vector<std::string> &language_pair,
|
||||||
|
int32_t worker_connector_size, std::unique_ptr<DataSchema> schema,
|
||||||
|
const std::vector<std::string> &text_files_list, int32_t op_connector_size, bool shuffle_files,
|
||||||
|
int32_t num_devices, int32_t device_id)
|
||||||
|
: TextFileOp(num_workers, num_samples, worker_connector_size, std::move(schema), std::move(text_files_list),
|
||||||
|
op_connector_size, shuffle_files, num_devices, device_id),
|
||||||
|
language_pair_(language_pair) {}
|
||||||
|
|
||||||
|
// Print info of operator.
|
||||||
|
void Multi30kOp::Print(std::ostream &out, bool show_all) {
|
||||||
|
// Print parameter to debug function.
|
||||||
|
std::vector<std::string> multi30k_files_list = TextFileOp::FileNames();
|
||||||
|
if (!show_all) {
|
||||||
|
// Call the super class for displaying any common 1-liner info.
|
||||||
|
ParallelOp::Print(out, show_all);
|
||||||
|
// Then show any custom derived-internal 1-liner info for this op.
|
||||||
|
out << "\n";
|
||||||
|
} else {
|
||||||
|
// Call the super class for displaying any common detailed info.
|
||||||
|
ParallelOp::Print(out, show_all);
|
||||||
|
// Then show any custom derived-internal stuff.
|
||||||
|
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||||
|
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nMulti30k files list:\n";
|
||||||
|
for (int i = 0; i < multi30k_files_list.size(); ++i) {
|
||||||
|
out << " " << multi30k_files_list[i];
|
||||||
|
}
|
||||||
|
out << "\n\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Multi30kOp::LoadTensor(const std::string &line, TensorRow *out_row, size_t index) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(out_row);
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
|
||||||
|
(*out_row)[index] = std::move(tensor);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Multi30kOp::LoadFile(const std::string &file_en, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||||
|
auto realpath_en = FileUtils::GetRealPath(file_en.data());
|
||||||
|
if (!realpath_en.has_value()) {
|
||||||
|
MS_LOG(ERROR) << "Invalid file path, " << DatasetName() + " Dataset file: " << file_en << " does not exist.";
|
||||||
|
RETURN_STATUS_UNEXPECTED("Invalid file path, " + DatasetName() + " Dataset file: " + file_en + " does not exist.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// We use English files to find Germany files, to make sure that data are ordered.
|
||||||
|
Path path_en(file_en);
|
||||||
|
Path parent_path(path_en.ParentPath());
|
||||||
|
std::string basename = path_en.Basename();
|
||||||
|
int suffix_len = 3;
|
||||||
|
std::string suffix_de = ".de";
|
||||||
|
basename = basename.replace(basename.find("."), suffix_len, suffix_de);
|
||||||
|
Path BaseName(basename);
|
||||||
|
Path path_de = parent_path / BaseName;
|
||||||
|
std::string file_de = path_de.ToString();
|
||||||
|
auto realpath_de = FileUtils::GetRealPath(file_de.data());
|
||||||
|
if (!realpath_de.has_value()) {
|
||||||
|
MS_LOG(ERROR) << "Invalid file path, " << DatasetName() + " Dataset file: " << file_de << " does not exist.";
|
||||||
|
RETURN_STATUS_UNEXPECTED("Invalid file path, " + DatasetName() + " Dataset file: " + file_de + " does not exist.");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ifstream handle_en(realpath_en.value());
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(handle_en.is_open(), "Invalid file, failed to open en file: " + file_en);
|
||||||
|
std::ifstream handle_de(realpath_de.value());
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(handle_de.is_open(), "Invalid file, failed to open de file: " + file_de);
|
||||||
|
|
||||||
|
// Set path for path in class TensorRow.
|
||||||
|
std::string line_en;
|
||||||
|
std::string line_de;
|
||||||
|
std::vector<std::string> path = {file_en, file_de};
|
||||||
|
|
||||||
|
int row_total = 0;
|
||||||
|
while (getline(handle_en, line_en) && getline(handle_de, line_de)) {
|
||||||
|
if (line_en.empty() && line_de.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// If read to the end offset of this file, break.
|
||||||
|
if (row_total >= end_offset) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Skip line before start offset.
|
||||||
|
if (row_total < start_offset) {
|
||||||
|
++row_total;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int tensor_size = 2;
|
||||||
|
TensorRow tRow(tensor_size, nullptr);
|
||||||
|
|
||||||
|
Status rc_en;
|
||||||
|
Status rc_de;
|
||||||
|
if (language_pair_[0] == "en") {
|
||||||
|
rc_en = LoadTensor(line_en, &tRow, 0);
|
||||||
|
rc_de = LoadTensor(line_de, &tRow, 1);
|
||||||
|
} else if (language_pair_[0] == "de") {
|
||||||
|
rc_en = LoadTensor(line_en, &tRow, 1);
|
||||||
|
rc_de = LoadTensor(line_de, &tRow, 0);
|
||||||
|
}
|
||||||
|
if (rc_en.IsError() || rc_de.IsError()) {
|
||||||
|
handle_en.close();
|
||||||
|
handle_de.close();
|
||||||
|
RETURN_IF_NOT_OK(rc_en);
|
||||||
|
RETURN_IF_NOT_OK(rc_de);
|
||||||
|
}
|
||||||
|
(&tRow)->setPath(path);
|
||||||
|
|
||||||
|
Status rc = jagged_rows_connector_->Add(worker_id, std::move(tRow));
|
||||||
|
if (rc.IsError()) {
|
||||||
|
handle_en.close();
|
||||||
|
handle_de.close();
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
++row_total;
|
||||||
|
}
|
||||||
|
|
||||||
|
handle_en.close();
|
||||||
|
handle_de.close();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MULTI30K_OP_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MULTI30K_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class JaggedConnector;
|
||||||
|
using StringIndex = AutoIndexObj<std::string>;
|
||||||
|
|
||||||
|
class Multi30kOp : public TextFileOp {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor of Multi30kOp
|
||||||
|
/// \note The builder class should be used to call this constructor.
|
||||||
|
/// \param[in] num_workers Number of worker threads reading data from multi30k_file files.
|
||||||
|
/// \param[in] num_samples Number of rows to read.
|
||||||
|
/// \param[in] language_pair List containing text and translation language.
|
||||||
|
/// \param[in] worker_connector_size List of filepaths for the dataset files.
|
||||||
|
/// \param[in] schema The data schema object.
|
||||||
|
/// \param[in] text_files_list File path of multi30k files.
|
||||||
|
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
|
||||||
|
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
|
||||||
|
/// \param[in] num_devices Shards of data.
|
||||||
|
/// \param[in] device_id The device ID within num_devices.
|
||||||
|
Multi30kOp(int32_t num_workers, int64_t num_samples, const std::vector<std::string> &language_pair,
|
||||||
|
int32_t worker_connector_size, std::unique_ptr<DataSchema> schema,
|
||||||
|
const std::vector<std::string> &text_files_list, int32_t op_connector_size, bool shuffle_files,
|
||||||
|
int32_t num_devices, int32_t device_id);
|
||||||
|
|
||||||
|
/// \Default destructor.
|
||||||
|
~Multi30kOp() = default;
|
||||||
|
|
||||||
|
/// \brief A print method typically used for debugging.
|
||||||
|
/// \param[in] out The output stream to write output to.
|
||||||
|
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
|
||||||
|
void Print(std::ostream &out, bool show_all);
|
||||||
|
|
||||||
|
/// \brief Return the name of Operator.
|
||||||
|
/// \return Status - return the name of Operator.
|
||||||
|
std::string Name() const override { return "Multi30kOp"; }
|
||||||
|
|
||||||
|
/// \brief DatasetName name getter.
|
||||||
|
/// \param[in] upper If true, the return value is uppercase, otherwise, it is lowercase.
|
||||||
|
/// \return std::string DatasetName of the current Op.
|
||||||
|
std::string DatasetName(bool upper = false) const { return upper ? "Multi30k" : "multi30k"; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// \brief Load data into Tensor.
|
||||||
|
/// \param[in] line Data read from files.
|
||||||
|
/// \param[in] out_row Output tensor.
|
||||||
|
/// \param[in] index The index of Tensor.
|
||||||
|
Status LoadTensor(const std::string &line, TensorRow *out_row, size_t index);
|
||||||
|
|
||||||
|
/// \brief Read data from files.
|
||||||
|
/// \param[in] file_en The paths of multi30k dataset files.
|
||||||
|
/// \param[in] start_offset The location of reading start.
|
||||||
|
/// \param[in] end_offset The location of reading finished.
|
||||||
|
/// \param[in] worker_id The id of the worker that is executing this function.
|
||||||
|
Status LoadFile(const std::string &file_en, int64_t start_offset, int64_t end_offset, int32_t worker_id);
|
||||||
|
|
||||||
|
std::vector<std::string> language_pair_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MULTI30K_OP_H_
|
|
@ -106,6 +106,7 @@ constexpr char kLJSpeechNode[] = "LJSpeechDataset";
|
||||||
constexpr char kManifestNode[] = "ManifestDataset";
|
constexpr char kManifestNode[] = "ManifestDataset";
|
||||||
constexpr char kMindDataNode[] = "MindDataDataset";
|
constexpr char kMindDataNode[] = "MindDataDataset";
|
||||||
constexpr char kMnistNode[] = "MnistDataset";
|
constexpr char kMnistNode[] = "MnistDataset";
|
||||||
|
constexpr char kMulti30kNode[] = "Multi30kDataset";
|
||||||
constexpr char kPennTreebankNode[] = "PennTreebankDataset";
|
constexpr char kPennTreebankNode[] = "PennTreebankDataset";
|
||||||
constexpr char kPhotoTourNode[] = "PhotoTourDataset";
|
constexpr char kPhotoTourNode[] = "PhotoTourDataset";
|
||||||
constexpr char kPlaces365Node[] = "Places365Dataset";
|
constexpr char kPlaces365Node[] = "Places365Dataset";
|
||||||
|
|
|
@ -32,6 +32,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
||||||
manifest_node.cc
|
manifest_node.cc
|
||||||
minddata_node.cc
|
minddata_node.cc
|
||||||
mnist_node.cc
|
mnist_node.cc
|
||||||
|
multi30k_node.cc
|
||||||
penn_treebank_node.cc
|
penn_treebank_node.cc
|
||||||
photo_tour_node.cc
|
photo_tour_node.cc
|
||||||
places365_node.cc
|
places365_node.cc
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/multi30k_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
Multi30kNode::Multi30kNode(const std::string &dataset_dir, const std::string &usage,
|
||||||
|
const std::vector<std::string> &language_pair, int32_t num_samples, ShuffleMode shuffle,
|
||||||
|
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||||
|
: NonMappableSourceNode(std::move(cache)),
|
||||||
|
dataset_dir_(dataset_dir),
|
||||||
|
usage_(usage),
|
||||||
|
language_pair_(language_pair),
|
||||||
|
num_samples_(num_samples),
|
||||||
|
shuffle_(shuffle),
|
||||||
|
num_shards_(num_shards),
|
||||||
|
shard_id_(shard_id),
|
||||||
|
multi30k_files_list_(WalkAllFiles(usage, dataset_dir)) {
|
||||||
|
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multi30kNode::Print(std::ostream &out) const {
|
||||||
|
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
|
||||||
|
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<DatasetNode> Multi30kNode::Copy() {
|
||||||
|
auto node = std::make_shared<Multi30kNode>(dataset_dir_, usage_, language_pair_, num_samples_, shuffle_, num_shards_,
|
||||||
|
shard_id_, cache_);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to build Multi30kNode
|
||||||
|
Status Multi30kNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||||
|
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||||
|
|
||||||
|
std::vector<std::string> sorted_dataset_files = multi30k_files_list_;
|
||||||
|
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||||
|
|
||||||
|
auto schema = std::make_unique<DataSchema>();
|
||||||
|
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1)));
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
schema->AddColumn(ColDescriptor("translation", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 1)));
|
||||||
|
|
||||||
|
std::shared_ptr<Multi30kOp> multi30k_op =
|
||||||
|
std::make_shared<Multi30kOp>(num_workers_, num_samples_, language_pair_, worker_connector_size_, std::move(schema),
|
||||||
|
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||||
|
RETURN_IF_NOT_OK(multi30k_op->Init());
|
||||||
|
|
||||||
|
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||||
|
// Inject ShuffleOp
|
||||||
|
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||||
|
int64_t num_rows = 0;
|
||||||
|
|
||||||
|
// First, get the number of rows in the dataset
|
||||||
|
RETURN_IF_NOT_OK(Multi30kOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||||
|
|
||||||
|
// Add the shuffle op after this op
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||||
|
shuffle_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
|
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
|
node_ops->push_back(shuffle_op);
|
||||||
|
}
|
||||||
|
multi30k_op->SetTotalRepeats(GetTotalRepeats());
|
||||||
|
multi30k_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||||
|
// Add Multi30kOp
|
||||||
|
node_ops->push_back(multi30k_op);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Multi30kNode::ValidateParams() {
|
||||||
|
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||||
|
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Multi30kDataset", dataset_dir_));
|
||||||
|
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("Multi30kDataset", multi30k_files_list_));
|
||||||
|
RETURN_IF_NOT_OK(ValidateStringValue("Multi30kDataset", usage_, {"train", "valid", "test", "all"}));
|
||||||
|
|
||||||
|
const int kLanguagePairSize = 2;
|
||||||
|
if (language_pair_.size() != kLanguagePairSize) {
|
||||||
|
std::string err_msg =
|
||||||
|
"Multi30kDataset: language_pair expecting size 2, but got: " + std::to_string(language_pair_.size());
|
||||||
|
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::vector<std::string>> support_language_pair = {{"en", "de"}, {"de", "en"}};
|
||||||
|
if (language_pair_ != support_language_pair[0] && language_pair_ != support_language_pair[1]) {
|
||||||
|
std::string err_msg = R"(Multi30kDataset: language_pair must be {"en", "de"} or {"de", "en"}.)";
|
||||||
|
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(ValidateScalar("Multi30kDataset", "num_samples", num_samples_, {0}, false));
|
||||||
|
RETURN_IF_NOT_OK(ValidateDatasetShardParams("Multi30kDataset", num_shards_, shard_id_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Multi30kNode::GetShardId(int32_t *shard_id) {
|
||||||
|
*shard_id = shard_id_;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Multi30kNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
|
int64_t *dataset_size) {
|
||||||
|
if (dataset_size_ > 0) {
|
||||||
|
*dataset_size = dataset_size_;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
int64_t num_rows, sample_size = num_samples_;
|
||||||
|
RETURN_IF_NOT_OK(Multi30kOp::CountAllFileRows(multi30k_files_list_, &num_rows));
|
||||||
|
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
|
||||||
|
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
|
||||||
|
dataset_size_ = *dataset_size;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Multi30kNode::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["num_parallel_workers"] = num_workers_;
|
||||||
|
args["dataset_dir"] = dataset_dir_;
|
||||||
|
args["num_samples"] = num_samples_;
|
||||||
|
args["shuffle"] = shuffle_;
|
||||||
|
args["num_shards"] = num_shards_;
|
||||||
|
args["shard_id"] = shard_id_;
|
||||||
|
args["language_pair"] = language_pair_;
|
||||||
|
if (cache_ != nullptr) {
|
||||||
|
nlohmann::json cache_args;
|
||||||
|
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||||
|
args["cache"] = cache_args;
|
||||||
|
}
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Multi30kNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||||
|
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||||
|
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Multi30kNode::MakeSimpleProducer() {
|
||||||
|
shard_id_ = 0;
|
||||||
|
num_shards_ = 1;
|
||||||
|
shuffle_ = ShuffleMode::kFalse;
|
||||||
|
num_samples_ = 0;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> Multi30kNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
|
||||||
|
std::vector<std::string> multi30k_files_list;
|
||||||
|
Path train_en("training/train.en");
|
||||||
|
Path test_en("mmt16_task1_test/test.en");
|
||||||
|
Path valid_en("validation/val.en");
|
||||||
|
Path dir(dataset_dir);
|
||||||
|
|
||||||
|
if (usage == "train") {
|
||||||
|
Path temp_path = dir / train_en;
|
||||||
|
multi30k_files_list.push_back(temp_path.ToString());
|
||||||
|
} else if (usage == "test") {
|
||||||
|
Path temp_path = dir / test_en;
|
||||||
|
multi30k_files_list.push_back(temp_path.ToString());
|
||||||
|
} else if (usage == "valid") {
|
||||||
|
Path temp_path = dir / valid_en;
|
||||||
|
multi30k_files_list.push_back(temp_path.ToString());
|
||||||
|
} else {
|
||||||
|
Path temp_path = dir / train_en;
|
||||||
|
multi30k_files_list.push_back(temp_path.ToString());
|
||||||
|
Path temp_path1 = dir / test_en;
|
||||||
|
multi30k_files_list.push_back(temp_path1.ToString());
|
||||||
|
Path temp_path2 = dir / valid_en;
|
||||||
|
multi30k_files_list.push_back(temp_path2.ToString());
|
||||||
|
}
|
||||||
|
return multi30k_files_list;
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,134 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class Multi30kNode : public NonMappableSourceNode {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor of Multi30kNode.
|
||||||
|
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||||
|
/// \param[in] usage Part of dataset of MULTI30K, can be "train", "test", "valid" or "all".
|
||||||
|
/// \param[in] language_pair List containing text and translation language.
|
||||||
|
/// \param[in] num_samples The number of samples to be included in the dataset
|
||||||
|
/// \param[in] shuffle The mode for shuffling data every epoch.
|
||||||
|
/// Can be any of:
|
||||||
|
/// ShuffleMode::kFalse - No shuffling is performed.
|
||||||
|
/// ShuffleMode::kFiles - Shuffle files only.
|
||||||
|
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
|
||||||
|
/// \param[in] num_shards Number of shards that the dataset should be divided into.
|
||||||
|
/// \param[in] shared_id The shard ID within num_shards. This argument should be
|
||||||
|
/// specified only when num_shards is also specified.
|
||||||
|
/// \param[in] cache Tensor cache to use.
|
||||||
|
Multi30kNode(const std::string &dataset_dir, const std::string &usage, const std::vector<std::string> &language_pair,
|
||||||
|
int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shared_id,
|
||||||
|
std::shared_ptr<DatasetCache> cache);
|
||||||
|
|
||||||
|
/// \brief Destructor of Multi30kNode.
|
||||||
|
~Multi30kNode() = default;
|
||||||
|
|
||||||
|
/// \brief Node name getter.
|
||||||
|
/// \return Name of the current node.
|
||||||
|
std::string Name() const override { return kMulti30kNode; }
|
||||||
|
|
||||||
|
/// \brief Print the description.
|
||||||
|
/// \param[in] out The output stream to write output to.
|
||||||
|
void Print(std::ostream &out) const override;
|
||||||
|
|
||||||
|
/// \brief Copy the node to a new object.
|
||||||
|
/// \return A shared pointer to the new copy.
|
||||||
|
std::shared_ptr<DatasetNode> Copy() override;
|
||||||
|
|
||||||
|
/// \brief a base class override function to create the required runtime dataset op objects for this class.
|
||||||
|
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||||
|
/// \return Status Status::OK() if build successfully.
|
||||||
|
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops);
|
||||||
|
|
||||||
|
/// \brief Parameters validation.
|
||||||
|
/// \return Status Status::OK() if all the parameters are valid.
|
||||||
|
Status ValidateParams() override;
|
||||||
|
|
||||||
|
/// \brief Get the shard id of node.
|
||||||
|
/// \param[in] shard_id The shard id.
|
||||||
|
/// \return Status Status::OK() if get shard id successfully.
|
||||||
|
Status GetShardId(int32_t *shard_id) override;
|
||||||
|
|
||||||
|
/// \brief Base-class override for GetDatasetSize.
|
||||||
|
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
|
||||||
|
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting.
|
||||||
|
/// dataset size at the expense of accuracy.
|
||||||
|
/// \param[out] dataset_size the size of the dataset.
|
||||||
|
/// \return Status of the function.
|
||||||
|
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||||
|
int64_t *dataset_size) override;
|
||||||
|
|
||||||
|
/// \brief Get the arguments of node.
|
||||||
|
/// \param[out] out_json JSON string of all attributes.
|
||||||
|
/// \return Status of the function.
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
|
/// \brief Multi30k by itself is a non-mappable dataset that does not support sampling.
|
||||||
|
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||||
|
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||||
|
/// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||||
|
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||||
|
/// \param[in] sampler The sampler to setup.
|
||||||
|
/// \return Status of the function.
|
||||||
|
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
|
||||||
|
|
||||||
|
/// \brief If a cache has been added into the ascendant tree over this Multi30k node, then the cache will be executing
|
||||||
|
/// a sampler for fetching the data. As such, any options in the Multi30k node need to be reset to its defaults
|
||||||
|
/// so that this Multi30k node will produce the full set of data into the cache.
|
||||||
|
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
|
||||||
|
/// \return Status of the function
|
||||||
|
Status MakeSimpleProducer() override;
|
||||||
|
|
||||||
|
/// \brief Getter functions
|
||||||
|
int32_t NumSamples() const { return num_samples_; }
|
||||||
|
int32_t NumShards() const { return num_shards_; }
|
||||||
|
int32_t ShardId() const { return shard_id_; }
|
||||||
|
ShuffleMode Shuffle() const { return shuffle_; }
|
||||||
|
const std::string &DatasetDir() const { return dataset_dir_; }
|
||||||
|
const std::string &Usage() const { return usage_; }
|
||||||
|
const std::vector<std::string> &LanguagePair() const { return language_pair_; }
|
||||||
|
|
||||||
|
/// \brief Generate a list of read file names according to usage.
|
||||||
|
/// \param[in] usage Part of dataset of Multi30k.
|
||||||
|
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||||
|
/// \return std::vector<std::string> A list of read file names.
|
||||||
|
std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string dataset_dir_;
|
||||||
|
std::string usage_;
|
||||||
|
std::vector<std::string> language_pair_;
|
||||||
|
int32_t num_samples_;
|
||||||
|
ShuffleMode shuffle_;
|
||||||
|
int32_t num_shards_;
|
||||||
|
int32_t shard_id_;
|
||||||
|
std::vector<std::string> multi30k_files_list_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MULTI30K_NODE_H_
|
|
@ -3723,6 +3723,75 @@ inline std::shared_ptr<MnistDataset> MS_API Mnist(const std::string &dataset_dir
|
||||||
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// \class Multi30kDataset
|
||||||
|
/// \brief A source dataset that reads and parses Multi30k dataset.
|
||||||
|
class MS_API Multi30kDataset : public Dataset {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor of Multi30kDataset.
|
||||||
|
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||||
|
/// \param[in] usage Part of dataset of MULTI30K, can be "train", "test", "valid" or "all".
|
||||||
|
/// \param[in] language_pair List containing text and translation language.
|
||||||
|
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||||
|
/// \param[in] shuffle The mode for shuffling data every epoch.
|
||||||
|
/// Can be any of:
|
||||||
|
/// ShuffleMode::kFalse - No shuffling is performed.
|
||||||
|
/// ShuffleMode::kFiles - Shuffle files only.
|
||||||
|
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
|
||||||
|
/// \param[in] num_shards Number of shards that the dataset should be divided into.
|
||||||
|
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||||
|
/// specified only when num_shards is also specified.
|
||||||
|
/// \param[in] cache Tensor cache to use.
|
||||||
|
Multi30kDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
|
const std::vector<std::vector<char>> &language_pair, int64_t num_samples, ShuffleMode shuffle,
|
||||||
|
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
|
||||||
|
|
||||||
|
/// \brief Destructor of Multi30kDataset.
|
||||||
|
~Multi30kDataset() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// \brief Function to create a Multi30kDataset.
|
||||||
|
/// \note The generated dataset has two columns ["text", "translation"].
|
||||||
|
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
|
||||||
|
/// \param[in] usage Part of dataset of MULTI30K, can be "train", "test", "valid" or "all" (default = "all").
|
||||||
|
/// \param[in] language_pair List containing text and translation language (default = {"en", "de"}).
|
||||||
|
/// \param[in] num_samples The number of samples to be included in the dataset
|
||||||
|
/// (Default = 0, means all samples).
|
||||||
|
/// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal).
|
||||||
|
/// Can be any of:
|
||||||
|
/// ShuffleMode::kFalse - No shuffling is performed.
|
||||||
|
/// ShuffleMode::kFiles - Shuffle files only.
|
||||||
|
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
|
||||||
|
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
|
||||||
|
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||||
|
/// specified only when num_shards is also specified (Default = 0).
|
||||||
|
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
|
||||||
|
/// \return Shared pointer to the Multi30kDataset.
|
||||||
|
/// \par Example
|
||||||
|
/// \code
|
||||||
|
/// /* Define dataset path and MindData object */
|
||||||
|
/// std::string dataset_dir = "/path/to/multi30k_dataset_directory";
|
||||||
|
/// std::shared_ptr<Dataset> ds = Multi30k(dataset_dir, "all");
|
||||||
|
///
|
||||||
|
/// /* Create iterator to read dataset */
|
||||||
|
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
/// std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
/// iter->GetNextRow(&row);
|
||||||
|
///
|
||||||
|
/// /* Note: In Multi30kdataset, each dictionary has keys "text" and "translation" */
|
||||||
|
/// auto text = row["text"];
|
||||||
|
/// \endcode
|
||||||
|
inline std::shared_ptr<Multi30kDataset> MS_API Multi30k(const std::string &dataset_dir,
|
||||||
|
const std::string &usage = "all",
|
||||||
|
const std::vector<std::string> &language_pair = {"en", "de"},
|
||||||
|
int64_t num_samples = 0,
|
||||||
|
ShuffleMode shuffle = ShuffleMode::kGlobal,
|
||||||
|
int32_t num_shards = 1, int32_t shard_id = 0,
|
||||||
|
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||||
|
return std::make_shared<Multi30kDataset>(StringToChar(dataset_dir), StringToChar(usage),
|
||||||
|
VectorStringToChar(language_pair), num_samples, shuffle, num_shards,
|
||||||
|
shard_id, cache);
|
||||||
|
}
|
||||||
|
|
||||||
/// \class PennTreebankDataset
|
/// \class PennTreebankDataset
|
||||||
/// \brief A source dataset for reading and parsing PennTreebank dataset.
|
/// \brief A source dataset for reading and parsing PennTreebank dataset.
|
||||||
class MS_API PennTreebankDataset : public Dataset {
|
class MS_API PennTreebankDataset : public Dataset {
|
||||||
|
|
|
@ -74,6 +74,7 @@ __all__ = ["Caltech101Dataset", # Vision
|
||||||
"IMDBDataset", # Text
|
"IMDBDataset", # Text
|
||||||
"IWSLT2016Dataset", # Text
|
"IWSLT2016Dataset", # Text
|
||||||
"IWSLT2017Dataset", # Text
|
"IWSLT2017Dataset", # Text
|
||||||
|
"Multi30kDataset", # Text
|
||||||
"PennTreebankDataset", # Text
|
"PennTreebankDataset", # Text
|
||||||
"SogouNewsDataset", # Text
|
"SogouNewsDataset", # Text
|
||||||
"TextFileDataset", # Text
|
"TextFileDataset", # Text
|
||||||
|
|
|
@ -30,7 +30,7 @@ from .validators import check_imdb_dataset, check_iwslt2016_dataset, check_iwslt
|
||||||
check_penn_treebank_dataset, check_ag_news_dataset, check_amazon_review_dataset, check_udpos_dataset, \
|
check_penn_treebank_dataset, check_ag_news_dataset, check_amazon_review_dataset, check_udpos_dataset, \
|
||||||
check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, \
|
check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, \
|
||||||
check_sogou_news_dataset, check_textfiledataset, check_dbpedia_dataset, check_yelp_review_dataset, \
|
check_sogou_news_dataset, check_textfiledataset, check_dbpedia_dataset, check_yelp_review_dataset, \
|
||||||
check_en_wik9_dataset, check_yahoo_answers_dataset
|
check_en_wik9_dataset, check_yahoo_answers_dataset, check_multi30k_dataset
|
||||||
|
|
||||||
from ..core.validator_helpers import replace_none
|
from ..core.validator_helpers import replace_none
|
||||||
|
|
||||||
|
@ -961,6 +961,106 @@ class IWSLT2017Dataset(SourceDataset, TextBaseDataset):
|
||||||
self.shuffle_flag, self.num_shards, self.shard_id)
|
self.shuffle_flag, self.num_shards, self.shard_id)
|
||||||
|
|
||||||
|
|
||||||
|
class Multi30kDataset(SourceDataset, TextBaseDataset):
|
||||||
|
"""
|
||||||
|
A source dataset that reads and parses Multi30k dataset.
|
||||||
|
|
||||||
|
The generated dataset has two columns :py:obj:`[text, translation]`.
|
||||||
|
The tensor of column :py:obj:'text' is of the string type.
|
||||||
|
The tensor of column :py:obj:'translation' is of the string type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_dir (str): Path to the root directory that contains the dataset.
|
||||||
|
usage (str, optional): Acceptable usages include `train`, `test, `valid` or `all` (default=`all`).
|
||||||
|
language_pair (str, optional): Acceptable language_pair include ['en', 'de'], ['de', 'en']
|
||||||
|
(default=['en', 'de']).
|
||||||
|
num_samples (int, optional): The number of images to be included in the dataset
|
||||||
|
(default=None, all samples).
|
||||||
|
num_parallel_workers (int, optional): Number of workers to read the data
|
||||||
|
(default=None, number set in the config).
|
||||||
|
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
|
||||||
|
(default=Shuffle.GLOBAL).
|
||||||
|
If shuffle is False, no shuffling will be performed;
|
||||||
|
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
|
||||||
|
Otherwise, there are two levels of shuffling:
|
||||||
|
|
||||||
|
- Shuffle.GLOBAL: Shuffle both the files and samples.
|
||||||
|
|
||||||
|
- Shuffle.FILES: Shuffle files only.
|
||||||
|
|
||||||
|
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||||
|
into (default=None). When this argument is specified, `num_samples` reflects
|
||||||
|
the max sample number of per shard.
|
||||||
|
shard_id (int, optional): The shard ID within num_shards (default=None). This
|
||||||
|
argument can only be specified when num_shards is also specified.
|
||||||
|
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
|
||||||
|
(default=None, which means no cache is used).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If dataset_dir does not contain data files.
|
||||||
|
RuntimeError: If usage is not "train", "test", "valid" or "all".
|
||||||
|
RuntimeError: If the length of language_pair is not equal to 2.
|
||||||
|
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||||
|
RuntimeError: If num_shards is specified but shard_id is None.
|
||||||
|
RuntimeError: If shard_id is specified but num_shards is None.
|
||||||
|
RuntimeError: If num_samples is less than 0.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> multi30k_dataset_dir = "/path/to/multi30k_dataset_directory"
|
||||||
|
>>> data = ds.Multi30kDataset(dataset_dir=multi30k_dataset_dir, usage='all', language_pair=['de', 'en'])
|
||||||
|
|
||||||
|
About Multi30k dataset:
|
||||||
|
|
||||||
|
Multi30K is a dataset to stimulate multilingual multimodal research for English-German.
|
||||||
|
It is based on the Flickr30k dataset, which contains images sourced from online
|
||||||
|
photo-sharing websites. Each image is paired with five English descriptions, which were
|
||||||
|
collected from Amazon Mechanical Turk. The Multi30K dataset extends the Flickr30K
|
||||||
|
dataset with translated and independent German sentences.
|
||||||
|
|
||||||
|
You can unzip the dataset files into the following directory structure and read by MindSpore's API.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
└── multi30k_dataset_directory
|
||||||
|
├── training
|
||||||
|
│ ├── train.de
|
||||||
|
│ └── train.en
|
||||||
|
├── validation
|
||||||
|
│ ├── val.de
|
||||||
|
│ └── val.en
|
||||||
|
└── mmt16_task1_test
|
||||||
|
├── val.de
|
||||||
|
└── val.en
|
||||||
|
|
||||||
|
Citation:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
@article{elliott-EtAl:2016:VL16,
|
||||||
|
author = {{Elliott}, D. and {Frank}, S. and {Sima'an}, K. and {Specia}, L.},
|
||||||
|
title = {Multi30K: Multilingual English-German Image Descriptions},
|
||||||
|
booktitle = {Proceedings of the 5th Workshop on Vision and Language},
|
||||||
|
year = {2016},
|
||||||
|
pages = {70--74},
|
||||||
|
year = 2016
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_multi30k_dataset
|
||||||
|
def __init__(self, dataset_dir, usage=None, language_pair=None, num_samples=None,
|
||||||
|
num_parallel_workers=None, shuffle=None, num_shards=None, shard_id=None, cache=None):
|
||||||
|
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
|
||||||
|
num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||||
|
self.dataset_dir = dataset_dir
|
||||||
|
self.usage = replace_none(usage, 'all')
|
||||||
|
self.language_pair = replace_none(language_pair, ["en", "de"])
|
||||||
|
self.shuffle = replace_none(shuffle, Shuffle.GLOBAL)
|
||||||
|
|
||||||
|
def parse(self, children=None):
|
||||||
|
return cde.Multi30kNode(self.dataset_dir, self.usage, self.language_pair, self.num_samples,
|
||||||
|
self.shuffle_flag, self.num_shards, self.shard_id)
|
||||||
|
|
||||||
|
|
||||||
class PennTreebankDataset(SourceDataset, TextBaseDataset):
|
class PennTreebankDataset(SourceDataset, TextBaseDataset):
|
||||||
"""
|
"""
|
||||||
A source dataset that reads and parses PennTreebank datasets.
|
A source dataset that reads and parses PennTreebank datasets.
|
||||||
|
|
|
@ -2558,3 +2558,40 @@ def check_en_wik9_dataset(method):
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_multi30k_dataset(method):
|
||||||
|
"""A wrapper that wraps a parameter checker around the original Dataset (Multi30kDataset)."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
||||||
|
nreq_param_bool = ['shuffle', 'decode']
|
||||||
|
|
||||||
|
dataset_dir = param_dict.get('dataset_dir')
|
||||||
|
check_dir(dataset_dir)
|
||||||
|
|
||||||
|
usage = param_dict.get('usage')
|
||||||
|
if usage is not None:
|
||||||
|
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
|
||||||
|
|
||||||
|
language_pair = param_dict.get('language_pair')
|
||||||
|
support_language_pair = [['en', 'de'], ['de', 'en'], ('en', 'de'), ('de', 'en')]
|
||||||
|
if language_pair is not None:
|
||||||
|
type_check(language_pair, (list, tuple), "language_pair")
|
||||||
|
if len(language_pair) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"language_pair should be a list or tuple of length 2, but got {0}".format(len(language_pair)))
|
||||||
|
if language_pair not in support_language_pair:
|
||||||
|
raise ValueError(
|
||||||
|
"language_pair can only be ['en', 'de'] or ['en', 'de'], but got {0}".format(language_pair))
|
||||||
|
|
||||||
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||||
|
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||||
|
|
||||||
|
check_sampler_shuffle_shard_options(param_dict)
|
||||||
|
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -40,6 +40,7 @@ SET(DE_UT_SRCS
|
||||||
c_api_dataset_lj_speech_test.cc
|
c_api_dataset_lj_speech_test.cc
|
||||||
c_api_dataset_manifest_test.cc
|
c_api_dataset_manifest_test.cc
|
||||||
c_api_dataset_minddata_test.cc
|
c_api_dataset_minddata_test.cc
|
||||||
|
c_api_dataset_multi30k_test.cc
|
||||||
c_api_dataset_ops_test.cc
|
c_api_dataset_ops_test.cc
|
||||||
c_api_dataset_penn_treebank_test.cc
|
c_api_dataset_penn_treebank_test.cc
|
||||||
c_api_dataset_photo_tour_test.cc
|
c_api_dataset_photo_tour_test.cc
|
||||||
|
|
|
@ -0,0 +1,652 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "minddata/dataset/include/dataset/datasets.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::Status;
|
||||||
|
using mindspore::dataset::ShuffleMode;
|
||||||
|
|
||||||
|
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||||
|
protected:
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Feature: Test Multi30k Dataset(English).
|
||||||
|
/// Description: read Multi30kDataset data and get data.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kSuccessEn) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kSuccessEn.";
|
||||||
|
// Test Multi30k English files with default parameters
|
||||||
|
|
||||||
|
// Create a Multi30k dataset
|
||||||
|
std::string en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
|
||||||
|
// test train
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("text"), row.end());
|
||||||
|
std::vector<std::string> expected = {"This is the first English sentence in train.",
|
||||||
|
"This is the second English sentence in train.",
|
||||||
|
"This is the third English sentence in train."};
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["text"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// test valid
|
||||||
|
usage = "valid";
|
||||||
|
expected = {"This is the first English sentence in valid.",
|
||||||
|
"This is the second English sentence in valid."};
|
||||||
|
|
||||||
|
ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
|
||||||
|
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
iter = ds->CreateIterator();
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("text"), row.end());
|
||||||
|
|
||||||
|
i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["text"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 2);
|
||||||
|
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// test test
|
||||||
|
usage = "test";
|
||||||
|
expected = {"This is the first English sentence in test.",
|
||||||
|
"This is the second English sentence in test.",
|
||||||
|
"This is the third English sentence in test."};
|
||||||
|
|
||||||
|
ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
|
||||||
|
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
iter = ds->CreateIterator();
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("text"), row.end());
|
||||||
|
|
||||||
|
i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["text"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test Multi30k Dataset(Germany).
|
||||||
|
/// Description: read Multi30kDataset data and get data.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kSuccessDe) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kSuccessDe.";
|
||||||
|
// Test Multi30k Germany files with default parameters
|
||||||
|
|
||||||
|
// Create a Multi30k dataset
|
||||||
|
std::string en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
|
||||||
|
// test train
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("translation"), row.end());
|
||||||
|
std::vector<std::string> expected = {"This is the first Germany sentence in train.",
|
||||||
|
"This is the second Germany sentence in train.",
|
||||||
|
"This is the third Germany sentence in train."};
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["translation"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// test valid
|
||||||
|
usage = "valid";
|
||||||
|
expected = {"This is the first Germany sentence in valid.",
|
||||||
|
"This is the second Germany sentence in valid."};
|
||||||
|
|
||||||
|
ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
|
||||||
|
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
iter = ds->CreateIterator();
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("translation"), row.end());
|
||||||
|
|
||||||
|
i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["translation"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 2);
|
||||||
|
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// test test
|
||||||
|
usage = "test";
|
||||||
|
expected = {"This is the first Germany sentence in test.",
|
||||||
|
"This is the second Germany sentence in test.",
|
||||||
|
"This is the third Germany sentence in test."};
|
||||||
|
|
||||||
|
ds = Multi30k(en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
|
||||||
|
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
iter = ds->CreateIterator();
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("translation"), row.end());
|
||||||
|
|
||||||
|
i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["translation"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test Multi30k Dataset(Germany).
|
||||||
|
/// Description: read Multi30kDataset data and get data.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetBasicWithPipeline) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetBasicWithPipeline.";
|
||||||
|
|
||||||
|
// Create two Multi30kFile Dataset, with single Multi30k file
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds1 = Multi30k(train_en_file, usage, {"en", "de"}, 2, ShuffleMode::kFalse);
|
||||||
|
std::shared_ptr<Dataset> ds2 = Multi30k(train_en_file, usage, {"en", "de"}, 2, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
EXPECT_NE(ds2, nullptr);
|
||||||
|
|
||||||
|
// Create two Repeat operation on ds
|
||||||
|
int32_t repeat_num = 2;
|
||||||
|
ds1 = ds1->Repeat(repeat_num);
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
repeat_num = 3;
|
||||||
|
ds2 = ds2->Repeat(repeat_num);
|
||||||
|
EXPECT_NE(ds2, nullptr);
|
||||||
|
|
||||||
|
// Create two Project operation on ds
|
||||||
|
std::vector<std::string> column_project = {"text"};
|
||||||
|
ds1 = ds1->Project(column_project);
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
ds2 = ds2->Project(column_project);
|
||||||
|
EXPECT_NE(ds2, nullptr);
|
||||||
|
|
||||||
|
// Create a Concat operation on the ds
|
||||||
|
ds1 = ds1->Concat({ds2});
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("text"), row.end());
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["text"];
|
||||||
|
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
|
||||||
|
i++;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect 10 samples
|
||||||
|
EXPECT_EQ(i, 10);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test Getters.
|
||||||
|
/// Description: includes tests for shape, type, size.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kGetters) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kGetters.";
|
||||||
|
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 2, ShuffleMode::kFalse);
|
||||||
|
std::vector<std::string> column_names = {"text","translation"};
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
EXPECT_EQ(ds->GetDatasetSize(), 2);
|
||||||
|
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test Multi30kDataset in distribution.
|
||||||
|
/// Description: test interface in a distributed state.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetDistribution) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetDistribution.";
|
||||||
|
|
||||||
|
// Create a Multi30kFile Dataset, with single Multi30k file
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kGlobal, 3, 2);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("text"), row.end());
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["text"];
|
||||||
|
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
|
||||||
|
i++;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect 1 samples
|
||||||
|
EXPECT_EQ(i, 1);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Error Test.
|
||||||
|
/// Description: test the wrong input.
|
||||||
|
/// Expectation: unable to read in data.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidFilePath) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvalidFilePath.";
|
||||||
|
|
||||||
|
// Create a Multi30k Dataset
|
||||||
|
// with invalid file path
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/invalid/file.path";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Error Test.
|
||||||
|
/// Description: test the wrong input.
|
||||||
|
/// Expectation: unable to read in data.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidUsage) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvaildUsage.";
|
||||||
|
|
||||||
|
// Create a Multi30k Dataset
|
||||||
|
// with invalid usage
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "invalid_usage";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Error Test.
|
||||||
|
/// Description: test the wrong input.
|
||||||
|
/// Expectation: unable to read in data.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidLanguagePair) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailLanguagePair.";
|
||||||
|
|
||||||
|
// Create a Multi30k Dataset
|
||||||
|
// with invalid usage
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::vector<std::string> language_pair0 = {"ch", "ja"};
|
||||||
|
std::shared_ptr<Dataset> ds0 = Multi30k(train_en_file, usage, language_pair0);
|
||||||
|
EXPECT_NE(ds0, nullptr);
|
||||||
|
std::vector<std::string> language_pair1 = {"en", "de", "aa"};
|
||||||
|
std::shared_ptr<Dataset> ds1 = Multi30k(train_en_file, usage, language_pair1);
|
||||||
|
EXPECT_NE(ds1, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Error Test.
|
||||||
|
/// Description: test the wrong input.
|
||||||
|
/// Expectation: unable to read in data.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidNumSamples) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvalidNumSamples.";
|
||||||
|
|
||||||
|
// Create a Multi30k Dataset
|
||||||
|
// with invalid samplers=-1
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, -1);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: TextFile number of samples cannot be negative
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Error Test.
|
||||||
|
/// Description: test the wrong input.
|
||||||
|
/// Expectation: unable to read in data.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidShards) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvalidShards.";
|
||||||
|
|
||||||
|
// Create a Multi30k Dataset
|
||||||
|
// with invalid shards.
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse, 2, 3);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: TextFile number of samples cannot be negative
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Error Test.
|
||||||
|
/// Description: test the wrong input.
|
||||||
|
/// Expectation: unable to read in data.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetFailInvalidShardID) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetFailInvalidShardID.";
|
||||||
|
|
||||||
|
// Create a Multi30k Dataset
|
||||||
|
// with invalid shard ID.
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse, 0, -1);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure: TextFile number of samples cannot be negative
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Error Test.
|
||||||
|
/// Description: test the wrong input.
|
||||||
|
/// Expectation: unable to read in data.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetLanguagePair) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetLanguagePair.";
|
||||||
|
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"de", "en"}, 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("translation"), row.end());
|
||||||
|
std::vector<std::string> expected = {"This is the first English sentence in train.",
|
||||||
|
"This is the second English sentence in train.",
|
||||||
|
"This is the third English sentence in train."};
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["translation"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test Multi30k Dataset(shufflemode=kFalse).
|
||||||
|
/// Description: test Multi30k Dataset interface with different ShuffleMode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetShuffleFilesFalse) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetShuffleFilesFalse.";
|
||||||
|
|
||||||
|
// Set configuration
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(135);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||||
|
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kFalse);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("text"), row.end());
|
||||||
|
std::vector<std::string> expected = {"This is the first English sentence in train.",
|
||||||
|
"This is the second English sentence in train.",
|
||||||
|
"This is the third English sentence in train."};
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["text"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||||
|
// Compare against expected result
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
i++;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test Multi30k Dataset(shufflemode=kFiles).
|
||||||
|
/// Description: test Multi30k Dataset interface with different ShuffleMode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetShuffleFilesFiles) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetShuffleFilesFiles.";
|
||||||
|
|
||||||
|
// Set configuration
|
||||||
|
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||||
|
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||||
|
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||||
|
GlobalContext::config_manager()->set_seed(135);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||||
|
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kFiles);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("text"), row.end());
|
||||||
|
std::vector<std::string> expected = {"This is the first English sentence in train.",
|
||||||
|
"This is the second English sentence in train.",
|
||||||
|
"This is the third English sentence in train."};
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["text"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||||
|
// Compare against expected result
|
||||||
|
EXPECT_STREQ(ss.c_str(), expected[i].c_str());
|
||||||
|
i++;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
|
||||||
|
// Restore configuration
|
||||||
|
GlobalContext::config_manager()->set_seed(original_seed);
|
||||||
|
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test Multi30k Dataset(shufflemode=kGlobal).
|
||||||
|
/// Description: test Multi30k Dataset interface with different ShuffleMode.
|
||||||
|
/// Expectation: the data is processed successfully.
|
||||||
|
TEST_F(MindDataTestPipeline, TestMulti30kDatasetShuffleFilesGlobal) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMulti30kDatasetShuffleFilesGlobal.";
|
||||||
|
|
||||||
|
std::string train_en_file = datasets_root_path_ + "/testMulti30kDataset";
|
||||||
|
std::string usage = "train";
|
||||||
|
std::shared_ptr<Dataset> ds = Multi30k(train_en_file, usage, {"en", "de"}, 0, ShuffleMode::kGlobal);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create an iterator over the result of the above dataset.
|
||||||
|
// This will trigger the creation of the Execution Tree and launch it.
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(iter, nullptr);
|
||||||
|
|
||||||
|
// Iterate the dataset and get each row
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
EXPECT_NE(row.find("text"), row.end());
|
||||||
|
|
||||||
|
uint64_t i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto text = row["text"];
|
||||||
|
std::shared_ptr<Tensor> de_text;
|
||||||
|
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
|
||||||
|
std::string_view sv;
|
||||||
|
ASSERT_OK(de_text->GetItemAt(&sv, {}));
|
||||||
|
std::string ss(sv);
|
||||||
|
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||||
|
i++;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(i, 3);
|
||||||
|
|
||||||
|
// Manually terminate the pipeline
|
||||||
|
iter->Stop();
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
This is the first Germany sentence in test.
|
||||||
|
This is the second Germany sentence in test.
|
||||||
|
This is the third Germany sentence in test.
|
|
@ -0,0 +1,3 @@
|
||||||
|
This is the first English sentence in test.
|
||||||
|
This is the second English sentence in test.
|
||||||
|
This is the third English sentence in test.
|
|
@ -0,0 +1,3 @@
|
||||||
|
This is the first Germany sentence in train.
|
||||||
|
This is the second Germany sentence in train.
|
||||||
|
This is the third Germany sentence in train.
|
|
@ -0,0 +1,3 @@
|
||||||
|
This is the first English sentence in train.
|
||||||
|
This is the second English sentence in train.
|
||||||
|
This is the third English sentence in train.
|
|
@ -0,0 +1,2 @@
|
||||||
|
This is the first Germany sentence in valid.
|
||||||
|
This is the second Germany sentence in valid.
|
|
@ -0,0 +1,2 @@
|
||||||
|
This is the first English sentence in valid.
|
||||||
|
This is the second English sentence in valid.
|
|
@ -0,0 +1,267 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.text.transforms as a_c_trans
|
||||||
|
from mindspore import log as logger
|
||||||
|
from util import config_get_set_num_parallel_workers, config_get_set_seed
|
||||||
|
|
||||||
|
INVALID_FILE = '../data/dataset/testMulti30kDataset/invalid_dir'
|
||||||
|
DATA_ALL_FILE = '../data/dataset/testMulti30kDataset'
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_file_multi30k_text():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset.
|
||||||
|
Description: read data from a single file.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
original_seed = config_get_set_seed(987)
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage="train", shuffle=False)
|
||||||
|
count = 0
|
||||||
|
line = ["This is the first English sentence in train.",
|
||||||
|
"This is the second English sentence in train.",
|
||||||
|
"This is the third English sentence in train."
|
||||||
|
]
|
||||||
|
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
strs = i["text"].item().decode("utf8")
|
||||||
|
assert strs == line[count]
|
||||||
|
count += 1
|
||||||
|
assert count == 3
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_file_multi30k_translation():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset.
|
||||||
|
Description: read data from a single file.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
original_seed = config_get_set_seed(987)
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage="train", shuffle=False)
|
||||||
|
count = 0
|
||||||
|
line = ["This is the first Germany sentence in train.",
|
||||||
|
"This is the second Germany sentence in train.",
|
||||||
|
"This is the third Germany sentence in train."
|
||||||
|
]
|
||||||
|
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
strs = i["translation"].item().decode("utf8")
|
||||||
|
assert strs == line[count]
|
||||||
|
count += 1
|
||||||
|
assert count == 3
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_file_multi30k():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset.
|
||||||
|
Description: read data from all file.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE)
|
||||||
|
count = 0
|
||||||
|
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
logger.info("{}".format(i["text"]))
|
||||||
|
count += 1
|
||||||
|
assert count == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_num_samples_none():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset(num_samples = default).
|
||||||
|
Description: test get num_samples.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||||
|
original_seed = config_get_set_seed(987)
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, shuffle=False)
|
||||||
|
count = 0
|
||||||
|
line = ["This is the first English sentence in test.",
|
||||||
|
"This is the second English sentence in test.",
|
||||||
|
"This is the third English sentence in test.",
|
||||||
|
"This is the first English sentence in train.",
|
||||||
|
"This is the second English sentence in train.",
|
||||||
|
"This is the third English sentence in train.",
|
||||||
|
"This is the first English sentence in valid.",
|
||||||
|
"This is the second English sentence in valid."
|
||||||
|
]
|
||||||
|
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
strs = i["text"].item().decode("utf8")
|
||||||
|
assert strs == line[count]
|
||||||
|
count += 1
|
||||||
|
assert count == 8
|
||||||
|
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||||
|
ds.config.set_seed(original_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_shards_multi30k():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset(num_shards = 3).
|
||||||
|
Description: test get num_samples.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage='train', num_shards=3, shard_id=1)
|
||||||
|
count = 0
|
||||||
|
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
logger.info("{}".format(i["text"]))
|
||||||
|
count += 1
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi30k_dataset_num_samples():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset(num_samples = 2).
|
||||||
|
Description: test get num_samples.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage="test", num_samples=2)
|
||||||
|
count = 0
|
||||||
|
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
count += 1
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi30k_dataset_shuffle_files():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset.
|
||||||
|
Description: test get all files.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, shuffle=True)
|
||||||
|
count = 0
|
||||||
|
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
count += 1
|
||||||
|
assert count == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi30k_dataset_shuffle_false():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset (shuffle = false).
|
||||||
|
Description: test get all files.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, shuffle=False)
|
||||||
|
count = 0
|
||||||
|
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
count += 1
|
||||||
|
assert count == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi30k_dataset_repeat():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k in distribution (repeat 3 times).
|
||||||
|
Description: test in a distributed state.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, usage='train')
|
||||||
|
dataset = dataset.repeat(3)
|
||||||
|
count = 0
|
||||||
|
for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
count += 1
|
||||||
|
assert count == 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi30k_dataset_get_datasetsize():
|
||||||
|
"""
|
||||||
|
Feature: Test Getters.
|
||||||
|
Description: test get_dataset_size of Multi30k dataset.
|
||||||
|
Expectation: the data is processed successfully.
|
||||||
|
"""
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE)
|
||||||
|
size = dataset.get_dataset_size()
|
||||||
|
assert size == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi30k_dataset_exceptions():
|
||||||
|
"""
|
||||||
|
Feature: Test Multi30k Dataset.
|
||||||
|
Description: Test exceptions.
|
||||||
|
Expectation: Exception thrown to be caught
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = ds.Multi30kDataset(INVALID_FILE)
|
||||||
|
assert "The folder ../data/dataset/testMulti30kDataset/invalid_dir does not exist or is not a directory or" \
|
||||||
|
" permission denied" in str(error_info.value)
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = ds.Multi30kDataset(DATA_ALL_FILE, usage="INVALID")
|
||||||
|
assert "Input usage is not within the valid set of ['train', 'test', 'valid', 'all']." in str(error_info.value)
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = ds.Multi30kDataset(DATA_ALL_FILE, usage="test", language_pair=["ch", "ja"])
|
||||||
|
assert "language_pair can only be ['en', 'de'] or ['en', 'de'], but got ['ch', 'ja']" in str(error_info.value)
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = ds.Multi30kDataset(DATA_ALL_FILE, usage="test", language_pair=["en", "en", "de"])
|
||||||
|
assert "language_pair should be a list or tuple of length 2, but got 3" in str(error_info.value)
|
||||||
|
with pytest.raises(ValueError) as error_info:
|
||||||
|
_ = ds.Multi30kDataset(DATA_ALL_FILE, usage='test', num_samples=-1)
|
||||||
|
assert "num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)!" in str(error_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi30k_dataset_en_pipeline():
|
||||||
|
"""
|
||||||
|
Feature: Multi30kDataset
|
||||||
|
Description: test Multi30kDataset in pipeline mode
|
||||||
|
Expectation: the data is processed successfully
|
||||||
|
"""
|
||||||
|
expected = ["this is the first english sentence in train.",
|
||||||
|
"this is the second english sentence in train.",
|
||||||
|
"this is the third english sentence in train."]
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, 'train', shuffle=False)
|
||||||
|
filter_wikipedia_xml_op = a_c_trans.CaseFold()
|
||||||
|
dataset = dataset.map(input_columns=["text"], operations=filter_wikipedia_xml_op, num_parallel_workers=1)
|
||||||
|
count = 0
|
||||||
|
for i in dataset.create_dict_iterator(output_numpy=True):
|
||||||
|
strs = i["text"].item().decode("utf8")
|
||||||
|
assert strs == expected[count]
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi30k_dataset_de_pipeline():
|
||||||
|
"""
|
||||||
|
Feature: Multi30kDataset
|
||||||
|
Description: test Multi30kDataset in pipeline mode
|
||||||
|
Expectation: the data is processed successfully
|
||||||
|
"""
|
||||||
|
expected = ["this is the first germany sentence in train.",
|
||||||
|
"this is the second germany sentence in train.",
|
||||||
|
"this is the third germany sentence in train."]
|
||||||
|
dataset = ds.Multi30kDataset(DATA_ALL_FILE, 'train', shuffle=False)
|
||||||
|
filter_wikipedia_xml_op = a_c_trans.CaseFold()
|
||||||
|
dataset = dataset.map(input_columns=["translation"], operations=filter_wikipedia_xml_op, num_parallel_workers=1)
|
||||||
|
count = 0
|
||||||
|
for i in dataset.create_dict_iterator(output_numpy=True):
|
||||||
|
strs = i["translation"].item().decode("utf8")
|
||||||
|
assert strs == expected[count]
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_data_file_multi30k_text()
|
||||||
|
test_data_file_multi30k_translation()
|
||||||
|
test_all_file_multi30k()
|
||||||
|
test_dataset_num_samples_none()
|
||||||
|
test_num_shards_multi30k()
|
||||||
|
test_multi30k_dataset_num_samples()
|
||||||
|
test_multi30k_dataset_shuffle_files()
|
||||||
|
test_multi30k_dataset_shuffle_false()
|
||||||
|
test_multi30k_dataset_repeat()
|
||||||
|
test_multi30k_dataset_get_datasetsize()
|
||||||
|
test_multi30k_dataset_exceptions()
|
||||||
|
test_multi30k_dataset_en_pipeline()
|
||||||
|
test_multi30k_dataset_de_pipeline()
|
Loading…
Reference in New Issue