[assistant][ops]New operator implementation, include SST2Dataset

This commit is contained in:
uccInf 2022-09-17 16:16:43 +08:00
parent 6d9fdaacc1
commit f287b4865c
23 changed files with 1507 additions and 2 deletions

View File

@ -0,0 +1,69 @@
mindspore.dataset.SST2Dataset
=============================
.. py:class:: mindspore.dataset.SST2Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None)
读取和解析SST2数据集的源数据集。
数据集中train.tsv文件和dev.tsv有两列 `[sentence, label]`
数据集中test.tsv文件中有一列 `[sentence]`
`sentence` 列和 `label` 列的数据类型都是string。
参数:
- **dataset_dir** (str) - 包含数据集文件的根目录路径。
- **usage** (str, 可选) - 指定数据集的子集,可取值为'train'、'test'或'dev'。
取值为'train'时将会读取67,349个训练样本取值为'test'时将会读取1,821个测试样本取值为'dev'时将会读取872个样本。默认值None读取train中样本。
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值None读取全部样本。
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值None使用mindspore.dataset.config中配置的线程数。
- **shuffle** (Union[bool, Shuffle], 可选) - 每个epoch中数据混洗的模式支持传入bool类型与枚举类型进行指定。默认值`Shuffle.GLOBAL`
如果 `shuffle` 为False则不混洗如果 `shuffle` 为True等同于将 `shuffle` 设置为mindspore.dataset.Shuffle.GLOBAL。
通过传入枚举变量设置数据混洗的模式:
- **Shuffle.GLOBAL**:混洗样本。
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值None。指定此参数后 `num_samples` 表示每个分片的最大样本数。
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值None。只有当指定了 `num_shards` 时才能指定此参数。
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值None不使用缓存。
异常:
- **RuntimeError** - `dataset_dir` 参数所指向的文件目录不存在或缺少数据集文件。
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
- **RuntimeError** - 指定了 `num_shards` 参数,但是未指定 `shard_id` 参数。
- **RuntimeError** - 指定了 `shard_id` 参数,但是未指定 `num_shards` 参数。
- **ValueError** - `shard_id` 参数值错误小于0或者大于等于 `num_shards`
**关于SST2数据集**
Stanford Sentiment Treebank是一个具有完全标记解析树的语料库可以对语言中情感的合成效果进行完整的分析。
语料库基于Pang和Lee2005介绍的数据集由11855个从电影评论中提取的句子组成。它是用斯坦福解析器解析的
共包含215154个来自这些解析树的独特短语每个短语都由3个人类评委进行注释。
以下为原始SST2数据集的结构您可以将数据集文件解压得到如下的文件结构并通过MindSpore的API进行读取。
.. code-block::
.
└── sst2_dataset_dir
├── train.tsv
├── test.tsv
├── dev.tsv
└── original
**引用:**
.. code-block::
@inproceedings{socher-etal-2013-recursive,
title = {Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank},
author = {Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning,
Christopher D. and Ng, Andrew and Potts, Christopher},
booktitle = {Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing},
month = oct,
year = {2013},
address = {Seattle, Washington, USA},
publisher = {Association for Computational Linguistics},
url = {https://www.aclweb.org/anthology/D13-1170},
pages = {1631--1642},
}
include:: mindspore.dataset.api_list_vision.rst

View File

@ -149,6 +149,7 @@ mindspore.dataset
mindspore.dataset.IWSLT2017Dataset
mindspore.dataset.PennTreebankDataset
mindspore.dataset.SogouNewsDataset
mindspore.dataset.SST2Dataset
mindspore.dataset.TextFileDataset
mindspore.dataset.UDPOSDataset
mindspore.dataset.WikiTextDataset

View File

@ -62,6 +62,7 @@ Text
mindspore.dataset.IWSLT2017Dataset
mindspore.dataset.PennTreebankDataset
mindspore.dataset.SogouNewsDataset
mindspore.dataset.SST2Dataset
mindspore.dataset.TextFileDataset
mindspore.dataset.UDPOSDataset
mindspore.dataset.WikiTextDataset

View File

@ -121,6 +121,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/sogou_news_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
@ -1856,6 +1857,14 @@ SQuADDataset::SQuADDataset(const std::vector<char> &dataset_dir, const std::vect
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
SST2Dataset::SST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
const std::shared_ptr<DatasetCache> &cache) {
auto ds = std::make_shared<SST2Node>(CharToString(dataset_dir), CharToString(usage), num_samples, shuffle, num_shards,
shard_id, cache);
ir_node_ = std::static_pointer_cast<SST2Node>(ds);
}
TedliumDataset::TedliumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &release,
const std::vector<char> &usage, const std::vector<char> &extensions,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {

View File

@ -59,6 +59,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/semeion_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/speech_commands_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/squad_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/stl10_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/sun397_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/tedlium_node.h"
@ -697,6 +698,18 @@ PYBIND_REGISTER(SQuADNode, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(SST2Node, 2, ([](const py::module *m) {
(void)py::class_<SST2Node, DatasetNode, std::shared_ptr<SST2Node>>(*m, "SST2Node",
"to create a SST2Node")
.def(py::init([](const std::string &dataset_dir, const std::string &usage, int64_t num_samples,
int32_t shuffle, int32_t num_shards, int32_t shard_id) {
auto sst2 = std::make_shared<SST2Node>(dataset_dir, usage, num_samples, toShuffleMode(shuffle),
num_shards, shard_id, nullptr);
THROW_IF_ERROR(sst2->ValidateParams());
return sst2;
}));
}));
PYBIND_REGISTER(STL10Node, 2, ([](const py::module *m) {
(void)py::class_<STL10Node, DatasetNode, std::shared_ptr<STL10Node>>(*m, "STL10Node",
"to create a STL10Node")

View File

@ -49,6 +49,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
sogou_news_op.cc
speech_commands_op.cc
squad_op.cc
sst2_op.cc
stl10_op.cc
sun397_op.cc
tedlium_op.cc

View File

@ -225,7 +225,7 @@ class CsvOp : public NonMappableLeafOp {
/// @param str - the input string
/// @param str - the delimiter
/// @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim);
virtual std::vector<std::string> split(const std::string &s, char delim);
// Private function for analysing the column name in every CSV file
// @return - Status

View File

@ -0,0 +1,135 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/sst2_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <stdexcept>
#include "include/common/debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/util/random.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace dataset {
SST2Op::SST2Op(const std::vector<std::string> &dataset_files_list, const std::string &usage, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t num_samples,
int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int32_t num_devices,
int32_t device_id)
: CsvOp(dataset_files_list, field_delim, column_default, column_name, num_workers, num_samples,
worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id),
usage_(usage) {}
void SST2Op::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nST2 files list:\n";
for (int i = 0; i < csv_files_list_.size(); ++i) {
out << " " << csv_files_list_[i];
}
out << "\n\n";
}
}
std::vector<std::string> SST2Op::split(const std::string &s, char delim) {
std::vector<std::string> res;
std::stringstream ss(s);
std::string item;
bool skip = usage_ == "test";
while (getline(ss, item, delim)) {
if (skip) {
skip = false;
} else {
res.push_back(item);
}
}
return res;
}
Status SST2Op::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
CsvParser csv_parser(worker_id, jagged_rows_connector_.get(), field_delim_, column_default_list_, file);
RETURN_IF_NOT_OK(csv_parser.InitCsvParser());
csv_parser.SetStartOffset(start_offset);
csv_parser.SetEndOffset(end_offset);
auto realpath = FileUtils::GetRealPath(file.c_str());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " << file << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, " + file + " does not exist.");
}
std::ifstream ifs;
ifs.open(realpath.value(), std::ifstream::in);
if (!ifs.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + file + ", the file is damaged or permission denied.");
}
if (column_name_list_.empty()) {
std::string tmp;
getline(ifs, tmp);
}
bool skip = usage_ == "test";
csv_parser.Reset();
try {
while (ifs.good()) {
// when ifstream reaches the end of file, the function get() return std::char_traits<char>::eof()
// which is a 32-bit -1, it's not equal to the 8-bit -1 on Euler OS. So instead of char, we use
// int to receive its return value.
int chr = ifs.get();
if (skip) {
if (chr == field_delim_) {
skip = false;
}
continue;
}
if (usage_ == "test" && chr == '\n') {
skip = true;
}
int err = csv_parser.ProcessMessage(chr);
if (err != 0) {
// if error code is -2, the returned error is interrupted
if (err == -2) {
return Status(kMDInterrupted);
}
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse csv file: " + file + " at line " +
std::to_string(csv_parser.GetTotalRows() + 1) +
". Error message: " + csv_parser.GetErrorMessage());
}
}
} catch (std::invalid_argument &ia) {
std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
RETURN_STATUS_UNEXPECTED("Invalid csv, csv file: " + file + " parse failed at line " + err_row +
", type does not match.");
} catch (std::out_of_range &oor) {
std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
RETURN_STATUS_UNEXPECTED("Invalid csv, " + file + " parse failed at line " + err_row + " : value out of range.");
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,87 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SST2_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SST2_OP_H_
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
namespace mindspore {
namespace dataset {
class SST2Op : public CsvOp {
public:
/// \brief Constructor.
/// \param[in] dataset_files_list List of file paths for the dataset files.
/// \param[in] usage Usage of this dataset, can be 'train', 'test' or 'all'.
/// \param[in] field_delim A char that indicates the delimiter to separate fields.
/// \param[in] column_default List of default values for the CSV field (default={}). Each item in the list is
/// either a valid type (float, int, or string).
/// \param[in] column_name List of column names of the dataset file.
/// \param[in] num_workers Num of workers reading files in parallel.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] worker_connector_size Size of each internal queue.
/// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from.
/// \param[in] shuffle_files Whether or not to shuffle the files before reading data.
/// \param[in] num_devices Number of devices that the dataset should be divided into.
/// \param[in] device_id The device ID within num_devices.
SST2Op(const std::vector<std::string> &dataset_files_list, const std::string &usage, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);
/// \brief Destructor.
~SST2Op() override = default;
/// \brief A print method typically used for debugging
/// \param[out] out The output stream to write output to
/// \param[in] show_all A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
/// \brief DatasetName name getter.
/// \param[in] upper A bool to control if you want to return uppercase or lowercase Op name.
/// \return DatasetName of the current Op.
std::string DatasetName(bool upper = false) const { return upper ? "SST2" : "sst2"; }
/// \brief Op name getter
/// \return Name of the current Op.
std::string Name() const override { return "SST2Op"; }
/// \brief Split string based on a character delimiter
/// @param[in] s The input string
/// @param[in] delim The delimiter
/// @return The a string vector
std::vector<std::string> split(const std::string &s, char delim);
/// \brief Reads a csv file and loads the data into multiple tensors.
/// @param[in] file The file to read.
/// @param[in] start_offset The start offset of file.
/// @param[in] end_offset The end offset of file.
/// @param[in] worker_id The id of the worker that is executing this function.
/// @return Status The error code returned.
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
std::string usage_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SST2_OP_H_

View File

@ -123,6 +123,7 @@ constexpr char kSemeionNode[] = "SemeionDataset";
constexpr char kSogouNewsNode[] = "SogouNewsDataset";
constexpr char kSpeechCommandsNode[] = "SpeechCommandsDataset";
constexpr char kSQuADNode[] = "SQuADDataset";
constexpr char kSST2Node[] = "SST2Dataset";
constexpr char kSTL10Node[] = "STL10Dataset";
constexpr char kSUN397Node[] = "SUN397Dataset";
constexpr char kTedliumNode[] = "TedliumDataset";

View File

@ -50,6 +50,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
sogou_news_node.cc
speech_commands_node.cc
squad_node.cc
sst2_node.cc
stl10_node.cc
sun397_node.cc
tedlium_node.cc

View File

@ -0,0 +1,204 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
SST2Node::SST2Node(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
: NonMappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
num_samples_(num_samples),
shuffle_(shuffle),
num_shards_(num_shards),
shard_id_(shard_id) {
// Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
// is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
// 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
// PreBuildSampler is phased out, this can be cleaned up.
GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
}
std::shared_ptr<DatasetNode> SST2Node::Copy() {
auto node = std::make_shared<SST2Node>(dataset_dir_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
(void)node->SetNumWorkers(num_workers_);
(void)node->SetConnectorQueueSize(connector_que_size_);
return node;
}
void SST2Node::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") +
", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")");
}
Status SST2Node::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SST2Node", dataset_dir_));
RETURN_IF_NOT_OK(ValidateStringValue("SST2Node", usage_, {"dev", "train", "test"}));
RETURN_IF_NOT_OK(ValidateScalar("SST2Node", "num_samples", num_samples_, {0}, false));
RETURN_IF_NOT_OK(ValidateDatasetShardParams("SST2Node", num_shards_, shard_id_));
return Status::OK();
}
// Function to build SST2Node
Status SST2Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
RETURN_UNEXPECTED_IF_NULL(node_ops);
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
// Sort the dataset files in a lexicographical order
std::vector<std::string> sorted_dataset_files;
RETURN_IF_NOT_OK(WalkAllFiles(dataset_dir_, usage_, &sorted_dataset_files));
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
char field_delim = '\t';
std::vector<std::string> column_names;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
std::shared_ptr<SST2Op> sst2_op = std::make_shared<SST2Op>(
sorted_dataset_files, usage_, field_delim, column_default_list, column_names, num_workers_, num_samples_,
worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
RETURN_IF_NOT_OK(sst2_op->Init());
// If a global shuffle is used for SST2, it will inject a shuffle op over the SST2.
// But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built.
// This is achieved in the cache transform pass where we call MakeSimpleProducer to reset SST2's shuffle
// option to false.
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_IF_NOT_OK(SST2Op::CountAllFileRows(sorted_dataset_files, column_names.empty(), &num_rows));
// Add the shuffle op after this op
RETURN_IF_NOT_OK(
AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
shuffle_op->SetTotalRepeats(GetTotalRepeats());
shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
sst2_op->SetTotalRepeats(GetTotalRepeats());
sst2_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(sst2_op);
return Status::OK();
}
Status SST2Node::WalkAllFiles(const std::string &dataset_dir, const std::string &usage,
std::vector<std::string> *dataset_files) {
RETURN_UNEXPECTED_IF_NULL(dataset_files);
Path train_file_name("train.tsv");
Path test_file_name("test.tsv");
Path dev_file_name("dev.tsv");
Path dir(dataset_dir);
if (usage == "train") {
Path file_path = dir / train_file_name;
dataset_files->push_back(file_path.ToString());
} else if (usage == "test") {
Path file_path = dir / test_file_name;
dataset_files->push_back(file_path.ToString());
} else if (usage == "dev") {
Path file_path = dir / dev_file_name;
dataset_files->push_back(file_path.ToString());
}
return Status::OK();
}
// Get the shard id of node
Status SST2Node::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);
*shard_id = shard_id_;
return Status::OK();
}
// Get Dataset size
Status SST2Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
RETURN_UNEXPECTED_IF_NULL(dataset_size);
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
std::vector<std::string> dataset_files;
RETURN_IF_NOT_OK(WalkAllFiles(dataset_dir_, usage_, &dataset_files));
RETURN_IF_NOT_OK(SST2Op::CountAllFileRows(dataset_files, true, &num_rows));
sample_size = num_samples_;
num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_)));
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status SST2Node::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_parallel_workers"] = num_workers_;
args["connector_queue_size"] = connector_que_size_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["num_samples"] = num_samples_;
args["shuffle"] = shuffle_;
args["num_shards"] = num_shards_;
args["shard_id"] = shard_id_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// SST2 by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
Status SST2Node::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
return Status::OK();
}
// If a cache has been added into the ascendant tree over this SST2 node, then the cache will be executing
// a sampler for fetching the data. As such, any options in the SST2 node need to be reset to its defaults so
// that this SST2 node will produce the full set of data into the cache.
Status SST2Node::MakeSimpleProducer() {
shard_id_ = 0;
num_shards_ = 1;
shuffle_ = ShuffleMode::kFalse;
num_samples_ = 0;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,120 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/sst2_op.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class SST2Node : public NonMappableSourceNode {
public:
/// \brief Constructor.
SST2Node(const std::string &dataset_dir, const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor.
~SST2Node() override = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kSST2Node; }
/// \brief Print the description.
/// \param[out] out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Generate a list of read file names according to usage.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of SST2.
/// \param[in] dataset_files List of filepaths for the dataset files
/// \return std::vector<std::string> A list of read file names.
Status WalkAllFiles(const std::string &dataset_dir, const std::string &usage,
std::vector<std::string> *dataset_files);
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size The size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
int64_t NumSamples() const { return num_samples_; }
ShuffleMode Shuffle() const { return shuffle_; }
int32_t NumShards() const { return num_shards_; }
int32_t ShardId() const { return shard_id_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief SST2 by itself is a non-mappable dataset that does not support sampling.
/// However, if a cache operator is injected at some other place higher in the tree, that cache can
/// inherit this sampler from the leaf, providing sampling support from the caching layer.
/// That is why we setup the sampler for a leaf node that does not use sampling.
/// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \param[in] sampler The sampler to setup.
/// \return Status of the function.
Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override;
/// \brief If a cache has been added into the ascendant tree over this SST2 node, then the cache will be
/// executing a sampler for fetching the data. As such, any options in the SST2 node need to be reset
/// to its defaults so that this SST2 node will produce the full set of data into the cache. Note:
/// This function is common among NonMappableSourceNode and should be promoted to its parent class.
/// \return Status of the function.
Status MakeSimpleProducer() override;
private:
std::string dataset_dir_;
std::string usage_;
int64_t num_samples_;
ShuffleMode shuffle_;
int32_t num_shards_;
int32_t shard_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SST2_NODE_H_

View File

@ -5237,6 +5237,64 @@ inline std::shared_ptr<SQuADDataset> DATASET_API SQuAD(const std::string &datase
num_shards, shard_id, cache);
}
/// \class SST2Dataset
/// \brief A source dataset for reading and parsing SST2 dataset.
class DATASET_API SST2Dataset : public Dataset {
public:
/// \brief Constructor of SST2Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of SST2, can be "train", "test" or "dev".
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle The mode for shuffling data every epoch.
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kGlobal - Shuffle the samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified.
/// \param[in] cache Tensor cache to use.
SST2Dataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, int64_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of SST2.
~SST2Dataset() override = default;
};
/// \brief Function to create a SST2Dataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of SST2, can be "train", "test" or "dev". Default: "train".
/// \param[in] num_samples The number of samples to be included in the dataset.
/// Default: 0, means all samples.
/// \param[in] shuffle The mode for shuffling data every epoch. Default: ShuffleMode::kGlobal.
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
/// ShuffleMode::kFiles - Shuffle files only.
/// ShuffleMode::kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. Default: 1.
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. Default: 0.
/// \param[in] cache Tensor cache to use. Default: nullptr, which means no cache is used.
/// \return Shared pointer to the SST2Dataset
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string folder_path = "/path/to/sst2_dataset_directory";
/// std::shared_ptr<Dataset> ds = SST2(folder_path, "train");
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
/// \endcode
inline std::shared_ptr<SST2Dataset> DATASET_API SST2(const std::string &dataset_dir, const std::string &usage = "train",
int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<SST2Dataset>(StringToChar(dataset_dir), StringToChar(usage), num_samples, shuffle, num_shards,
shard_id, cache);
}
/// \class STL10Dataset
/// \brief A source dataset that reads and parses STL10 dataset.
class DATASET_API STL10Dataset : public Dataset {

View File

@ -68,6 +68,7 @@ class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
friend class SBUDataset;
friend class SemeionDataset;
friend class SpeechCommandsDataset;
friend class SST2Dataset;
friend class STL10Dataset;
friend class SUN397Dataset;
friend class TedliumDataset;

View File

@ -85,6 +85,7 @@ __all__ = ["Caltech101Dataset", # Vision
"PennTreebankDataset", # Text
"SogouNewsDataset", # Text
"SQuADDataset", # Text
"SST2Dataset", # Text
"TextFileDataset", # Text
"UDPOSDataset", # Text
"WikiTextDataset", # Text

View File

@ -30,7 +30,8 @@ from .validators import check_imdb_dataset, check_iwslt2016_dataset, check_iwslt
check_penn_treebank_dataset, check_ag_news_dataset, check_amazon_review_dataset, check_udpos_dataset, \
check_wiki_text_dataset, check_conll2000_dataset, check_cluedataset, \
check_sogou_news_dataset, check_textfiledataset, check_dbpedia_dataset, check_yelp_review_dataset, \
check_en_wik9_dataset, check_yahoo_answers_dataset, check_multi30k_dataset, check_squad_dataset
check_en_wik9_dataset, check_yahoo_answers_dataset, check_multi30k_dataset, check_squad_dataset, \
check_sst2_dataset
from ..core.validator_helpers import replace_none
@ -1496,6 +1497,105 @@ class SQuADDataset(SourceDataset, TextBaseDataset):
self.num_shards, self.shard_id)
class SST2Dataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses the SST2 dataset.
The generated dataset's train.tsv and dev.tsv have two columns :py:obj:`[sentence, label]` .
The generated dataset's test.tsv has one column :py:obj:`[sentence]` .
The tensor of column :py:obj:`sentence` and :py:obj:`label` are of the string type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be `train`, `test` or `dev`. `train` will read
from 67,349 train samples, `test` will read from 1,821 test samples, `dev` will read from
all 872 samples. Default: None, will read train samples.
num_samples (int, optional): The number of samples to be included in the dataset.
Default: None, will include all text.
num_parallel_workers (int, optional): Number of workers to read the data.
Default: None, number set in the mindspore.dataset.config.
shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch.
Bool type and Shuffle enum are both supported to pass in. Default: `Shuffle.GLOBAL` .
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Set the mode of data shuffling by passing in enumeration variables:
- Shuffle.GLOBAL: Shuffle the samples.
num_shards (int, optional): Number of shards that the dataset will be divided into. Default: None.
When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
shard_id (int, optional): The shard ID within num_shards. This argument can only be specified when
num_shards is also specified. Default: None.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. More details:
`Single-Node Data Cache <https://www.mindspore.cn/tutorials/experts/en/master/dataset/cache.html>`_ .
Default: None, which means no cache is used.
Raises:
RuntimeError: If `dataset_dir` does not contain data files.
RuntimeError: If `num_parallel_workers` exceeds the max thread numbers.
RuntimeError: If `num_shards` is specified but shard_id is None.
RuntimeError: If `shard_id` is specified but num_shards is None.
ValueError: If `shard_id` is invalid (< 0 or >= num_shards).
Examples:
>>> sst2_dataset_dir = "/path/to/sst2_dataset_directory"
>>>
>>> # 1) Read 3 samples from SST2 dataset
>>> dataset = ds.SST2Dataset(dataset_dir=sst2_dataset_dir, num_samples=3)
>>>
>>> # 2) Read train samples from SST2 dataset
>>> dataset = ds.SST2Dataset(dataset_dir=sst2_dataset_dir, usage="train")
About SST2 dataset:
The Stanford Sentiment Treebank is a corpus with fully labeled parse trees that allows for a complete
analysis of the compositional effects of sentiment in language. The corpus is based on the dataset introduced
by Pang and Lee (2005) and consists of 11,855 single sentences extracted from movie reviews. It was parsed
with the Stanford parser and includes a total of 215,154 unique phrases from those parse trees, each
annotated by 3 human judges.
Here is the original SST2 dataset structure.
You can unzip the dataset files into this directory structure and read by Mindspore's API.
.. code-block::
.
sst2_dataset_dir
train.tsv
test.tsv
dev.tsv
original
Citation:
.. code-block::
@inproceedings{socher-etal-2013-recursive,
title = {Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank},
author = {Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning,
Christopher D. and Ng, Andrew and Potts, Christopher},
booktitle = {Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing},
month = oct,
year = {2013},
address = {Seattle, Washington, USA},
publisher = {Association for Computational Linguistics},
url = {https://www.aclweb.org/anthology/D13-1170},
pages = {1631--1642},
}
"""
@check_sst2_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL,
num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "train")
def parse(self, children=None):
return cde.SST2Node(self.dataset_dir, self.usage, self.num_samples, self.shuffle_flag,
self.num_shards, self.shard_id)
class TextFileDataset(SourceDataset, TextBaseDataset):
"""
A source dataset that reads and parses datasets stored on disk in text format.

View File

@ -2819,6 +2819,34 @@ def check_svhn_dataset(method):
return new_method
def check_sst2_dataset(method):
"""A wrapper that wraps a parameter checker around the original SST2 Dataset."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test", "dev"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_stl10_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(STL10Dataset)."""

View File

@ -0,0 +1,524 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/engine/ir/datasetops/source/sst2_node.h"
#include "minddata/dataset/include/dataset/datasets.h"
using namespace mindspore::dataset;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: SST2Dataset
/// Description: Read test data
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetBasic.";
// Create a SST2 Dataset
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::vector<std::string> column_names = {"sentence"};
std::shared_ptr<Dataset> ds = SST2(folder_path, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterator the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_EQ(row.find("index"), row.end());
EXPECT_NE(row.find("sentence"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"test read SST2dataset 1 ."},
{"test read SST2dataset 2 ."},
{"test read SST2dataset 3 ."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: SST2Dataset
/// Description: Read train data and test data
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetUsageTrain) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetUsageTrain.";
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::vector<std::string> column_names = {"sentence", "label"};
std::shared_ptr<Dataset> ds = SST2(folder_path, "train", 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterator the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("sentence"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"train read SST2Dataset 1 . ","0"},
{"train read SST2Dataset 2 . ","1"},
{"train read SST2Dataset 3 . ","1"},
{"train read SST2Dataset 4 . ","1"},
{"train read SST2Dataset 5 . ","0"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: SST2Dataset
/// Description: Includes tests for shape, type, size
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetGetters.";
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::shared_ptr<Dataset> ds = SST2(folder_path, "test", 0, ShuffleMode::kFalse);
std::vector<std::string> column_names = {"sentence"};
EXPECT_NE(ds, nullptr);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
EXPECT_EQ(types.size(), 1);
EXPECT_EQ(types[0].ToString(), "string");
EXPECT_EQ(shapes.size(), 1);
EXPECT_EQ(shapes[0].ToString(), "<>");
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 3);
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: SST2Dataset
/// Description: Read 2 samples from train file
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetNumSamples) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetNumSamples.";
// Create a SST2Dataset.
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::vector<std::string> column_names = {"sentence", "label"};
std::shared_ptr<Dataset> ds = SST2(folder_path, "train", 2, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("sentence"), row.end());
EXPECT_NE(row.find("label"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"train read SST2Dataset 1 . ","0"},
{"train read SST2Dataset 2 . ","1"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: SST2Dataset
/// Description: Test in a distributed state
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetDistribution) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetDistribution.";
// Create a SST2Dataset
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::vector<std::string> column_names = {"sentence", "label"};
std::shared_ptr<Dataset> ds = SST2(folder_path, "train", 0, ShuffleMode::kFalse, 2, 0);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("sentence"), row.end());
EXPECT_NE(row.find("label"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"train read SST2Dataset 1 . ","0"},
{"train read SST2Dataset 2 . ","1"},
{"train read SST2Dataset 3 . ","1"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: SST2Dataset
/// Description: Test with invalid input
/// Expectation: Throw error messages when certain errors occur
TEST_F(MindDataTestPipeline, TestSST2DatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetFail.";
// Create a SST2 Dataset
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::string invalid_folder_path = "./NotExistPath";
std::vector<std::string> column_names = {"sentence", "label"};
// Test invalid folder_path
std::shared_ptr<Dataset> ds0 = SST2(invalid_folder_path, "train", -1, ShuffleMode::kFalse);
EXPECT_NE(ds0, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
// Expect failure: invalid SST2 input
EXPECT_EQ(iter0, nullptr);
// Test invalid usage
std::shared_ptr<Dataset> ds1 = SST2(folder_path, "all", 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid SST2 input
EXPECT_EQ(iter1, nullptr);
// Test invalid num_samples < -1
std::shared_ptr<Dataset> ds2 = SST2(folder_path, "train", -1, ShuffleMode::kFalse);
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
// Expect failure: invalid SST2 input
EXPECT_EQ(iter2, nullptr);
// Test invalid num_shards < 1
std::shared_ptr<Dataset> ds3 = SST2(folder_path, "train", 0, ShuffleMode::kFalse, 0);
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid SST2 input
EXPECT_EQ(iter3, nullptr);
// Test invalid shard_id >= num_shards
std::shared_ptr<Dataset> ds4 = SST2(folder_path, "train", 0, ShuffleMode::kFalse, 2, 2);
EXPECT_NE(ds4, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter4 = ds4->CreateIterator();
// Expect failure: invalid SST2 input
EXPECT_EQ(iter4, nullptr);
}
/// Feature: SST2Dataset
/// Description: Read data with pipeline from test file
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetWithPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetWithPipeline.";
// Create two SST2 Dataset, with single SST2 file
std::string dataset_dir = datasets_root_path_ + "/testSST2/";
std::shared_ptr<Dataset> ds1 = SST2(dataset_dir, "test", 0, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds2 = SST2(dataset_dir, "test", 0, ShuffleMode::kFalse);
EXPECT_NE(ds1, nullptr);
EXPECT_NE(ds2, nullptr);
// Create two Repeat operation on ds
int32_t repeat_num = 2;
ds1 = ds1->Repeat(repeat_num);
EXPECT_NE(ds1, nullptr);
repeat_num = 3;
ds2 = ds2->Repeat(repeat_num);
EXPECT_NE(ds2, nullptr);
// Create two Project operation on ds
std::vector<std::string> column_project = {"sentence"};
ds1 = ds1->Project(column_project);
EXPECT_NE(ds1, nullptr);
ds2 = ds2->Project(column_project);
EXPECT_NE(ds2, nullptr);
// Create a Concat operation on the ds
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("sentence"), row.end());
uint64_t i = 0;
while (row.size() != 0) {
auto text = row["sentence"];
MS_LOG(INFO) << "Tensor text shape: " << text.Shape();
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
// Expect 2 × 3 + 3 × 3 = 15 samples
EXPECT_EQ(i, 15);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: SST2Dataset
/// Description: Test with shuffle files
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetShuffleFilesA) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetShuffleFilesA.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::vector<std::string> column_names = {"sentence", "label"};
std::shared_ptr<Dataset> ds = SST2(folder_path, "train", 0, ShuffleMode::kFiles);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("sentence"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"train read SST2Dataset 1 . ","0"},
{"train read SST2Dataset 2 . ","1"},
{"train read SST2Dataset 3 . ","1"},
{"train read SST2Dataset 4 . ","1"},
{"train read SST2Dataset 5 . ","0"}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: SST2Dataset
/// Description: Test with shuffle in file
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetShuffleFilesB) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetShuffleFilesB.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::vector<std::string> column_names = {"sentence"};
std::shared_ptr<Dataset> ds = SST2(folder_path, "test", 0, ShuffleMode::kInfile);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("sentence"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"test read SST2dataset 1 ."},
{"test read SST2dataset 2 ."},
{"test read SST2dataset 3 ."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 samples
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}
/// Feature: SST2Dataset
/// Description: Test with global shuffle
/// Expectation: The data is processed successfully
TEST_F(MindDataTestPipeline, TestSST2DatasetShuffleGlobal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSST2DatasetShuffleFilesGlobal.";
// Set configuration
uint32_t original_seed = GlobalContext::config_manager()->seed();
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
GlobalContext::config_manager()->set_seed(130);
GlobalContext::config_manager()->set_num_parallel_workers(4);
std::string folder_path = datasets_root_path_ + "/testSST2/";
std::vector<std::string> column_names = {"sentence"};
std::shared_ptr<Dataset> ds = SST2(folder_path, "test", 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_NE(row.find("sentence"), row.end());
std::vector<std::vector<std::string>> expected_result = {
{"test read SST2dataset 1 ."},
{"test read SST2dataset 3 ."},
{"test read SST2dataset 2 ."}};
uint64_t i = 0;
while (row.size() != 0) {
for (int j = 0; j < column_names.size(); j++) {
auto text = row[column_names[j]];
std::shared_ptr<Tensor> de_text;
ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text));
std::string_view sv;
ASSERT_OK(de_text->GetItemAt(&sv, {}));
std::string ss(sv);
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
}
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
// Expect 3 samples
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
// Restore configuration
GlobalContext::config_manager()->set_seed(original_seed);
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
}

View File

@ -0,0 +1,5 @@
sentence label
dev read SST2Dataset 1 . 0
dev read SST2Dataset 2 . 1
dev read SST2Dataset 3 . 1
dev read SST2Dataset 4 . 1
1 sentence label
2 dev read SST2Dataset 1 . 0
3 dev read SST2Dataset 2 . 1
4 dev read SST2Dataset 3 . 1
5 dev read SST2Dataset 4 . 1

View File

@ -0,0 +1,4 @@
index sentence
0 test read SST2dataset 1 .
1 test read SST2dataset 2 .
2 test read SST2dataset 3 .
1 index sentence
2 0 test read SST2dataset 1 .
3 1 test read SST2dataset 2 .
4 2 test read SST2dataset 3 .

View File

@ -0,0 +1,6 @@
sentence label
train read SST2Dataset 1 . 0
train read SST2Dataset 2 . 1
train read SST2Dataset 3 . 1
train read SST2Dataset 4 . 1
train read SST2Dataset 5 . 0
1 sentence label
2 train read SST2Dataset 1 . 0
3 train read SST2Dataset 2 . 1
4 train read SST2Dataset 3 . 1
5 train read SST2Dataset 4 . 1
6 train read SST2Dataset 5 . 0

View File

@ -0,0 +1,136 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test SST2 dataset operators
"""
import mindspore.dataset as ds
DATA_DIR = '../data/dataset/testSST2/'
def test_sst2_dataset_basic():
"""
Feature: SST2Dataset
Description: Read data from train file
Expectation: The data is processed successfully
"""
buffer = []
data = ds.SST2Dataset(DATA_DIR, usage="train", shuffle=False)
data = data.repeat(2)
data = data.skip(3)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append(d)
assert len(buffer) == 7
def test_sst2_dataset_quoted():
"""
Feature: SST2Dataset
Description: Read the data and compare it to expectations
Expectation: The data is processed successfully
"""
data = ds.SST2Dataset(DATA_DIR, usage="test", shuffle=False)
buffer = []
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.extend([d['sentence']])
assert buffer == ["test read SST2dataset 1 .",
"test read SST2dataset 2 .",
"test read SST2dataset 3 ."]
def test_sst2_dataset_usage():
"""
Feature: SST2Dataset.
Description: Tead all files with usage all.
Expectation: The data is processed successfully.
"""
buffer = []
data = ds.SST2Dataset(DATA_DIR, usage="dev", shuffle=False)
for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
buffer.append(d)
assert len(buffer) == 4
def test_sst2_dataset_get_dataset_size():
"""
Feature: SST2Dataset
Description: Test get_dataset_size function
Expectation: The data is processed successfully
"""
data = ds.SST2Dataset(DATA_DIR, usage="dev", shuffle=False)
size = data.get_dataset_size()
assert size == 4
def test_sst2_dataset_distribution():
"""
Feature: SST2Dataset
Description: Test in a distributed state
Expectation: The data is processed successfully
"""
data = ds.SST2Dataset(DATA_DIR, usage="train", shuffle=False, num_shards=2, shard_id=0)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 3
def test_sst2_dataset_num_samples():
"""
Feature: SST2Dataset
Description: Test num_samples parameter
Expectation: The data is processed successfully
"""
data = ds.SST2Dataset(DATA_DIR, usage="test", shuffle=False, num_samples=2)
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
assert count == 2
def test_sst2_dataset_exception():
"""
Feature: SST2Dataset
Description: Test the wrong input
Expectation: Unable to read data properly
"""
def exception_func(item):
raise Exception("Error occur!")
try:
data = ds.SST2Dataset(DATA_DIR, usage="test", shuffle=False)
data = data.map(operations=exception_func, input_columns=["sentence"], num_parallel_workers=1)
for _ in data.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
try:
data = ds.SST2Dataset(DATA_DIR, usage="test", shuffle=False)
data = data.map(operations=exception_func, input_columns=["sentence"], num_parallel_workers=1)
for _ in data.create_dict_iterator():
pass
assert False
except RuntimeError as e:
assert "map operation: [PyFunc] failed. The corresponding data file" in str(e)
if __name__ == "__main__":
test_sst2_dataset_basic()
test_sst2_dataset_quoted()
test_sst2_dataset_usage()
test_sst2_dataset_get_dataset_size()
test_sst2_dataset_distribution()
test_sst2_dataset_num_samples()
test_sst2_dataset_exception()