!22553 [assistant][ops] Add new loader UDPOSDataset

Merge pull request !22553 from 杨旭华/UDPOSDataset
This commit is contained in:
i-robot 2021-12-20 06:54:54 +00:00 committed by Gitee
commit 953920112c
18 changed files with 1705 additions and 1 deletions

View File

@ -121,6 +121,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/usps_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yahoo_answers_node.h"
@ -1685,6 +1686,14 @@ TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_f
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
UDPOSDataset::UDPOSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<UDPOSNode>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle,
num_shards, shard_id, cache);
ir_node_ = std::static_pointer_cast<UDPOSNode>(ds);
}
YahooAnswersDataset::YahooAnswersDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {

View File

@ -51,6 +51,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yahoo_answers_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yelp_review_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h"
@ -528,6 +529,18 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(UDPOSNode, 2, ([](const py::module *m) {
(void)py::class_<UDPOSNode, DatasetNode, std::shared_ptr<UDPOSNode>>(*m, "UDPOSNode",
"to create an UDPOSNode")
.def(py::init([](std::string dataset_dir, std::string usage, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id) {
std::shared_ptr<UDPOSNode> udpos = std::make_shared<UDPOSNode>(
dataset_dir, usage, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(udpos->ValidateParams());
return udpos;
}));
}));
PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
(void)py::class_<USPSNode, DatasetNode, std::shared_ptr<USPSNode>>(*m, "USPSNode",
"to create an USPSNode")

View File

@ -36,6 +36,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
stl10_op.cc
tedlium_op.cc
text_file_op.cc
udpos_op.cc
usps_op.cc
yahoo_answers_op.cc
yelp_review_op.cc

View File

@ -0,0 +1,170 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/udpos_op.h"
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/wait_post.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace dataset {
UDPOSOp::UDPOSOp(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, const std::vector<std::string> &udpos_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
: TextFileOp(num_workers, total_rows, worker_connector_size, std::move(schema), udpos_files_list, op_connector_size,
shuffle_files, num_devices, device_id) {}
// A print method typically used for debugging.
void UDPOSOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nUDPOS files list:\n";
for (size_t i = 0; i < text_files_list_.size(); ++i) {
out << " " << text_files_list_[i];
}
out << "\nData Schema:\n";
out << *data_schema_ << "\n\n";
}
}
Status UDPOSOp::LoadTensor(const std::vector<std::string> &column, TensorRow *out_row, size_t index) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(column, &tensor));
(*out_row)[index] = std::move(tensor);
return Status::OK();
}
// Function to split string based on a character delimiter.
std::vector<std::string> UDPOSOp::Split(const std::string &s, char delim) {
std::vector<std::string> res;
std::stringstream ss(s);
std::string item;
while (getline(ss, item, delim)) {
res.push_back(item);
}
return res;
}
// Removes excess space before and after the string.
std::string UDPOSOp::Strip(const std::string &str) {
size_t strlen = str.size();
size_t i, j;
i = 0;
while (i < strlen && str[i] == ' ') {
i++;
}
j = strlen - 1;
while (j >= i && str[j] == ' ') {
j--;
}
j++;
if (i == 0 && j == strlen) {
return str;
} else {
return str.substr(i, j - i);
}
}
Status UDPOSOp::Load(const std::vector<std::string> &word, const std::vector<std::string> &universal,
const std::vector<std::string> &stanford, const std::string &file, int32_t worker_id) {
size_t row_line = 3;
size_t word_line = 0, universal_line = 1, stanford_line = 2;
TensorRow tRow(row_line, nullptr);
// Add file path info.
std::vector<std::string> file_path(row_line, file);
tRow.setPath(file_path);
RETURN_IF_NOT_OK(LoadTensor(word, &tRow, word_line));
RETURN_IF_NOT_OK(LoadTensor(universal, &tRow, universal_line));
RETURN_IF_NOT_OK(LoadTensor(stanford, &tRow, stanford_line));
RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow)));
return Status::OK();
}
Status UDPOSOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
auto realpath = FileUtils::GetRealPath(file.data());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " + DatasetName() + " dataset dir: " << file << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, " + DatasetName() + " dataset dir: " + file + " does not exist.");
}
std::ifstream handle(realpath.value());
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + ": " + file);
}
int64_t rows_total = 0;
std::string line;
std::vector<std::string> word_column;
std::vector<std::string> universal_column;
std::vector<std::string> stanford_column;
while (getline(handle, line)) {
if (line.empty() && rows_total < start_offset) {
continue;
}
// If read to the end offset of this file, break.
if (rows_total >= end_offset) {
if (word_column.size() != 0) {
RETURN_IF_NOT_OK(Load(word_column, universal_column, stanford_column, file, worker_id));
}
std::vector<std::string>().swap(word_column);
std::vector<std::string>().swap(universal_column);
std::vector<std::string>().swap(stanford_column);
break;
}
// Skip line before start offset.
if (rows_total < start_offset) {
rows_total++;
continue;
}
line = Strip(line);
if (line.empty() && rows_total >= start_offset) {
if (word_column.size() != 0) {
RETURN_IF_NOT_OK(Load(word_column, universal_column, stanford_column, file, worker_id));
}
std::vector<std::string>().swap(word_column);
std::vector<std::string>().swap(universal_column);
std::vector<std::string>().swap(stanford_column);
continue;
} else if (!line.empty() && rows_total >= start_offset) {
std::vector<std::string> column = Split(line, '\t');
size_t word_line = 0, universal_line = 1, stanford_line = 2;
word_column.push_back(column[word_line]);
universal_column.push_back(column[universal_line]);
stanford_column.push_back(column[stanford_line]);
}
rows_total++;
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,96 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_UDPOS_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_UDPOS_OP_H_
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/util/queue.h"
namespace mindspore {
namespace dataset {
class JaggedConnector;
class UDPOSOp : public TextFileOp {
public:
/// \Constructor of UDPOSOp.
UDPOSOp(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr<DataSchema>,
const std::vector<std::string> &udpos_files_list, int32_t op_connector_size, bool shuffle_files,
int32_t num_devices, int32_t device_id);
/// \Default destructor.
~UDPOSOp() = default;
/// \brief A print method typically used for debugging.
/// \param[in] out The output stream to write output to.
/// \param[in] show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
/// \brief Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "UDPOSOp"; }
// DatasetName name getter
/// \param[in] upper Needs to be capitalized or not
// \return DatasetName of the current Op
std::string DatasetName(bool upper = false) const { return upper ? "UDPOS" : "udpos"; }
private:
/// \brief Parses a single row and puts the data into multiple TensorRows.
/// \param[in] column The content of the column.
/// \param[in] out_row The tensor table to put the parsed data in.
/// \param[in] index Serial number of column.
/// \return Status The error code returned.
Status LoadTensor(const std::vector<std::string> &column, TensorRow *out_row, size_t index);
/// \brief Removes excess space before and after the string.
/// \param[in] str The input string.
/// \return A string.
std::string Strip(const std::string &str);
/// \brief Split string based on a character delimiter.
/// \param[in] s The input string.
/// \param[in] delim Symbols for separating string.
/// \return A string vector.
std::vector<std::string> Split(const std::string &s, char delim);
/// \brief Specify that the corresponding data is translated into Tensor.
/// \param[in] word A list of words in a sentence.
/// \param[in] universal General part of speech.
/// \param[in] stanford Stanford part of speech.
/// \param[in] file The file to read.
/// \param[in] worker_id The id of the worker that is executing this function.
/// \return Status The error code returned.
Status Load(const std::vector<std::string> &word, const std::vector<std::string> &universal,
const std::vector<std::string> &stanford, const std::string &file, int32_t worker_id);
/// \brief Reads a text file and loads the data into multiple TensorRows.
/// \param file The file to read.
/// \param start_offset The start offset of file.
/// \param end_offset The end offset of file.
/// \param worker_id The id of the worker that is executing this function.
/// \return Status The error code returned.
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_UDPOS_OP_H_

View File

@ -112,6 +112,7 @@ constexpr char kSTL10Node[] = "STL10Dataset";
constexpr char kTedliumNode[] = "TedliumDataset";
constexpr char kTextFileNode[] = "TextFileDataset";
constexpr char kTFRecordNode[] = "TFRecordDataset";
constexpr char kUDPOSNode[] = "UDPOSDataset";
constexpr char kUSPSNode[] = "USPSDataset";
constexpr char kVOCNode[] = "VOCDataset";
constexpr char kYahooAnswersNode[] = "YahooAnswersDataset";

View File

@ -38,6 +38,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
tedlium_node.cc
text_file_node.cc
tf_record_node.cc
udpos_node.cc
usps_node.cc
voc_node.cc
yahoo_answers_node.cc

View File

@ -0,0 +1,196 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/udpos_node.h"
#include <algorithm>
#include <utility>
#include "minddata/dataset/engine/datasetops/source/udpos_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for UDPOSNode.
UDPOSNode::UDPOSNode(const std::string &dataset_dir, const std::string &usage, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: NonMappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id),
udpos_files_list_(WalkAllFiles(usage, dataset_dir)) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
std::shared_ptr<DatasetNode> UDPOSNode::Copy() {
auto node = std::make_shared<UDPOSNode>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
return node;
}
void UDPOSNode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
}
Status UDPOSNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("UDPOSNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("UDPOSNode", usage_, {"train", "test", "valid", "all"}));
RETURN_IF_NOT_OK(ValidateScalar("UDPOSNode", "num_samples", num_samples_, {0}, false));
RETURN_IF_NOT_OK(ValidateDatasetShardParams("UDPOSNode", num_shards_, shard_id_));
return Status::OK();
}
// Function to build UDPOSNode.
Status UDPOSNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Sort the dataset files in a lexicographical order.
std::vector<std::string> sorted_dataset_files = udpos_files_list_;
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("word", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("universal", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("stanford", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
// Create and initialize UDPOSOp.
std::shared_ptr<UDPOSOp> udpos_op =
std::make_shared<UDPOSOp>(num_workers_, num_samples_, worker_connector_size_, std::move(schema),
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(udpos_op->Init());
// If a global shuffle is used for UDPOS, it will inject a shuffle op over the UDPOS.
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset UDPOS's shuffle
// option to false.
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp.
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset.
RETURN_IF_NOT_OK(UDPOSOp::CountAllFileRows(sorted_dataset_files, &num_rows));
// Add the shuffle op after this op.
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
udpos_op->SetTotalRepeats(GetTotalRepeats());
udpos_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
// Add UDPOSOp.
node_ops->push_back(udpos_op);
return Status::OK();
}
// Get the shard id of node.
Status UDPOSNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
// Get Dataset size.
Status UDPOSNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size = num_samples_;
RETURN_IF_NOT_OK(UDPOSOp::CountAllFileRows(udpos_files_list_, &num_rows));
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status UDPOSNode::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// UDPOS by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status UDPOSNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this UDPOS node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the UDPOS node need to be reset to its defaults so
// that this UDPOS node will produce the full set of data into the cache.
Status UDPOSNode::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
std::vector<std::string> UDPOSNode::WalkAllFiles(const std::string &usage, const std::string &dataset_dir) {
std::vector<std::string> udpos_files_list;
const std::string train_prefix = "en-ud-tag.v2.train.txt";
const std::string test_prefix = "en-ud-tag.v2.test.txt";
const std::string valid_prefix = "en-ud-tag.v2.dev.txt";
if (usage == "train") {
udpos_files_list.push_back(dataset_dir + train_prefix);
} else if (usage == "test") {
udpos_files_list.push_back(dataset_dir + test_prefix);
} else if (usage == "valid") {
udpos_files_list.push_back(dataset_dir + valid_prefix);
} else {
udpos_files_list.push_back(dataset_dir + train_prefix);
udpos_files_list.push_back(dataset_dir + test_prefix);
udpos_files_list.push_back(dataset_dir + valid_prefix);
}
return udpos_files_list;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,120 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_UDPOS_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_UDPOS_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
/// \class UDPOSNode.
/// \brief A Dataset derived class to represent UDPOS dataset.
class UDPOSNode : public NonMappableSourceNode {
public:
/// \brief Constructor.
UDPOSNode(const std::string &dataset_dir, const std::string &usage, int32_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~UDPOSNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return "UDPOSNode"; }
/// \brief Print the description.
/// \param[out] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief A base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size The size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
int32_t NumSamples() const { return num_samples_; }
int32_t NumShards() const { return num_shards_; }
int32_t ShardId() const { return shard_id_; }
ShuffleMode Shuffle() const { return shuffle_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief UDPOS by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this UDPOS node, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the UDPOS node need to be reset to its defaults
/// so that this UDPOS node will produce the full set of data into the cache.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function.
Status MakeSimpleProducer() override;
/// \Read all files in the directory.
/// \param[in] usage Part of dataset of UDPOS.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \return Status The status code returned.
std::vector<std::string> WalkAllFiles(const std::string &usage, const std::string &dataset_dir);
private:
std::string dataset_dir_;
std::string usage_;
int32_t num_samples_;
int32_t num_shards_;
int32_t shard_id_;
ShuffleMode shuffle_;
std::vector<std::string> udpos_files_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_UDPOS_NODE_H_

View File

@ -4280,6 +4280,68 @@ std::shared_ptr<TFRecordDataset> MS_API TFRecord(const std::vector<std::string>
return ds;
}
/// \class UDPOSDataset
/// \brief A source dataset for reading and parsing UDPOS dataset.
class MS_API UDPOSDataset : public Dataset {
public:
/// \brief Constructor of UDPOS Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data list txt file to be read, can be "train", "test", 'valid' or "all".
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
UDPOSDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of UDPOSDataset.
~UDPOSDataset() = default;
};
/// \brief Function to create a UDPOSDataset.
/// \note The generated dataset has three column ['word','universal','stanford'].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of UDPOS, can be "train", "test", "valid" or "all" (default="all").
/// \param[in] num_samples The number of samples to be included in the dataset
/// (Default = 0, means all samples).
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1).
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified (Default = 0).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the UDPOSDataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string folder_path = "/path/to/udpos_dataset_directory";
/// std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "test", 0, ShuffleMode::kGlobal);
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In UDPOS dataset, each dictionary has keys "word", "universal", "stanford" */
/// auto word = row["word"];
/// \endcode
inline std::shared_ptr<UDPOSDataset> MS_API UDPOS(const std::string &dataset_dir, const std::string &usage = "all",
int64_t num_samples = 0, ShuffleMode shuffle = ShuffleMode::kGlobal,
int32_t num_shards = 1, int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<UDPOSDataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle,
num_shards, shard_id, cache);
}
/// \class USPSDataset
/// \brief A source dataset that reads and parses USPS datasets.
class MS_API USPSDataset : public Dataset {

View File

@ -72,7 +72,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \
check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset, check_svhn_dataset, \
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size, get_auto_offload
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -6176,6 +6176,64 @@ class Schema:
return schema_obj.cpp_schema.get_num_rows()
class UDPOSDataset(SourceDataset):
"""
A source dataset that reads and parses UDPOS dataset.
The generated dataset has three columns: :py:obj:`[word, universal, stanford]`.
The tensor of column :py:obj:`word` is of the string type.
The tensor of column :py:obj:`universal` is of the string type.
The tensor of column :py:obj:`stanford` is of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be `train`, `test`, `valid` or `all`. `train` will read from
12,543 train samples, `test` will read from 2,077 test samples, `valid` will read from 2,002 test samples,
`all` will read from all 16,622 samples (default=None, all samples).
num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
(default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the max sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
Raises:
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
Examples:
>>> udpos_dataset_dir = "/path/to/udpos_dataset_dir"
>>> dataset = ds.UDPOSDataset(dataset_files=udpos_dataset_dir, usage='all')
"""
@check_udpos_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, shuffle=Shuffle.GLOBAL, num_shards=None,
shard_id=None, num_parallel_workers=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, 'all')
def parse(self, children=None):
return cde.UDPOSNode(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag, self.num_shards,
self.shard_id)
class USPSDataset(SourceDataset):
"""
A source dataset for reading and parsing the USPS dataset.

View File

@ -404,6 +404,35 @@ def check_tfrecorddataset(method):
return new_method
def check_udpos_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(UDPOSDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
# check dataset_dir; required argument
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
# check usage
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_usps_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(USPSDataset)."""

View File

@ -48,6 +48,7 @@ SET(DE_UT_SRCS
c_api_dataset_tedlium_test.cc
c_api_dataset_textfile_test.cc
c_api_dataset_tfrecord_test.cc
c_api_dataset_udpos_test.cc
c_api_dataset_usps_test.cc
c_api_dataset_voc_test.cc
c_api_dataset_yahoo_answers_test.cc

View File

@ -0,0 +1,574 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
using mindspore::dataset::ShuffleMode;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: Test UDPOS Dataset.
/// Description: read data from a single file.
/// Expectation: three data in one file.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetBasic.";
// Test UDPOS Dataset with single UDPOS file and many default inputs.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create a UDPOS Dataset, with single UDPOS file.
// Note: en-ud-tag.v2.valid.txt has 3 rows.
// Use 2 samples.
// Use defaults for other input parameters.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::vector<std::string> column_names = {"word", "universal", "stanford"};
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "valid", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("word"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"From", "Abed", "Ido"}, {"Psg", "Psg", "Nine"}, {"Bus", "Psg", "Nine"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto word = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {{}}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 3);
// Expect 3 samples.
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test UDPOS Dataset.
/// Description: repeat read data.
/// Expectation: five times the read-in data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetBasicWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetBasicWithPipeline.";
// Test UDPOS Dataset with single UDPOS file and many default inputs.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create two UDPOSDataset, with single UDPOS file.
// Note: en-ud-tag.v2.test.txt has 3 rows.
// Use 2 samples.
// Use defaults for other input parameters.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds1 = UDPOS(dataset_dir, "test", 0, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds2 = UDPOS(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds.
int32_t repeat_num = 2;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 3;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds.
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
std::vector<std::string> column_names = {"word", "universal", "stanford"};
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("word"), row.end());
std::vector<std::vector<std::string>> expected_result = {{"What", "Psg", "What"}};
uint64_t i = 0;
while (row.size() != 0) {
auto word = row["word"];
MS_LOG(INFO) << "Tensor word shape: " << word.Shape();
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 5 samples.
EXPECT_EQ(i, 5);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test UDPOS Dataset.
/// Description: Includes tests for shape, type, size.
/// Expectation: correct shape, type, size.
TEST_F(MindDataTestPipeline, TestUDPOSGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSGetters.";
// Test UDPOS Dataset with single UDPOS file and many default inputs.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(987);
GlobalContext::config_manager()->set_num_parallel_workers(4);
// Create a UDPOS Dataset, with single UDPOS file.
// Note: en-ud-tag.v2.test.txt has 1 rows.
// Use 2 samples.
// Use defaults for other input parameters.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "train", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
std::vector<std::string> column_names = {"word", "universal", "stanford"};
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
EXPECT_EQ(types.size(), 3);
EXPECT_EQ(types[0].ToString(), "string");
EXPECT_EQ(types[1].ToString(), "string");
EXPECT_EQ(types[2].ToString(), "string");
EXPECT_EQ(shapes.size(), 3);
EXPECT_EQ(shapes[0].ToString(), "<6>");
EXPECT_EQ(shapes[1].ToString(), "<6>");
EXPECT_EQ(shapes[2].ToString(), "<6>");
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ds->GetColumnNames(), column_names);
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test UDPOS Dataset.
/// Description: test with samplers=-1.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetInvalidSamplers) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetInvalidSamplers.";
// Create a UDPOS Dataset.
// With invalid samplers=-1.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "test", -1, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: UDPOS number of samples cannot be negative.
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test UDPOS Dataset.
/// Description: test with wrongful empty dataset_files.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetInvalidFilePath) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetInvalidFilePath.";
// Attempt to create a UDPOS Dataset.
// With wrongful empty dataset_files input.
std::shared_ptr<Dataset> ds = UDPOS("NotExistFile", "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: dataset_files is not specified.
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test UDPOS Dataset.
/// Description: test with non-existent dataset_files.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetInvalidFileName) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetInvalidFileName.";
// Create a UDPOS Dataset.
// With non-existent dataset_files input.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "dev", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: specified dataset_files does not exist.
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test UDPOS Dataset.
/// Description: test with empty string dataset_files.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetEmptyFilePath) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetEmptyFilePath.";
// Create a UDPOS Dataset.
// With empty string dataset_files input.
std::shared_ptr<Dataset> ds = UDPOS("", "dev", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: specified dataset_files does not exist
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test UDPOS Dataset.
/// Description: test with invalid num_shards=0 value.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetInvalidNumShards) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetInvalidNumShards.";
// Create a UDPOS Dataset.
// With invalid num_shards=0 value.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "test", 0, ShuffleMode::kFalse, 0);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: Number of shards cannot be <=0.
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test UDPOS Dataset.
/// Description: test with invalid shard_id=-1 value.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetInvalidShardId) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetInvalidShardId.";
// Create a UDPOS Dataset.
// With invalid shard_id=-1 value.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "dev", 0, ShuffleMode::kFalse, -1);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: shard_id cannot be negative.
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test UDPOS Dataset.
/// Description: test with invalid shard_id=2 and num_shards=2 combination.
/// Expectation: unable to read in data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetInvalidIdAndShards) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetInvalidIdAndShards.";
// Create a UDPOS Dataset.
// With invalid shard_id=2 and num_shards=2 combination.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "dev", 0, ShuffleMode::kFalse, 2, 2);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: Cannot have shard_id >= num_shards.
EXPECT_EQ(iter, nullptr);
}
/// Feature: Test UDPOS Dataset.
/// Description: read all data with no shuffle, num_parallel_workers=1.
/// Expectation: return correct data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetShuffleFalse) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetShuffleFalse.";
// Test UDPOS Dataset with three UDPOS files and no shuffle, num_parallel_workers=1.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(654);
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a UDPOS Dataset, with three UDPOS files, en-ud-tag.v2.valid.txt ,
// en-ud-tag.v2.test.txt and en-ud-tag.v2.train.txt, in lexicographical order.
// Note: en-ud-tag.v2.valid.txt has 3 rows.
// Note: en-ud-tag.v2.test.txt has 1 rows.
// Note: en-ud-tag.v2.train.txt has 2 rows.
// Use default of all samples.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "all", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
std::vector<std::string> column_names = {"word", "universal", "stanford"};
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("word"), row.end());
std::vector<std::vector<std::string>> expected_result = {{"From", "Abed", "Ido"}, {"Psg", "Psg", "Nine"},
{"Bus", "Psg", "Nine"}, {"What", "Psg", "What"},
{"Abed", "Psg", "Nine"}, {"...", "Psg", "---"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto word = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {{}}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 + 1 + 2 = 6 samples.
EXPECT_EQ(i, 6);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test UDPOS Dataset.
/// Description: read all data with files shuffle, num_parallel_workers=1.
/// Expectation: return correct data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetShuffleFilesA) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetShuffleFilesA.";
// Test TUDPOS Dataset with files shuffle, num_parallel_workers=1.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(135);
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a UDPOS Dataset, with three UDPOS files, en-ud-tag.v2.valid.txt ,
// en-ud-tag.v2.test.txt and en-ud-tag.v2.train.txt, in lexicographical order.
// Note: en-ud-tag.v2.valid.txt has 3 rows.
// Note: en-ud-tag.v2.test.txt has 1 rows.
// Note: en-ud-tag.v2.train.txt has 2 rows.
// Set shuffle to files shuffle.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "all", 0, ShuffleMode::kFiles);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
std::vector<std::string> column_names = {"word", "universal", "stanford"};
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("word"), row.end());
std::vector<std::vector<std::string>> expected_result = {{"Abed", "Psg", "Nine"}, {"...", "Psg", "---"},
{"What", "Psg", "What"}, {"From", "Abed", "Ido"},
{"Psg", "Psg", "Nine"}, {"Bus", "Psg", "Nine"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto word = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {{}}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 + 1 + 2 = 6 samples.
EXPECT_EQ(i, 6);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test UDPOS Dataset.
/// Description: read all data with no shuffle, num_parallel_workers=4, shard coverage.
/// Expectation: return correct data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetShuffleFilesB) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetShuffleFilesB.";
// Test UDPOS Dataset with files shuffle, num_parallel_workers=1.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(135);
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a UDPOS Dataset, with three UDPOS files, en-ud-tag.v2.valid.txt ,
// en-ud-tag.v2.test.txt and en-ud-tag.v2.train.txt, in lexicographical order.
// Note: en-ud-tag.v2.valid.txt has 3 rows.
// Note: en-ud-tag.v2.test.txt has 1 rows.
// Note: en-ud-tag.v2.train.txt has 2 rows.
// Set shuffle to files shuffle.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "all", 0, ShuffleMode::kInfile);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
std::vector<std::string> column_names = {"word", "universal", "stanford"};
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("word"), row.end());
std::vector<std::vector<std::string>> expected_result = {{"From", "Abed", "Ido"}, {"Psg", "Psg", "Nine"},
{"Bus", "Psg", "Nine"}, {"What", "Psg", "What"},
{"Abed", "Psg", "Nine"}, {"...", "Psg", "---"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto word = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {{}}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 + 1 + 2 = 6 samples.
EXPECT_EQ(i, 6);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration.
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: Test UDPOS Dataset.
/// Description: read all data with global shuffle, num_parallel_workers=1.
/// Expectation: return correct data.
TEST_F(MindDataTestPipeline, TestUDPOSDatasetShuffleGlobal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestUDPOSDatasetShuffleGlobal.";
// Test UDPOS Dataset with one UDPOS file, global shuffle, num_parallel_workers=1.
// Set configuration.
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(246);
GlobalContext::config_manager()->set_num_parallel_workers(1);
// Create a UDPOS Dataset, with one UDPOS files.
// Note: en-ud-tag.v2.test.txt has 1 rows.
// Set shuffle to global shuffle.
std::string dataset_dir = datasets_root_path_ + "/testUDPOSDataset/";
std::shared_ptr<Dataset> ds = UDPOS(dataset_dir, "test", 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
std::vector<std::string> column_names = {"word", "universal", "stanford"};
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("word"), row.end());
std::vector<std::vector<std::string>> expected_result = {{"What", "Psg", "What"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto word = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(word, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {{}}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 1 samples.
EXPECT_EQ(i, 1);
// Manually terminate the pipeline.
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}

View File

@ -0,0 +1,21 @@
From Abed Ido
The Dead Dead
Abed Psg Nine
Come Vivi Vivi
The Dead Dead
Std Nine Nine
Psg Psg Nine
Bus Psg Nine
Ori Abed Iike
The Psg Nine
Abed Vivi Vivi
The Nine Come
Bus Psg Nine
Nine Vivi Vivi
Job Psg Nine
Mom Psg Nine
Abed Psg Nine
From Abed Iike

View File

@ -0,0 +1,7 @@
What Psg What
Like Std Iike
Good Psg Nine
Mom Vivi Vivi
Iike Abed Iike
Good Psg Nine

View File

@ -0,0 +1,14 @@
Abed Psg Nine
... Psg High
Zoom Psg Nine
... Psg ...
Abed Abed Job
From Nine Nine
... Psg ---
The Dead Dead
ken Nine Nine
Ori Abed Iike
... Dead Dead
Respect Abed Job

View File

@ -0,0 +1,331 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
from util import config_get_set_num_parallel_workers, config_get_set_seed
DATA_DIR = '../data/dataset/testUDPOSDataset/'
def test_udpos_dataset_one_file():
"""
Feature: Test UDPOS Dataset.
Description: read one file
Expectation: throw number of data in a file
"""
data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["word"]))
count += 1
assert count == 1
def test_udpos_dataset_all_file():
"""
Feature: Test UDPOS Dataset.
Description: read all file
Expectation: throw number of data in all file
"""
data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=False)
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["word"]))
count += 1
assert count == 6
def test_udpos_dataset_shuffle_false_four_parallel():
"""
Feature: Test UDPOS Dataset.
Description: set up four parallel
Expectation: throw data
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(987)
data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=False)
count = 0
numword = 6
line = ["From", "The", "Abed", "Come", "The", "Std",
"What", "Like", "Good", "Mom", "Iike", "Good",
"Abed", "...", "Zoom", "...", "Abed", "From",
"Psg", "Bus", "Ori", "The", "Abed", "The",
"...", "The", "ken", "Ori", "...", "Respect",
"Bus", "Nine", "Job", "Mom", "Abed", "From"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
for j in range(numword):
strs = i["word"][j].item().decode("utf8")
assert strs == line[count*6+j]
count += 1
assert count == 6
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_udpos_dataset_shuffle_false_one_parallel():
"""
Feature: Test UDPOS Dataset.
Description: no parallelism set
Expectation: throw data
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(987)
data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=False)
count = 0
numword = 6
line = ["From", "The", "Abed", "Come", "The", "Std",
"Psg", "Bus", "Ori", "The", "Abed", "The",
"Bus", "Nine", "Job", "Mom", "Abed", "From",
"What", "Like", "Good", "Mom", "Iike", "Good",
"Abed", "...", "Zoom", "...", "Abed", "From",
"...", "The", "ken", "Ori", "...", "Respect"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
for j in range(numword):
strs = i["word"][j].item().decode("utf8")
assert strs == line[count*6+j]
count += 1
assert count == 6
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_udpos_dataset_shuffle_files_four_parallel():
"""
Feature: Test UDPOS Dataset.
Description: set four parallel and file Disorder
Expectation: throw data
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(135)
data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES)
count = 0
numword = 6
line = ["Abed", "...", "Zoom", "...", "Abed", "From",
"What", "Like", "Good", "Mom", "Iike", "Good",
"From", "The", "Abed", "Come", "The", "Std",
"...", "The", "ken", "Ori", "...", "Respect",
"Psg", "Bus", "Ori", "The", "Abed", "The",
"Bus", "Nine", "Job", "Mom", "Abed", "From"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
for j in range(numword):
strs = i["word"][j].item().decode("utf8")
assert strs == line[count*6+j]
count += 1
assert count == 6
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_udpos_dataset_shuffle_files_one_parallel():
"""
Feature: Test UDPOS Dataset.
Description: set no parallelism and file Disorder
Expectation: throw data
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(135)
data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES)
count = 0
numword = 6
line = ["Abed", "...", "Zoom", "...", "Abed", "From",
"...", "The", "ken", "Ori", "...", "Respect",
"What", "Like", "Good", "Mom", "Iike", "Good",
"From", "The", "Abed", "Come", "The", "Std",
"Psg", "Bus", "Ori", "The", "Abed", "The",
"Bus", "Nine", "Job", "Mom", "Abed", "From"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
for j in range(numword):
strs = i["word"][j].item().decode("utf8")
assert strs == line[count*6+j]
count += 1
assert count == 6
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_udpos_dataset_shuffle_global_four_parallel():
"""
Feature: Test UDPOS Dataset.
Description: set four parallel and all Disorder
Expectation: throw data
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
original_seed = config_get_set_seed(246)
data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL)
count = 0
numword = 6
line = ["Bus", "Nine", "Job", "Mom", "Abed", "From",
"Abed", "...", "Zoom", "...", "Abed", "From",
"From", "The", "Abed", "Come", "The", "Std",
"Psg", "Bus", "Ori", "The", "Abed", "The",
"What", "Like", "Good", "Mom", "Iike", "Good",
"...", "The", "ken", "Ori", "...", "Respect"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
for j in range(numword):
strs = i["word"][j].item().decode("utf8")
assert strs == line[count*6+j]
count += 1
assert count == 6
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_udpos_dataset_shuffle_global_one_parallel():
"""
Feature: Test UDPOS Dataset.
Description: set no parallelism and all Disorder
Expectation: throw data
"""
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
original_seed = config_get_set_seed(246)
data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL)
count = 0
numword = 6
line = ["...", "The", "ken", "Ori", "...", "Respect",
"Psg", "Bus", "Ori", "The", "Abed", "The",
"From", "The", "Abed", "Come", "The", "Std",
"Bus", "Nine", "Job", "Mom", "Abed", "From",
"What", "Like", "Good", "Mom", "Iike", "Good",
"Abed", "...", "Zoom", "...", "Abed", "From"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
for j in range(numword):
strs = i["word"][j].item().decode("utf8")
assert strs == line[count*6+j]
count += 1
assert count == 6
# Restore configuration
ds.config.set_num_parallel_workers(original_num_parallel_workers)
ds.config.set_seed(original_seed)
def test_udpos_dataset_num_samples():
"""
Feature: Test UDPOS Dataset.
Description: read one file
Expectation: throw number of file
"""
data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False, num_samples=2)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 1
def test_udpos_dataset_distribution():
"""
Feature: Test UDPOS Dataset.
Description: read one file
Expectation: throw number of file
"""
data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False, num_shards=2, shard_id=1)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 1
def test_udpos_dataset_repeat():
"""
Feature: Test UDPOS Dataset.
Description: repeat read data
Expectation: throw data
"""
data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
data = data.repeat(3)
count = 0
numword = 6
line = ["What", "Like", "Good", "Mom", "Iike", "Good",
"What", "Like", "Good", "Mom", "Iike", "Good",
"What", "Like", "Good", "Mom", "Iike", "Good"]
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
for j in range(numword):
strs = i["word"][j].item().decode("utf8")
assert strs == line[count*6+j]
count += 1
assert count == 3
def test_udpos_dataset_get_datasetsize():
"""
Feature: Test UDPOS Dataset.
Description: repeat read data
Expectation: throw data
"""
data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
size = data.get_dataset_size()
assert size == 6
def test_udpos_dataset_to_device():
"""
Feature: Test UDPOS Dataset.
Description: transfer data from CPU to other devices
Expectation: send
"""
data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
data = data.to_device()
data.send()
def test_udpos_dataset_exceptions():
"""
Feature: Test UDPOS Dataset.
Description: send error when error occur
Expectation: send error
"""
with pytest.raises(ValueError) as error_info:
_ = ds.UDPOSDataset(DATA_DIR, usage="test", num_samples=-1)
assert "num_samples exceeds the boundary" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.UDPOSDataset("NotExistFile", usage="test")
assert "The folder NotExistFile does not exist or is not a directory or permission denied!" in str(error_info.value)
with pytest.raises(ValueError) as error_info:
_ = ds.TextFileDataset("")
assert "The following patterns did not match any files" in str(error_info.value)
def exception_func(item):
raise Exception("Error occur!")
with pytest.raises(RuntimeError) as error_info:
data = data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
data = data.map(operations=exception_func, input_columns=["word"], num_parallel_workers=1)
for _ in data.__iter__():
pass
assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
if __name__ == "__main__":
test_udpos_dataset_one_file()
test_udpos_dataset_all_file()
test_udpos_dataset_shuffle_false_four_parallel()
test_udpos_dataset_shuffle_false_one_parallel()
test_udpos_dataset_shuffle_files_one_parallel()
test_udpos_dataset_shuffle_files_four_parallel()
test_udpos_dataset_shuffle_global_four_parallel()
test_udpos_dataset_shuffle_global_one_parallel()
test_udpos_dataset_num_samples()
test_udpos_dataset_distribution()
test_udpos_dataset_repeat()
test_udpos_dataset_get_datasetsize()
test_udpos_dataset_to_device()
test_udpos_dataset_exceptions()